In [34]:
!pip install geomloss
!pip install torchvision



In [16]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from geomloss import SamplesLoss

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=64, shuffle=True)

unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])
unlabeled_loader = DataLoader(unlabeled, batch_size=256, shuffle=True)
target_loader = DataLoader(target, batch_size=256, shuffle=True)
# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)
total_loss = 0
num_batches = 0

for (unlabeled_images, _), (target_images, _) in zip(unlabeled_loader, target_loader):
    loss = sinkhorn_loss(unlabeled_images.view(unlabeled_images.shape[0], -1), 
                         target_images.view(target_images.shape[0], -1))
    total_loss += loss.item()
    num_batches += 1

average_loss = total_loss / num_batches

print("Average Sinkhorn loss:", average_loss)

#loss over entire dataset
(unlabeled_images, _), (target_images, _) = next(zip(unlabeled_loader, target_loader))

loss = sinkhorn_loss(unlabeled_images.view(unlabeled_images.shape[0], -1), 
                     target_images.view(target_images.shape[0], -1))

print("Sinkhorn loss:", loss.item())



Average Sinkhorn loss: 90.25309753417969
Sinkhorn loss: 89.77881622314453


In [23]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
subset = torch.utils.data.Subset(mnist_trainset, range(9090))
unlabeled, target = torch.utils.data.random_split(subset, [9000, 90])

# Filter target dataset to include only one label, e.g., 0
target_images, target_labels = zip(*target)
label_to_keep = 0
filtered_target_indices = [i for i, label in enumerate(target_labels) if label == label_to_keep]
target_images_org = [target_images[i] for i in filtered_target_indices]
target_images = torch.stack(target_images_org)

unlabeled_images, unlabeled_labels = zip(*unlabeled)
unlabeled_images = torch.stack(unlabeled_images)
unlabeled_loader = DataLoader(list(zip(unlabeled_images, unlabeled_labels)), batch_size=90, shuffle=False)

# Count the number of datapoints with labels 0 in the unlabeled dataset
zero_label_count = unlabeled_labels.count(0)
print(f"Number of datapoints with label 0 in the unlabeled dataset: {zero_label_count}")

#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])
#unlabeled_loader = DataLoader(unlabeled, batch_size=90, shuffle=False)
#target_loader = DataLoader(target, batch_size=90, shuffle=False)
target_loader = DataLoader(target_images, batch_size=len(target_images), shuffle=False)


# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target_images), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.001)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w


# Loop over the datasets 10 times
for epoch in range(10):
    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _),target_images) in enumerate(zip(unlabeled_loader, target_loader)):

        # Select the weights for the current batch
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]

        # Compute Sinkhorn loss
        loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
    
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)

    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)
top_weights = sorted_weights[:10]
top_indices = indices[:10]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")

Epoch 1, Average Sinkhorn loss: 790.334716796875
Epoch 2, Average Sinkhorn loss: 3110.748779296875
Epoch 3, Average Sinkhorn loss: 3110.807373046875
Epoch 4, Average Sinkhorn loss: 3110.807373046875
Epoch 5, Average Sinkhorn loss: 3110.807373046875
Epoch 6, Average Sinkhorn loss: 3110.807373046875
Epoch 7, Average Sinkhorn loss: 3110.807373046875
Epoch 8, Average Sinkhorn loss: 3110.807373046875
Epoch 9, Average Sinkhorn loss: 3110.807373046875
Epoch 10, Average Sinkhorn loss: 3110.807373046875
Top 10 weights, their indices, and corresponding labels:
Weight: 0.09257984161376953, Index: 14, Label: 0
Weight: 0.09045600891113281, Index: 62, Label: 0
Weight: 0.08690357208251953, Index: 45, Label: 0
Weight: 0.08185911178588867, Index: 58, Label: 0
Weight: 0.0791015625, Index: 19, Label: 0
Weight: 0.07868146896362305, Index: 41, Label: 0
Weight: 0.0736842155456543, Index: 34, Label: 0
Weight: 0.07110166549682617, Index: 85, Label: 0
Weight: 0.07033014297485352, Index: 16, Label: 0
Weight: 0.

In [24]:
target_images = torch.stack(target_images_org)

unlabeled_images, unlabeled_labels = zip(*unlabeled)
unlabeled_images = torch.stack(unlabeled_images)


# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled_images), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target_images), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.01)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w

# Loop over the datasets 10 times
for epoch in range(10):
    optimizer.zero_grad()  # Reset gradients

    # Compute Sinkhorn loss
    loss = sinkhorn_loss(weights_unlabeled,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                         target_images.view(target_images.shape[0], -1)
                        )

    # Backpropagate the loss
    loss.backward()

    # Update the weights
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
        
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)
    
    print(f"Epoch {epoch+1}, Sinkhorn loss: {loss.item()}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)
top_weights = sorted_weights[:10]
top_indices = indices[:10]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")

Epoch 1, Sinkhorn loss: 4347.38671875
Epoch 2, Sinkhorn loss: 4284.7587890625
Epoch 3, Sinkhorn loss: 4284.619140625
Epoch 4, Sinkhorn loss: 4284.619140625
Epoch 5, Sinkhorn loss: 4284.619140625
Epoch 6, Sinkhorn loss: 4284.619140625
Epoch 7, Sinkhorn loss: 4284.619140625
Epoch 8, Sinkhorn loss: 4284.619140625
Epoch 9, Sinkhorn loss: 4284.619140625
Epoch 10, Sinkhorn loss: 4284.619140625
Top 10 weights, their indices, and corresponding labels:
Weight: 0.145369291305542, Index: 4972, Label: 0
Weight: 0.13164782524108887, Index: 6091, Label: 0
Weight: 0.10820269584655762, Index: 4186, Label: 0
Weight: 0.07911562919616699, Index: 1543, Label: 0
Weight: 0.07526278495788574, Index: 4959, Label: 0
Weight: 0.06982684135437012, Index: 896, Label: 0
Weight: 0.06844592094421387, Index: 1199, Label: 0
Weight: 0.05932879447937012, Index: 610, Label: 0
Weight: 0.05243563652038574, Index: 2254, Label: 0
Weight: 0.05175280570983887, Index: 5096, Label: 0


In [26]:

target_images = torch.stack(target_images_org)

unlabeled_images, unlabeled_labels = zip(*unlabeled)
unlabeled_images = torch.stack(unlabeled_images)
unlabeled_loader = DataLoader(list(zip(unlabeled_images, unlabeled_labels)), batch_size=90, shuffle=False)

#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])
#unlabeled_loader = DataLoader(unlabeled, batch_size=90, shuffle=False)
#target_loader = DataLoader(target, batch_size=90, shuffle=False)
target_loader = DataLoader(target_images, batch_size=len(target_images), shuffle=False)


# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target_images), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.001)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w


# Loop over the datasets 10 times
for epoch in range(10):
    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _),target_images) in enumerate(zip(unlabeled_loader, target_loader)):

        # Select the weights for the current batch
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]

        # Compute Sinkhorn loss
        loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
    
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)

    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)
top_weights = sorted_weights[:40]
top_indices = indices[:40]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")

Epoch 1, Average Sinkhorn loss: 790.334716796875
Epoch 2, Average Sinkhorn loss: 3110.748779296875
Epoch 3, Average Sinkhorn loss: 3110.807373046875
Epoch 4, Average Sinkhorn loss: 3110.807373046875
Epoch 5, Average Sinkhorn loss: 3110.807373046875
Epoch 6, Average Sinkhorn loss: 3110.807373046875
Epoch 7, Average Sinkhorn loss: 3110.807373046875
Epoch 8, Average Sinkhorn loss: 3110.807373046875
Epoch 9, Average Sinkhorn loss: 3110.807373046875
Epoch 10, Average Sinkhorn loss: 3110.807373046875
Top 10 weights, their indices, and corresponding labels:
Weight: 0.09257984161376953, Index: 14, Label: 0
Weight: 0.09045600891113281, Index: 62, Label: 0
Weight: 0.08690357208251953, Index: 45, Label: 0
Weight: 0.08185911178588867, Index: 58, Label: 0
Weight: 0.0791015625, Index: 19, Label: 0
Weight: 0.07868146896362305, Index: 41, Label: 0
Weight: 0.0736842155456543, Index: 34, Label: 0
Weight: 0.07110166549682617, Index: 85, Label: 0
Weight: 0.07033014297485352, Index: 16, Label: 0
Weight: 0.