In [34]:
!pip install geomloss
!pip install torchvision



In [35]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from geomloss import SamplesLoss

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=64, shuffle=True)

unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59000, 1000])
unlabeled_loader = DataLoader(unlabeled, batch_size=256, shuffle=True)
target_loader = DataLoader(target, batch_size=256, shuffle=True)

In [36]:
# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)

In [37]:
total_loss = 0
num_batches = 0

for (unlabeled_images, _), (target_images, _) in zip(unlabeled_loader, target_loader):
    loss = sinkhorn_loss(unlabeled_images.view(unlabeled_images.shape[0], -1), 
                         target_images.view(target_images.shape[0], -1))
    total_loss += loss.item()
    num_batches += 1

average_loss = total_loss / num_batches

print("Average Sinkhorn loss:", average_loss)

#loss over entire dataset
(unlabeled_images, _), (target_images, _) = next(zip(unlabeled_loader, target_loader))

loss = sinkhorn_loss(unlabeled_images.view(unlabeled_images.shape[0], -1), 
                     target_images.view(target_images.shape[0], -1))

print("Sinkhorn loss:", loss.item())



Average Sinkhorn loss: 91.48714637756348
Sinkhorn loss: 89.9547119140625


In [44]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
subset = torch.utils.data.Subset(mnist_trainset, range(990))
unlabeled, target = torch.utils.data.random_split(subset, [900, 90])

#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])
unlabeled_loader = DataLoader(unlabeled, batch_size=90, shuffle=False)
target_loader = DataLoader(target, batch_size=90, shuffle=False)

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.001)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w


# Loop over the datasets 10 times
for epoch in range(10):
    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _), (target_images, _)) in enumerate(zip(unlabeled_loader, target_loader)):

        # Select the weights for the current batch
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]

        # Compute Sinkhorn loss
        loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
    
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)

    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)

print(f"Non-zero weights in descending order: {sorted_weights[:10]} and its indices: {indices[:10]}")


Epoch 1, Average Sinkhorn loss: 4365.8427734375
Epoch 2, Average Sinkhorn loss: 107.07838439941406
Epoch 3, Average Sinkhorn loss: 110.40301513671875
Epoch 4, Average Sinkhorn loss: 109.58042907714844
Epoch 5, Average Sinkhorn loss: 110.27548217773438
Epoch 6, Average Sinkhorn loss: 106.31277465820312
Epoch 7, Average Sinkhorn loss: 107.30313110351562
Epoch 8, Average Sinkhorn loss: 107.15576171875
Epoch 9, Average Sinkhorn loss: 111.96646881103516
Epoch 10, Average Sinkhorn loss: 108.69554901123047
Non-zero weights in descending order: tensor([0.0503, 0.0471, 0.0428, 0.0393, 0.0380, 0.0348, 0.0326, 0.0316, 0.0294,
        0.0290], grad_fn=<SliceBackward0>) and its indices: tensor([ 8, 86, 87, 23, 37, 75, 14, 34, 45, 78])


In [47]:
import torch
from torchvision import datasets, transforms
from geomloss import SamplesLoss
from torch import optim

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
#mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

#code to get a subset of mnist_trainset
#subset = torch.utils.data.Subset(mnist_trainset, range(990))
#unlabeled, target = torch.utils.data.random_split(subset, [900, 90])



#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])

# Load the entire datasets
unlabeled_images, _ = zip(*unlabeled)
target_images, _ = zip(*target)

unlabeled_images = torch.stack(unlabeled_images)
target_images = torch.stack(target_images)

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.01)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w

# Loop over the datasets 10 times
for epoch in range(10):
    optimizer.zero_grad()  # Reset gradients

    # Compute Sinkhorn loss
    loss = sinkhorn_loss(weights_unlabeled,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                         target_images.view(target_images.shape[0], -1)
                        )

    # Backpropagate the loss
    loss.backward()

    # Update the weights
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
        
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)
    
    print(f"Epoch {epoch+1}, Sinkhorn loss: {loss.item()}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)

print(f"Non-zero weights in descending order: {sorted_weights[:10]} and its indices: {indices[:10]}")


Epoch 1, Sinkhorn loss: 94.65599822998047
Epoch 2, Sinkhorn loss: 108.08488464355469
Epoch 3, Sinkhorn loss: 108.08488464355469
Epoch 4, Sinkhorn loss: 108.08488464355469
Epoch 5, Sinkhorn loss: 108.08488464355469
Epoch 6, Sinkhorn loss: 108.08488464355469
Epoch 7, Sinkhorn loss: 108.08488464355469
Epoch 8, Sinkhorn loss: 108.08488464355469
Epoch 9, Sinkhorn loss: 108.08488464355469
Epoch 10, Sinkhorn loss: 108.08488464355469
Non-zero weights in descending order: tensor([0.1006, 0.0769, 0.0643, 0.0607, 0.0603, 0.0578, 0.0564, 0.0563, 0.0533,
        0.0430], grad_fn=<SliceBackward0>) and its indices: tensor([197, 720, 768, 838, 299, 261,  57, 204, 758, 454])


In [48]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
subset = torch.utils.data.Subset(mnist_trainset, range(990))
#unlabeled, target = torch.utils.data.random_split(subset, [900, 90])

#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])
unlabeled_loader = DataLoader(unlabeled, batch_size=90, shuffle=False)
target_loader = DataLoader(target, batch_size=90, shuffle=False)

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.001)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w


# Loop over the datasets 10 times
for epoch in range(10):
    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _), (target_images, _)) in enumerate(zip(unlabeled_loader, target_loader)):

        # Select the weights for the current batch
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]

        # Compute Sinkhorn loss
        loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
    
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)

    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)

print(f"Non-zero weights in descending order: {sorted_weights[:10]} and its indices: {indices[:10]}")


Epoch 1, Average Sinkhorn loss: 4546.77001953125
Epoch 2, Average Sinkhorn loss: 109.22112274169922
Epoch 3, Average Sinkhorn loss: 101.6413803100586
Epoch 4, Average Sinkhorn loss: 106.3034896850586
Epoch 5, Average Sinkhorn loss: 107.51469421386719
Epoch 6, Average Sinkhorn loss: 111.22582244873047
Epoch 7, Average Sinkhorn loss: 107.99440002441406
Epoch 8, Average Sinkhorn loss: 106.12525939941406
Epoch 9, Average Sinkhorn loss: 113.13867950439453
Epoch 10, Average Sinkhorn loss: 103.81668853759766
Non-zero weights in descending order: tensor([0.0466, 0.0461, 0.0410, 0.0404, 0.0379, 0.0370, 0.0368, 0.0350, 0.0339,
        0.0329], grad_fn=<SliceBackward0>) and its indices: tensor([70, 18, 55, 51, 19, 10, 23,  6, 47, 64])
