In [34]:
!pip install geomloss
!pip install torchvision



In [21]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from geomloss import SamplesLoss

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
#trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=64, shuffle=True)

#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])

subset = torch.utils.data.Subset(mnist_trainset, range(990))
unlabeled, target = torch.utils.data.random_split(subset, [900, 90])


unlabeled_loader = DataLoader(unlabeled, batch_size=256, shuffle=True)
target_loader = DataLoader(target, batch_size=256, shuffle=True)

weights_unlabeled = torch.rand(len(unlabeled), 1)
weights_unlabeled /= weights_unlabeled.sum()

weights_target = torch.rand(len(target), 1)
weights_target /= weights_target.sum()

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)
total_loss = 0
num_batches = 0

for batch_idx, ((unlabeled_images, _),(target_images,_)) in enumerate(zip(unlabeled_loader, target_loader)):
    # Select the weights for the current batch
    weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]
    #weights_batch/= weights_batch.sum()
    
     # Compute Sinkhorn loss
    loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )
    total_loss += loss.item()
    num_batches += 1

average_loss = total_loss / num_batches

print("Average Sinkhorn loss:", average_loss)

#loss over entire dataset
# For unlabeled dataset
unlabeled_images_all = torch.tensor([])
unlabeled_labels_all = torch.tensor([])

for images, labels in unlabeled_loader:
    unlabeled_images_all = torch.cat((unlabeled_images_all, images.view(images.shape[0], -1)))
    unlabeled_labels_all = torch.cat((unlabeled_labels_all, labels))
    
# For target dataset
target_images_all = torch.tensor([])
target_labels_all = torch.tensor([])

for images, labels in target_loader:
    target_images_all = torch.cat((target_images_all, images.view(images.shape[0], -1)))
    target_labels_all = torch.cat((target_labels_all, labels))


loss = sinkhorn_loss(weights_unlabeled,unlabeled_images_all.view(unlabeled_images_all.shape[0], -1), weights_target,
                     target_images_all.view(target_images_all.shape[0], -1))

print("Sinkhorn loss:", loss.item())



Average Sinkhorn loss: 92.34629821777344
Sinkhorn loss: 1043.08544921875


In [23]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
subset = torch.utils.data.Subset(mnist_trainset, range(40090))
unlabeled, target = torch.utils.data.random_split(subset, [40000, 90])

# Filter target dataset to include only one label, e.g., 0
target_images, target_labels = zip(*target)
label_to_keep = 9
filtered_target_indices = [i for i, label in enumerate(target_labels) if label == label_to_keep]
target_images_org = [target_images[i] for i in filtered_target_indices]
target_images = torch.stack(target_images_org)

unlabeled_images, unlabeled_labels = zip(*unlabeled)
unlabeled_images = torch.stack(unlabeled_images)
unlabeled_loader = DataLoader(list(zip(unlabeled_images, unlabeled_labels)), batch_size=4000, shuffle=False)

# Count the number of datapoints with labels 0 in the unlabeled dataset
zero_label_count = unlabeled_labels.count(0)
print(f"Number of datapoints with label 0 in the unlabeled dataset: {zero_label_count}")

#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])
#unlabeled_loader = DataLoader(unlabeled, batch_size=90, shuffle=False)
#target_loader = DataLoader(target, batch_size=90, shuffle=False)
target_loader = DataLoader(target_images, batch_size=len(target_images), shuffle=False)


# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target_images), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.001)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w


# Loop over the datasets 10 times
for epoch in range(10):

    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _),target_images) in enumerate(zip(unlabeled_loader, target_loader)):
        optimizer.zero_grad()  # Reset gradients

        # Select the weights for the current batch
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]
        weights_batch/=weights_batch.sum()
        # Compute Sinkhorn loss
        loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
    
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)

    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)
top_weights = sorted_weights[:400]
top_indices = indices[:400]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")

Number of datapoints with label 0 in the unlabeled dataset: 3924
Epoch 1, Average Sinkhorn loss: 19.38302993774414
Epoch 2, Average Sinkhorn loss: 5118.08349609375
Epoch 3, Average Sinkhorn loss: 5118.18701171875
Epoch 4, Average Sinkhorn loss: 5118.18701171875
Epoch 5, Average Sinkhorn loss: 5118.18701171875
Epoch 6, Average Sinkhorn loss: 5118.18701171875
Epoch 7, Average Sinkhorn loss: 5118.18701171875
Epoch 8, Average Sinkhorn loss: 5118.18701171875
Epoch 9, Average Sinkhorn loss: 5118.18701171875
Epoch 10, Average Sinkhorn loss: 5118.18701171875
Top 10 weights, their indices, and corresponding labels:
Weight: 0.045969098806381226, Index: 1429, Label: 9
Weight: 0.03998115658760071, Index: 2462, Label: 9
Weight: 0.03712168335914612, Index: 3337, Label: 9
Weight: 0.0304887592792511, Index: 1409, Label: 9
Weight: 0.029543250799179077, Index: 2713, Label: 9
Weight: 0.02908340096473694, Index: 17, Label: 9
Weight: 0.028588443994522095, Index: 2083, Label: 9
Weight: 0.02777719497680664, 

In [10]:
target_images = torch.stack(target_images_org)

unlabeled_images, unlabeled_labels = zip(*unlabeled)
unlabeled_images = torch.stack(unlabeled_images)

# Count the number of datapoints with labels 0 in the unlabeled dataset
zero_label_count = unlabeled_labels.count(0)
print(f"Number of datapoints with label 0 in the unlabeled dataset: {zero_label_count}")

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled_images), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target_images), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.01)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w

# Loop over the datasets 10 times
for epoch in range(10):
    optimizer.zero_grad()  # Reset gradients

    # Compute Sinkhorn loss
    loss = sinkhorn_loss(weights_unlabeled,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                         target_images.view(target_images.shape[0], -1)
                        )

    # Backpropagate the loss
    loss.backward()

    # Update the weights
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
        
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)
    
    print(f"Epoch {epoch+1}, Sinkhorn loss: {loss.item()}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)
top_weights = sorted_weights[:10]
top_indices = indices[:10]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")

Number of datapoints with label 0 in the unlabeled dataset: 4930


: 

: 

In [8]:

target_images = torch.stack(target_images_org)

unlabeled_images, unlabeled_labels = zip(*unlabeled)
unlabeled_images = torch.stack(unlabeled_images)
unlabeled_loader = DataLoader(list(zip(unlabeled_images, unlabeled_labels)), batch_size=90, shuffle=False)

#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])
#unlabeled_loader = DataLoader(unlabeled, batch_size=90, shuffle=False)
#target_loader = DataLoader(target, batch_size=90, shuffle=False)
target_loader = DataLoader(target_images, batch_size=len(target_images), shuffle=False)


# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target_images), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.001)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w


# Loop over the datasets 10 times
for epoch in range(10):
    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _),target_images) in enumerate(zip(unlabeled_loader, target_loader)):

        # Select the weights for the current batch
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]

        # Compute Sinkhorn loss
        loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
    
    weights_unlabeled.data = weights_unlabeled_new.data


    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")

# Sort the weights in descending order and print the non-zero weights
sorted_weights, indices = torch.sort(weights_unlabeled.flatten(), descending=True)
top_weights = sorted_weights[:40]
top_indices = indices[:40]

# Retrieve the labels of the images corresponding to the top indices
top_labels = [unlabeled_labels[idx] for idx in top_indices]

print("Top 10 weights, their indices, and corresponding labels:")
for weight, idx, label in zip(top_weights, top_indices, top_labels):
    print(f"Weight: {weight}, Index: {idx}, Label: {label}")

Epoch 1, Average Sinkhorn loss: 269.4306335449219
Epoch 2, Average Sinkhorn loss: 5825.7294921875
Epoch 3, Average Sinkhorn loss: 8194755.5
Epoch 4, Average Sinkhorn loss: 8653.630859375
Epoch 5, Average Sinkhorn loss: 8194755.5
Epoch 6, Average Sinkhorn loss: 8653.630859375
Epoch 7, Average Sinkhorn loss: 8194755.5
Epoch 8, Average Sinkhorn loss: 8653.630859375
Epoch 9, Average Sinkhorn loss: 8194755.5
Epoch 10, Average Sinkhorn loss: 8653.630859375
Top 10 weights, their indices, and corresponding labels:
Weight: 1.6763335224823095e-05, Index: 39835, Label: 9
Weight: 1.6763335224823095e-05, Index: 39821, Label: 9
Weight: 1.6763335224823095e-05, Index: 39822, Label: 4
Weight: 1.6763335224823095e-05, Index: 39823, Label: 1
Weight: 1.6763335224823095e-05, Index: 39824, Label: 0
Weight: 1.6763335224823095e-05, Index: 39825, Label: 5
Weight: 1.6763335224823095e-05, Index: 39826, Label: 1
Weight: 1.6763335224823095e-05, Index: 39827, Label: 9
Weight: 1.6763335224823095e-05, Index: 39828, La