In [10]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim
from torchvision.datasets import CIFAR10
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim


# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w
        
transform = transforms.Compose(
    [transforms.Grayscale(),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

cifar_trainset = CIFAR10(root='data/', download=True, transform=transform)
subset = torch.utils.data.Subset(cifar_trainset, range(42000))
unlabeled, target = torch.utils.data.random_split(subset, [40000, 2000])

# Filter target dataset to include only one label, e.g., 0
target_images, target_labels = zip(*target)
label_to_keep = 9
print("Target Label:{}".format(label_to_keep))
filtered_target_indices = [i for i, label in enumerate(target_labels) if label == label_to_keep]
target_images_org = [target_images[i] for i in filtered_target_indices]
target_images = torch.stack(target_images_org)
print("No of target images:{}".format(target_images.shape[0]))

unlabeled_images, unlabeled_labels = zip(*unlabeled)
unlabeled_images = torch.stack(unlabeled_images)
unlabeled_loader = DataLoader(list(zip(unlabeled_images, unlabeled_labels)), batch_size=4000, shuffle=False)

target_loader = DataLoader(target_images, batch_size=len(target_images), shuffle=False)

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target_images), 1), 1.0 / len(target), requires_grad=False)

# Define an optimizer
#optimizer = optim.SGD([weights_unlabeled], lr=0.001)
optimizer = optim.Adam([weights_unlabeled], lr=0.01)

# Loop over the datasets 10 times
for epoch in range(10):

    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _),target_images) in enumerate(zip(unlabeled_loader, target_loader)):
        optimizer.zero_grad()  # Reset gradients

        # Select the weights for the current batch
        unlabeled_images = unlabeled_images[:,0,:,:]
        target_images = target_images[:,0,:,:]
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]
        weights_batch = weights_batch.clone() / weights_batch.sum()

        # Reshape the images to be 1D tensors
        unlabeled_images = unlabeled_images.view(unlabeled_images.shape[0], -1)
        target_images = target_images.view(target_images.shape[0], -1)


        # Compute Sinkhorn loss
        loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled.data = project_simplex(weights_unlabeled.data)
    
    #weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)

    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)
top_weights = sorted_weights[:10]
top_indices = indices[:10]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")

Files already downloaded and verified
Target Label:9
No of target images:200
Epoch 1, Average Sinkhorn loss: 9269.59375
Epoch 2, Average Sinkhorn loss: 9254.1494140625
Epoch 3, Average Sinkhorn loss: 9249.857421875
Epoch 4, Average Sinkhorn loss: 9245.0400390625
Epoch 5, Average Sinkhorn loss: 9243.1083984375
Epoch 6, Average Sinkhorn loss: 9241.6767578125
Epoch 7, Average Sinkhorn loss: 9240.587890625
Epoch 8, Average Sinkhorn loss: 9239.6962890625
Epoch 9, Average Sinkhorn loss: 9239.017578125
Epoch 10, Average Sinkhorn loss: 9238.53515625
Top 10 weights, their indices, and corresponding labels:
Weight: 0.02720998413860798, Index: 1841, Label: 5
Weight: 0.027074463665485382, Index: 1663, Label: 5
Weight: 0.02196543663740158, Index: 299, Label: 4
Weight: 0.020615218207240105, Index: 348, Label: 3
Weight: 0.01563732512295246, Index: 3588, Label: 3
Weight: 0.010885179042816162, Index: 1478, Label: 0
Weight: 0.010186053812503815, Index: 3762, Label: 3
Weight: 0.008024487644433975, Index: