In [None]:
!pip install geomloss


In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from geomloss import SamplesLoss

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=64, shuffle=True)

unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59000, 1000])
unlabeled_loader = DataLoader(unlabeled, batch_size=256, shuffle=True)
target_loader = DataLoader(target, batch_size=256, shuffle=True)

In [None]:
# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)

In [None]:
total_loss = 0
num_batches = 0

for (unlabeled_images, _), (target_images, _) in zip(unlabeled_loader, target_loader):
    loss = sinkhorn_loss(unlabeled_images.view(unlabeled_images.shape[0], -1), 
                         target_images.view(target_images.shape[0], -1))
    total_loss += loss.item()
    num_batches += 1

average_loss = total_loss / num_batches

print("Average Sinkhorn loss:", average_loss)
