In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim
from torchvision.datasets import CIFAR10
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim


# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w
        
transform = transforms.Compose(
    [transforms.Grayscale(),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

cifar_trainset = CIFAR10(root='data/', download=True, transform=transform)
subset = torch.utils.data.Subset(cifar_trainset, range(42000))
unlabeled, target = torch.utils.data.random_split(subset, [40000, 2000])

# Filter target dataset to include only one label, e.g., 0
target_images, target_labels = zip(*target)
label_to_keep = 9
print("Target Label:{}".format(label_to_keep))
filtered_target_indices = [i for i, label in enumerate(target_labels) if label == label_to_keep]
target_images_org = [target_images[i] for i in filtered_target_indices]
target_images = torch.stack(target_images_org)
print("No of target images:{}".format(target_images.shape[0]))

unlabeled_images, unlabeled_labels = zip(*unlabeled)
unlabeled_images = torch.stack(unlabeled_images)
unlabeled_loader = DataLoader(list(zip(unlabeled_images, unlabeled_labels)), batch_size=4000, shuffle=False)

target_loader = DataLoader(target_images, batch_size=len(target_images), shuffle=False)

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target_images), 1), 1.0 / len(target), requires_grad=False)

# Define an optimizer
#optimizer = optim.SGD([weights_unlabeled], lr=0.001)
optimizer = optim.Adam([weights_unlabeled], lr=0.01)

# Loop over the datasets 10 times
for epoch in range(10):

    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _),target_images) in enumerate(zip(unlabeled_loader, target_loader)):
        optimizer.zero_grad()  # Reset gradients

        # Select the weights for the current batch
        unlabeled_images = unlabeled_images[:,0,:,:]
        target_images = target_images[:,0,:,:]
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]
        weights_batch = weights_batch.clone() / weights_batch.sum()

        # Reshape the images to be 1D tensors
        unlabeled_images = unlabeled_images.view(unlabeled_images.shape[0], -1)
        target_images = target_images.view(target_images.shape[0], -1)


        # Compute Sinkhorn loss
        loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled.data = project_simplex(weights_unlabeled.data)
    
    #weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)

    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)
top_weights = sorted_weights[:10]
top_indices = indices[:10]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")

Files already downloaded and verified
Target Label:9
No of target images:200
Epoch 1, Average Sinkhorn loss: 9271.09765625
Epoch 2, Average Sinkhorn loss: 9255.0703125
Epoch 3, Average Sinkhorn loss: 9252.1484375
Epoch 4, Average Sinkhorn loss: 9247.1044921875
Epoch 5, Average Sinkhorn loss: 9244.13671875
Epoch 6, Average Sinkhorn loss: 9243.0244140625
Epoch 7, Average Sinkhorn loss: 9242.146484375
Epoch 8, Average Sinkhorn loss: 9241.0302734375
Epoch 9, Average Sinkhorn loss: 9239.998046875
Epoch 10, Average Sinkhorn loss: 9239.4033203125
Top 10 weights, their indices, and corresponding labels:
Weight: 0.019431760534644127, Index: 3189, Label: 4
Weight: 0.018872834742069244, Index: 1140, Label: 4
Weight: 0.013660362921655178, Index: 2342, Label: 4
Weight: 0.009914053604006767, Index: 3590, Label: 0
Weight: 0.009910528548061848, Index: 3127, Label: 0
Weight: 0.0093167619779706, Index: 2934, Label: 4
Weight: 0.008644971996545792, Index: 2155, Label: 3
Weight: 0.008335231803357601, Index

In [3]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim
from torchvision.datasets import CIFAR10

# Define a function to project weights to a simplex
def project_simplex(v):
    z = 1
    orig_shape = v.shape
    v = v.view(1, -1)
    shape = v.shape
    with torch.no_grad():
        mu = torch.sort(v, dim=1)[0]
        mu = torch.flip(mu, dims=(1,))
        cum_sum = torch.cumsum(mu, dim=1)
        j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
        rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
        rho = rho.to(int)
        max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
        theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
        w = torch.clamp(v - theta, min=0.0).view(orig_shape)
        return w
        
transform = transforms.Compose(
    [transforms.Grayscale(),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

cifar_trainset = CIFAR10(root='data/', download=True, transform=transform)
subset = torch.utils.data.Subset(cifar_trainset, range(42000))
unlabeled, remaining = torch.utils.data.random_split(subset, [40000, 2000])

# split remaining into target and private
target, private = torch.utils.data.random_split(remaining, [1000, 1000])

# Filter target dataset to include only one label, e.g., 0
target_images, target_labels = zip(*target)
label_to_keep = 9
print("Target Label:{}".format(label_to_keep))
filtered_target_indices = [i for i, label in enumerate(target_labels) if label == label_to_keep]
target_images_org = [target_images[i] for i in filtered_target_indices]
target_images = torch.stack(target_images_org)
print("No of target images:{}".format(target_images.shape[0]))

unlabeled_images, unlabeled_labels = zip(*unlabeled)
unlabeled_images = torch.stack(unlabeled_images)
unlabeled_loader = DataLoader(list(zip(unlabeled_images, unlabeled_labels)), batch_size=4000, shuffle=False)

target_loader = DataLoader(target_images, batch_size=len(target_images), shuffle=False)

# Filter private dataset to exclude label_to_keep
private_images, private_labels = zip(*private)
filtered_private_indices = [i for i, label in enumerate(private_labels) if label != label_to_keep]
private_images_org = [private_images[i] for i in filtered_private_indices]
private_images = torch.stack(private_images_org)

private_loader = DataLoader(private_images, batch_size=len(private_images), shuffle=False)

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

# Initialize weights for the unlabeled_images and private_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target_images), 1), 1.0 / len(target), requires_grad=False)
weights_private = torch.full((len(private_images), 1), 1.0 / len(private), requires_grad=False)

# Define an optimizer
optimizer = optim.Adam([weights_unlabeled], lr=0.01)

# Loop over the datasets 10 times
for epoch in range(10):

    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _), target_images, private_images) in enumerate(zip(unlabeled_loader, target_loader, private_loader)):
        optimizer.zero_grad()  # Reset gradients

        # Select the weights for the current batch
        unlabeled_images = unlabeled_images[:,0,:,:]
        target_images = target_images[:,0,:,:]
        private_images = private_images[:,0,:,:]
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]
        weights_batch = weights_batch.clone() / weights_batch.sum()

        # Reshape the images to be 1D tensors
        unlabeled_images = unlabeled_images.view(unlabeled_images.shape[0], -1)
        target_images = target_images.view(target_images.shape[0], -1)
        private_images = private_images.view(private_images.shape[0], -1)

        # Compute Sinkhorn loss
        loss_unlabeled_target = sinkhorn_loss(weights_batch, unlabeled_images, weights_target, target_images)
        loss_unlabeled_private = sinkhorn_loss(weights_batch, unlabeled_images, weights_private, private_images)
        loss_private_target = sinkhorn_loss(weights_private, private_images, weights_target, target_images)

        loss = loss_unlabeled_target - loss_unlabeled_private + 0.2*loss_private_target

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled.data = project_simplex(weights_unlabeled.data)

    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)
top_weights = sorted_weights[:10]
top_indices = indices[:10]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")


Files already downloaded and verified
Target Label:9
No of target images:91
Epoch 1, Average Sinkhorn loss: 11295.7626953125
Epoch 2, Average Sinkhorn loss: 11267.4140625
Epoch 3, Average Sinkhorn loss: 11240.3623046875
Epoch 4, Average Sinkhorn loss: 11225.4228515625
Epoch 5, Average Sinkhorn loss: 11214.7158203125
Epoch 6, Average Sinkhorn loss: 11206.5166015625
Epoch 7, Average Sinkhorn loss: 11200.021484375
Epoch 8, Average Sinkhorn loss: 11194.755859375
Epoch 9, Average Sinkhorn loss: 11190.2275390625
Epoch 10, Average Sinkhorn loss: 11186.5927734375
Top 10 weights, their indices, and corresponding labels:
Weight: 0.01211019791662693, Index: 672, Label: 9
Weight: 0.011650919914245605, Index: 2267, Label: 9
Weight: 0.01039084792137146, Index: 728, Label: 9
Weight: 0.010268377140164375, Index: 1855, Label: 1
Weight: 0.010249389335513115, Index: 1395, Label: 9
Weight: 0.010145265609025955, Index: 3802, Label: 9
Weight: 0.01006278581917286, Index: 496, Label: 1
Weight: 0.0099889021366

In [4]:
top_weights = sorted_weights[:40]
top_indices = indices[:40]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")

Top 10 weights, their indices, and corresponding labels:
Weight: 0.01211019791662693, Index: 672, Label: 9
Weight: 0.011650919914245605, Index: 2267, Label: 9
Weight: 0.01039084792137146, Index: 728, Label: 9
Weight: 0.010268377140164375, Index: 1855, Label: 1
Weight: 0.010249389335513115, Index: 1395, Label: 9
Weight: 0.010145265609025955, Index: 3802, Label: 9
Weight: 0.01006278581917286, Index: 496, Label: 1
Weight: 0.009988902136683464, Index: 2081, Label: 8
Weight: 0.009909534826874733, Index: 1401, Label: 9
Weight: 0.00984535925090313, Index: 3611, Label: 9
Weight: 0.00944882445037365, Index: 1711, Label: 9
Weight: 0.009372292086482048, Index: 1501, Label: 9
Weight: 0.009309044107794762, Index: 2154, Label: 8
Weight: 0.009260408580303192, Index: 2523, Label: 7
Weight: 0.008967723697423935, Index: 1515, Label: 9
Weight: 0.008804993703961372, Index: 1114, Label: 1
Weight: 0.00873599760234356, Index: 623, Label: 1
Weight: 0.008594978600740433, Index: 1156, Label: 9
Weight: 0.0084696