In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Tuple
import random as rd

torch.manual_seed(0);

In [2]:
class PairedMNIST(Dataset):
    def __init__(self):
        self.mnist = MNIST('data', transform=transforms.ToTensor(), download=True)
    
    def __len__(self):
        return len(self.mnist)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        image, label = self.mnist[idx]
        similarity_label = rd.random() > .5
        if similarity_label:
            other_indices = self.mnist.targets == label
        else:
            other_indices = self.mnist.targets != label
        other, _ = self.mnist[rd.choice(other_indices.nonzero()[0])]
        return image, other, similarity_label * 1.

dataset = PairedMNIST()
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

In [3]:
class ContrastiveModel(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 32)
        )
        
    def forward(self, image: torch.Tensor, other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        return (
            self.encoder(image),
            self.encoder(other)
        )

In [4]:
def contrastive_loss(x1: torch.Tensor, x2: torch.Tensor, y: torch.Tensor, m: float = 1.) -> torch.Tensor:
    h = torch.linalg.norm(x1 - x2, dim=1) ** 2
    loss = y * h / 2 + (1 - y) * torch.clamp(m - h, min=0) / 2
    return loss.mean()

In [5]:
model = ContrastiveModel()
optimizer = torch.optim.Adam(params=model.parameters())

In [6]:
for epoch in range(5):
    model.train()
    total_loss = 0.
    for X1, X2, y in tqdm(dataloader):
        
        X1, X2 = X1.flatten(1), X2.flatten(1)
        
        optimizer.zero_grad()
        y1, y2 = model(X1, X2)
        loss = contrastive_loss(y1, y2, y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch} | Loss: {total_loss/len(dataloader):.4f}")

100%|██████████████████████████████████████████████████████████| 600/600 [00:39<00:00, 15.18it/s]


Epoch 0 | Loss: 0.0129


100%|██████████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.79it/s]


Epoch 1 | Loss: 0.0057


100%|██████████████████████████████████████████████████████████| 600/600 [00:41<00:00, 14.31it/s]


Epoch 2 | Loss: 0.0040


100%|██████████████████████████████████████████████████████████| 600/600 [00:33<00:00, 17.73it/s]


Epoch 3 | Loss: 0.0030


100%|██████████████████████████████████████████████████████████| 600/600 [00:43<00:00, 13.83it/s]

Epoch 4 | Loss: 0.0025





# Is anything wrong until here?

In [7]:
mnist_dataloader = DataLoader(dataset.mnist, batch_size=100)
with torch.inference_mode():
    predictions, labels = zip(*[(model.encoder(X.flatten(1)), y) for X, y in mnist_dataloader])
    predictions = torch.cat(predictions)
    labels = torch.cat(labels)

In [8]:
from sklearn.cluster import KMeans

In [9]:
kmeans = KMeans(10)

In [10]:
kmeans.fit(predictions.numpy())

In [11]:
kmeans.inertia_

31.77687644958496

In [12]:
from sklearn.metrics import rand_score

In [13]:
preds = kmeans.predict(predictions.numpy())
preds

array([1, 2, 9, ..., 1, 9, 0], dtype=int32)

In [14]:
rand_score(preds, labels)

0.678226419329211

In [15]:
from sklearn.neighbors import KNeighborsClassifier

In [16]:
knn = KNeighborsClassifier()
knn.fit(predictions.numpy(), labels)

In [17]:
knn.score(predictions.numpy(), labels)

0.71685