In [246]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader


from torchvision import models
import torchvision.transforms as T
from torchvision.datasets import MNIST

from torchinfo import summary

import matplotlib.pyplot as plt
import numpy as np

In [247]:
Ya_transforms = T.Compose(
    [
        T.ToTensor(),
        T.RandomChoice(
            [
                T.ColorJitter(),
                T.RandomHorizontalFlip(),
                T.GaussianBlur(3),
                T.RandomSolarize(.6)
            ]
        )
    ]
)

Yb_transforms = T.Compose(
    [
        T.ToTensor(),
        T.RandomChoice(
            [
                T.ColorJitter(),
                T.RandomHorizontalFlip(),
                T.GaussianBlur(3),
                T.RandomSolarize(.6)
            ]
        )
    ]
)

In [248]:
A_dataset = MNIST(
    'data', 
    'True',
    transform=Ya_transforms,
    download=True
)

B_dataset = MNIST(
    'data', 
    'True',
    transform=Yb_transforms,
)

In [249]:
BATCH_SIZE = 32
A_loader = DataLoader(A_dataset, batch_size=BATCH_SIZE)
B_loader = DataLoader(B_dataset, batch_size=BATCH_SIZE)

In [23]:
A_sample = next(iter(A_loader))[0]
B_sample = next(iter(B_loader))[0]

In [225]:
A_sample[0]

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

In [59]:
def plot_samples(A_batch, B_batch):

    fig, ax = plt.subplots(2, 7)
    # ax = ax.reshape(14)

    for i in range(7):
        ax[0, i].imshow(A_batch[i][0])
        ax[1, i].imshow(B_batch[i][0])

In [258]:
class Encoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 4, 1),
            nn.Dropout(),
            nn.ReLU(),
            nn.BatchNorm2d(32)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16, 6, 2),
            nn.Dropout(),
            nn.ReLU(),
            nn.BatchNorm2d(16)
        )

        self.linear = nn.Sequential(
            nn.LazyLinear(16)
        )
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        out = self.linear(x)
      
        return out

In [259]:
class Projector(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.proj = nn.Sequential(
            nn.Linear(in_features=16, out_features=32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 16)
        )

    def forward(self, x):
        x = self.proj(x)
        return x

In [260]:
class F_theta(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.encoder = Encoder()
        self.proj = Projector()

    def forward(self, x):
        encoded = self.encoder(x)
        x = self.proj(encoded)
        return x, encoded

In [261]:
f_theta = F_theta()

In [262]:
summary(f_theta)

Layer (type:depth-idx)                   Param #
F_theta                                  --
├─Encoder: 1-1                           --
│    └─Sequential: 2-1                   --
│    │    └─Conv2d: 3-1                  544
│    │    └─Dropout: 3-2                 --
│    │    └─ReLU: 3-3                    --
│    │    └─BatchNorm2d: 3-4             64
│    └─Sequential: 2-2                   --
│    │    └─Conv2d: 3-5                  18,448
│    │    └─Dropout: 3-6                 --
│    │    └─ReLU: 3-7                    --
│    │    └─BatchNorm2d: 3-8             32
│    └─Sequential: 2-3                   --
│    │    └─LazyLinear: 3-9              --
├─Projector: 1-2                         --
│    └─Sequential: 2-4                   --
│    │    └─Linear: 3-10                 544
│    │    └─BatchNorm1d: 3-11            64
│    │    └─ReLU: 3-12                   --
│    │    └─Linear: 3-13                 528
Total params: 20,224
Trainable params: 20,224
Non-trainable para

In [263]:
def normalize(x):
    return ((x - x.mean(0)) / x.std(0))

In [264]:
optimizer = optim.Adam(f_theta.parameters())

In [265]:
epoch_loss = []

In [274]:
f_theta.train()
for epoch in range(50):
    batch_loss = []
    for x_a, x_b in zip(A_loader, B_loader):
        # print(x_a[0].shape, x_b[0].shape)
        out_a = normalize(f_theta(x_a[0])[0])
        out_b = normalize(f_theta(x_b[0])[0])
        emb = f_theta(x_b[0])[0]
        # print(f'f_theta out: {out_a.shape}')
        cross_corr = torch.matmul(out_a.T, out_b) / BATCH_SIZE
        c_diff = (cross_corr - torch.eye(cross_corr.size(0))).pow(2)
        # print(f'Cross corr: {cross_corr.shape}')
        # print(c_diff.shape)
        c_diff.flatten()[:-1].view(c_diff.size(0) - 1, c_diff.size(0)+1)[:, 1:].mul_(.5)
        loss = c_diff.sum()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_loss.append(loss.item())
    epoch_loss.append(np.mean(batch_loss))
    print(f_theta(x_b[0])[1][0])
    if epoch+1 % 5 == 0:
        torch.save(f_theta.state_dict(), f'weights/n_epochs_{epoch}_new.pt')
    print(f'Epoch {epoch+1} finished. Loss: {epoch_loss[-1]:.3f}')
        
        


tensor([ 46.1071,  16.4787,   4.9440,  15.8657,   5.5808, -18.1278,   1.0149,
         41.6873, -31.0885, -20.4176, -22.5892,   6.4539, -15.9461, -11.5106,
          7.8488,   8.7911], grad_fn=<SelectBackward0>)
Epoch 1 finished. Loss: 1.602
tensor([ 52.3106,  29.1177,  14.8703,  13.9134,  17.5207,  -7.6311,   1.0507,
         36.9610, -35.8788, -24.6493, -12.7419,  10.3941,  -9.7290,  -8.7875,
         25.2963,   7.0533], grad_fn=<SelectBackward0>)
Epoch 2 finished. Loss: 1.592
tensor([ 54.8341,  21.6798,   8.5622,  18.2535,  19.7113,  -8.1439,  10.1015,
         34.6644, -44.1203, -23.0919,  -4.8697,  11.9878,   1.6465,  -9.8415,
         17.7444,  22.8734], grad_fn=<SelectBackward0>)
Epoch 3 finished. Loss: 1.597
tensor([ 61.3607,  26.4037,   3.5679,  16.0325,  20.6803, -15.5255,   9.7262,
         32.6398, -46.1915, -24.3841, -10.7016,   0.8762, -11.2492,  -2.3860,
         20.9871,   0.2347], grad_fn=<SelectBackward0>)
Epoch 4 finished. Loss: 1.575
tensor([ 62.0398,  23.9060,  15.

KeyboardInterrupt: 

In [275]:
torch.save(f_theta.state_dict(), f'weights/n_epochs_{epoch}.pt')

In [216]:
f_theta = F_theta()

f_theta.load_state_dict(torch.load('weights/n_epochs_23.pt'))



<All keys matched successfully>

In [282]:
valid_loader = DataLoader(MNIST('data', train=False, transform=T.ToTensor()), batch_size=32)

In [285]:
test_batch_img, test_batch_lbl = (next(iter(valid_loader)))

In [286]:
f_theta(test_batch_img)[1]

tensor([[-6.9082e+00,  1.4926e+01,  2.0808e+01, -4.7688e+01,  4.3042e+01,
          4.5962e+01, -4.9906e+01,  2.5168e+00, -3.1275e+01,  1.5816e+01,
         -3.1620e+01,  3.5737e+01, -4.5994e+00,  4.2364e+01,  1.5794e+01,
          1.1120e+01],
        [ 8.6672e+01, -1.8814e+01,  3.2317e+01, -3.2457e+01, -9.0589e+00,
         -8.8214e+00, -1.9338e+01,  5.0588e+01,  3.0347e+01,  7.8374e+00,
         -1.2130e+01,  1.3253e+01, -2.5789e+01, -1.8396e+01,  3.6527e+01,
          7.4531e+00],
        [ 9.0540e+01, -5.9130e+01, -1.4260e+01, -1.2459e+01,  5.2794e+01,
         -4.6375e+01, -2.1636e+01, -2.6174e+01,  2.5309e+00,  4.8639e+01,
         -2.5860e+01,  1.0535e+01, -4.0442e+01,  1.7193e+01,  2.1324e+01,
         -2.7289e+01],
        [ 4.1543e+01,  8.1364e+01,  2.1217e+01, -5.3465e+01, -1.4093e+01,
         -5.5029e+01,  2.1401e+01,  3.8139e+01, -2.5498e+01,  2.6841e+01,
         -5.4498e+01,  1.4192e+01, -3.8766e+01,  2.9353e+01,  5.5838e+01,
         -1.0746e+01],
        [ 2.3217e+01