In [2]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader


from torchvision import models
import torchvision.transforms as T
from torchvision.datasets import MNIST

from torchinfo import summary

import matplotlib.pyplot as plt
import numpy as np

In [7]:
Ya_transforms = T.Compose(
    [
        T.ToTensor(),
        T.RandomChoice(
            [
                T.ColorJitter(),
                T.RandomHorizontalFlip(),
                T.GaussianBlur(3),
                T.RandomSolarize(.6)
            ]
        )
    ]
)

Yb_transforms = T.Compose(
    [
        T.ToTensor(),
        T.RandomChoice(
            [
                T.ColorJitter(),
                T.RandomHorizontalFlip(),
                T.GaussianBlur(3),
                T.RandomSolarize(.6)
            ]
        )
    ]
)

In [8]:
A_dataset = MNIST(
    'data',
    train = 'True',
    download = True,
    transform=Ya_transforms
)


B_dataset = MNIST(
    'data',
    train = 'True',
    download = True,
    transform=Yb_transforms
)

In [9]:
BATCH_SIZE=256

A_loader = DataLoader(A_dataset, batch_size=BATCH_SIZE)
B_loader = DataLoader(B_dataset, batch_size=BATCH_SIZE)

In [10]:
next(iter(A_loader))

[tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         ...,
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ..

In [17]:
class Encoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 4, 1),
            nn.Dropout(),
            nn.ReLU(),
            nn.BatchNorm2d(32)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16, 6, 2),
            nn.Dropout(),
            nn.ReLU(),
            nn.BatchNorm2d(16)
        )

        self.linear = nn.Sequential(
            nn.LazyLinear(16)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        out = self.linear(x)
        return out

In [16]:
fake_data = torch.randn(256, 1, 28, 28)

Encoder()(fake_data).shape



torch.Size([256, 16])

In [19]:
class Projector(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.proj = nn.Sequential(
            nn.Linear(in_features=16, out_features=32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 16)
        )

    def forward(self, x):
        return self.proj(x)

In [21]:
Projector()(torch.randn(256, 16)).shape

torch.Size([256, 16])

In [22]:
class F_theta(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.encoder = Encoder()
        self.projector = Projector()

    def forward(self, x):
        return self.projector(self.encoder(x))

In [24]:
model = F_theta()



In [25]:
summary(model)

Layer (type:depth-idx)                   Param #
F_theta                                  --
├─Encoder: 1-1                           --
│    └─Sequential: 2-1                   --
│    │    └─Conv2d: 3-1                  544
│    │    └─Dropout: 3-2                 --
│    │    └─ReLU: 3-3                    --
│    │    └─BatchNorm2d: 3-4             64
│    └─Sequential: 2-2                   --
│    │    └─Conv2d: 3-5                  18,448
│    │    └─Dropout: 3-6                 --
│    │    └─ReLU: 3-7                    --
│    │    └─BatchNorm2d: 3-8             32
│    └─Sequential: 2-3                   --
│    │    └─LazyLinear: 3-9              --
├─Projector: 1-2                         --
│    └─Sequential: 2-4                   --
│    │    └─Linear: 3-10                 544
│    │    └─BatchNorm1d: 3-11            64
│    │    └─ReLU: 3-12                   --
│    │    └─Linear: 3-13                 528
Total params: 20,224
Trainable params: 20,224
Non-trainable para

In [30]:
def normalize(x):
    return (x - x.mean(0)) / x.std(0)

In [34]:
optimizer = torch.optim.Adam(model.parameters())

In [37]:
model.train()

losses = []

for epoch in range(10):
    batch_losses = []
    for x_a, x_b in zip(A_loader, B_loader):
        out_a = model(x_a[0])
        out_b = model(x_b[0])

        out_a_normed = normalize(out_a)
        out_b_normed = normalize(out_b)

        c = torch.matmul(out_a_normed.T, out_b_normed) / BATCH_SIZE

        c_diff = (c - torch.eye(c.shape[0])).pow(2)

        c_diff.flatten()[:-1].view(c_diff.size(0) - 1, c_diff.size(0)+1)[:, 1:].mul_(.5)

        loss = c_diff.sum()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.item())
    losses.append(np.mean(batch_losses))

    print(f'Epoch n: {epoch}, loss: {losses[-1]}')

Epoch n: 0, loss: 3.3138827334059044
Epoch n: 1, loss: 1.2140538943574783
Epoch n: 2, loss: 0.9484281950808586
Epoch n: 3, loss: 0.8098967663785245
Epoch n: 4, loss: 0.7613271729743227
Epoch n: 5, loss: 0.7083654177949784
Epoch n: 6, loss: 0.6771670341491699
Epoch n: 7, loss: 0.650630200796939
Epoch n: 8, loss: 0.6249631072612519
Epoch n: 9, loss: 0.6046989103581043
