In [None]:
import tempfile
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

from utils import batch_plot, seed_everything


np.set_printoptions(precision=2)
pd.set_option("display.precision", 2)
%load_ext autoreload
%autoreload 2

seed_everything()

In [None]:
batch_size = 4
image_size = 32
feature_dims = 128
temperature = 0.1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tmp_dir = Path(tempfile.mkdtemp())


# https://github.com/sthalles/SimCLR
class ContrastiveLearningViewGenerator(object):
    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transform(x) for i in range(self.n_views)]


null_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

contrast_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(size=image_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=int(0.1 * image_size)),
        transforms.ToTensor(),
    ]
)


original_dataset = torchvision.datasets.CIFAR10(
    root="../code/cs231n/datasets", train=True, download=True, transform=null_transforms
)
original_loader = torch.utils.data.DataLoader(original_dataset, batch_size=batch_size, shuffle=False)
original_images, original_labels = next(iter(original_loader))
batch_plot(xs=original_images.numpy().transpose((0, 2, 3, 1)), ys=original_labels.numpy())

In [None]:
for i, image in enumerate(original_images):
    save_image(image, tmp_dir.joinpath(f"original_{i:02d}.png"))

In [None]:
contrast_dataset = torchvision.datasets.CIFAR10(
    root="../code/cs231n/datasets",
    train=True,
    download=True,
    transform=ContrastiveLearningViewGenerator(base_transform=contrast_transforms),
)
contrast_loader = torch.utils.data.DataLoader(contrast_dataset, batch_size=batch_size, shuffle=False)
contrast_images, _ = next(iter(contrast_loader))
contrast_images = torch.cat(contrast_images, dim=0)
batch_plot(xs=contrast_images.numpy().transpose((0, 2, 3, 1)), rows=2, cols=4)

In [None]:
for i, image in enumerate(contrast_images):
    save_image(image, tmp_dir.joinpath(f"constract_{i:02d}.png"))

In [None]:
import torch.nn as nn
import torchvision.models as models


class ResNetSimCLR(nn.Module):
    # https://github.com/sthalles/SimCLR/blob/master/models/resnet_simclr.py
    def __init__(self, base_model="resnet18", **kwargs):
        super().__init__()
        # f(.)
        self.encoder = models.__dict__[base_model](**kwargs)
        dim_mlp = self.encoder.fc.in_features
        # g(.)
        self.encoder.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder.fc)

    def forward(self, x):
        return self.encoder(x)


def compute_sim_matrix_direct(features):
    return torch.nn.functional.cosine_similarity(features[:, None, :], features[None, :, :], dim=-1)


def compute_sim_matrix_functional(features):
    features = torch.nn.functional.normalize(features, dim=1)
    return features @ features.T


def compute_sim_matrix_norm(features):
    features = features / torch.linalg.norm(features, dim=1, keepdims=True)
    return features @ features.T


def simclr_loss_direct(features, temperature, n_views=2):
    n = len(features) // n_views
    similarity_matrix = compute_sim_matrix_norm(features)  # (2*N, 2*N))
    exponential = (similarity_matrix / temperature).exp()  # [2*N, 2*N]
    mask = torch.eye(2 * n, dtype=torch.bool)  # [2*N, 2*N]
    denom = exponential[~mask].view(2 * n, -1).sum(axis=1)  # [2*N, 1]
    loss = -(exponential / denom).log()
    return loss[mask.roll(shifts=n, dims=0)].mean()


def simclr_loss_separate(features, temperature, n_views=2):
    n = len(features) // n_views
    similarity_matrix = compute_sim_matrix_norm(features) / temperature  # (2*N, 2*N))
    mask = torch.eye(2 * n, dtype=torch.bool)  # [2*N, 2*N]
    similarity_matrix.masked_fill_(mask, -float("inf"))  # [2*N, 2*N]
    # -log(exp(a)/sum(exp(b))) = -a + logsumexp(b)
    loss = -similarity_matrix[mask.roll(shifts=n, dims=0)] + torch.logsumexp(similarity_matrix, dim=-1)
    return loss.mean()


def simclr_loss_criterion(features, temperature, n_views=2):
    n = len(features) // n_views
    similarity_matrix = compute_sim_matrix_norm(features)  # (2*N, 2*N))
    mask = torch.eye(2 * n, dtype=torch.bool)  # [2*N, 2*N]
    similarity_matrix = similarity_matrix[~mask].view(2 * n, -1)
    labels = mask.roll(shifts=n, dims=0)[~mask].view(2 * n, -1)
    positives = similarity_matrix[labels].view(2 * n, -1)
    negatives = similarity_matrix[~labels].view(2 * n, -1)
    logits = torch.cat([positives, negatives], dim=1) / temperature
    labels = torch.zeros(logits.shape[0], dtype=torch.long)
    return torch.nn.functional.cross_entropy(logits, labels)


model = ResNetSimCLR(weights=None, num_classes=feature_dims)

In [None]:
features = model(contrast_images)

assert features.shape == (2 * batch_size, feature_dims)


torch.testing.assert_close(
    compute_sim_matrix_direct(features=features), compute_sim_matrix_functional(features=features)
)

torch.testing.assert_close(compute_sim_matrix_norm(features=features), compute_sim_matrix_functional(features=features))

torch.testing.assert_close(
    simclr_loss_direct(features=features, temperature=temperature),
    simclr_loss_criterion(features=features, temperature=temperature),
)

torch.testing.assert_close(
    simclr_loss_direct(features=features, temperature=temperature),
    simclr_loss_separate(features=features, temperature=temperature),
)

In [None]:
%timeit compute_sim_matrix_norm(features=features)
%timeit compute_sim_matrix_functional(features=features)
%timeit compute_sim_matrix_direct(features=features)

In [None]:
%timeit simclr_loss_direct(features=features, temperature=temperature)
%timeit simclr_loss_criterion(features=features, temperature=temperature)
%timeit simclr_loss_separate(features=features, temperature=temperature)

In [None]:
!open {tmp_dir}

# references

- [Contrastive Representation Learning](https://lilianweng.github.io/posts/2021-05-31-contrastive/)