In [1]:
!pip install geomloss
!pip install torchvision

Collecting geomloss
  Downloading geomloss-0.2.6.tar.gz (26 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting numpy (from geomloss)
  Using cached numpy-1.24.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
Collecting torch (from geomloss)
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl (619.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:03[0m
[?25hCollecting filelock (from torch->geomloss)
  Using cached filelock-3.12.0-py3-none-any.whl (10 kB)
Collecting typing-extensions (from torch->geomloss)
  Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)
Collecting sympy (from torch->geomloss)
  Downloading sympy-1.12-py3-none-any.whl (5.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.7/5.7 MB[0m 

In [4]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from geomloss import SamplesLoss

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=64, shuffle=True)

unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59000, 1000])
unlabeled_loader = DataLoader(unlabeled, batch_size=256, shuffle=True)
target_loader = DataLoader(target, batch_size=256, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz



6.0%

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [5]:
# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)

In [2]:
total_loss = 0
num_batches = 0

for (unlabeled_images, _), (target_images, _) in zip(unlabeled_loader, target_loader):
    loss = sinkhorn_loss(unlabeled_images.view(unlabeled_images.shape[0], -1), 
                         target_images.view(target_images.shape[0], -1))
    total_loss += loss.item()
    num_batches += 1

average_loss = total_loss / num_batches

print("Average Sinkhorn loss:", average_loss)

#loss over entire dataset
(unlabeled_images, _), (target_images, _) = next(zip(unlabeled_loader, target_loader))

loss = sinkhorn_loss(unlabeled_images.view(unlabeled_images.shape[0], -1), 
                     target_images.view(target_images.shape[0], -1))

print("Sinkhorn loss:", loss.item())



Average Sinkhorn loss: 91.40467071533203
Sinkhorn loss: 92.88128662109375


In [9]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geomloss import SamplesLoss
from torch import optim

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
subset = torch.utils.data.Subset(mnist_trainset, range(990))
unlabeled, target = torch.utils.data.random_split(subset, [900, 90])

#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])
unlabeled_loader = DataLoader(unlabeled, batch_size=90, shuffle=True)
target_loader = DataLoader(target, batch_size=90, shuffle=True)

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.01)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w


# Loop over the datasets 10 times
for epoch in range(10):
    losses = []
    weights_unlabeled.grad = None  # Reset gradients at the beginning of each epoch

    for batch_idx, ((unlabeled_images, _), (target_images, _)) in enumerate(zip(unlabeled_loader, target_loader)):

        # Select the weights for the current batch
        weights_batch = weights_unlabeled[batch_idx * unlabeled_loader.batch_size : (batch_idx + 1) * unlabeled_loader.batch_size]

        # Compute Sinkhorn loss
        loss = sinkhorn_loss(weights_batch,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                             target_images.view(target_images.shape[0], -1),
                             )

        losses.append(loss.item())

        # Compute gradients for the loss
        loss.backward()  # Gradients are accumulated over mini-batches

    # Average the loss over all mini-batches
    loss_avg = sum(losses) / len(losses)

    # Update the weights based on the accumulated gradients
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
    
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)

    print(f"Epoch {epoch+1}, Average Sinkhorn loss: {loss_avg}")


Epoch 1, Average Sinkhorn loss: 7835.4306640625


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [18]:
import torch
from torchvision import datasets, transforms
from geomloss import SamplesLoss
from torch import optim

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

#code to get a subset of mnist_trainset
subset = torch.utils.data.Subset(mnist_trainset, range(990))
unlabeled, target = torch.utils.data.random_split(subset, [900, 90])



#unlabeled, target = torch.utils.data.random_split(mnist_trainset, [59744, 256])

# Load the entire datasets
unlabeled_images, _ = zip(*unlabeled)
target_images, _ = zip(*target)

unlabeled_images = torch.stack(unlabeled_images)
target_images = torch.stack(target_images)

# Create a loss function using GeomLoss
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)

# Initialize weights for the unlabeled_images
weights_unlabeled = torch.full((len(unlabeled), 1), 1.0 / len(unlabeled), requires_grad=True)
weights_target = torch.full((len(target), 1), 1.0 / len(target), requires_grad=False)
# Define an optimizer
optimizer = optim.SGD([weights_unlabeled], lr=0.01)

# Define a function to project weights to a simplex
def project_to_simplex(weights):
    return torch.clamp(weights, min=0) / torch.sum(weights)

def project_simplex(v):
        """
        v: PyTorch Tensor to be projected to a simplex

        Returns:
        w: PyTorch Tensor simplex projection of v
        """
        z = 1
        orig_shape = v.shape
        v = v.view(1, -1)
        shape = v.shape
        with torch.no_grad():
            mu = torch.sort(v, dim=1)[0]
            mu = torch.flip(mu, dims=(1,))
            cum_sum = torch.cumsum(mu, dim=1)
            j = torch.unsqueeze(torch.arange(1, shape[1] + 1, dtype=mu.dtype, device=mu.device), 0)
            rho = torch.sum(mu * j - cum_sum + z > 0.0, dim=1, keepdim=True) - 1.
            rho = rho.to(int)
            max_nn = cum_sum[torch.arange(shape[0]), rho[:, 0]]
            theta = (torch.unsqueeze(max_nn, -1) - z) / (rho.type(max_nn.dtype) + 1)
            w = torch.clamp(v - theta, min=0.0).view(orig_shape)
            return w

# Loop over the datasets 10 times
for epoch in range(10):
    optimizer.zero_grad()  # Reset gradients

    # Compute Sinkhorn loss
    loss = sinkhorn_loss(weights_unlabeled,unlabeled_images.view(unlabeled_images.shape[0], -1), weights_target,
                         target_images.view(target_images.shape[0], -1)
                        )

    # Backpropagate the loss
    loss.backward()

    # Update the weights
    optimizer.step()

    # Project the weights to a simplex
    with torch.no_grad():
        weights_unlabeled_new = project_simplex(weights_unlabeled)
        
    weights_unlabeled = weights_unlabeled_new.clone().detach().requires_grad_(True)
    
    print(f"Epoch {epoch+1}, Sinkhorn loss: {loss.item()}")


Epoch 1, Sinkhorn loss: 92.76219177246094
Epoch 2, Sinkhorn loss: 156.490234375
Epoch 3, Sinkhorn loss: 156.490234375
Epoch 4, Sinkhorn loss: 156.490234375
Epoch 5, Sinkhorn loss: 156.490234375
Epoch 6, Sinkhorn loss: 156.490234375
Epoch 7, Sinkhorn loss: 156.490234375
Epoch 8, Sinkhorn loss: 156.490234375
Epoch 9, Sinkhorn loss: 156.490234375
Epoch 10, Sinkhorn loss: 156.490234375
