In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from ssl_playground.data import MNIST64
from ssl_playground.models import create_resnet18, shape_trace_resnet

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
mnist64 = MNIST64("../data/mnist64")
batch_size = 8
mnist64.make_data_and_loaders(batch_size)

In [5]:
example_iter = iter(mnist64.train_data_loader)
example_images, example_labels = next(example_iter)

In [6]:
example_images.size(), example_labels

(torch.Size([8, 1, 64, 64]), tensor([5, 1, 8, 2, 8, 6, 3, 6]))

In [7]:
resnet18 = torchvision.models.resnet18()

In [8]:
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
resnet18.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet18.fc = torch.nn.Identity()

In [10]:
example_outputs = resnet18(example_images)

In [11]:
example_outputs.size()

torch.Size([8, 512])

In [12]:
resnet18.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet18.fc = torch.nn.Linear(in_features=512, out_features=10, bias=True)

In [13]:
example_outputs = resnet18(example_images)

In [14]:
example_outputs.size()

torch.Size([8, 10])

In [15]:
print(example_outputs.softmax(0))

tensor([[0.1195, 0.1344, 0.1241, 0.1341, 0.1730, 0.1652, 0.0790, 0.1999, 0.1053,
         0.1468],
        [0.0990, 0.1685, 0.1351, 0.0858, 0.1057, 0.0872, 0.1021, 0.0733, 0.2727,
         0.1841],
        [0.0967, 0.1316, 0.1551, 0.1837, 0.1379, 0.1343, 0.1237, 0.1070, 0.0548,
         0.0750],
        [0.1421, 0.0707, 0.1903, 0.1355, 0.0928, 0.1471, 0.0892, 0.1419, 0.0776,
         0.0741],
        [0.1257, 0.1356, 0.1033, 0.0841, 0.1367, 0.1265, 0.1495, 0.0958, 0.1196,
         0.0959],
        [0.1564, 0.1444, 0.1340, 0.0729, 0.1173, 0.1123, 0.1494, 0.1418, 0.1038,
         0.0986],
        [0.1191, 0.1020, 0.0681, 0.1829, 0.1441, 0.1247, 0.1274, 0.1280, 0.1715,
         0.2005],
        [0.1415, 0.1128, 0.0900, 0.1210, 0.0924, 0.1028, 0.1797, 0.1123, 0.0946,
         0.1251]], grad_fn=<SoftmaxBackward0>)


In [16]:
x = example_images
print("Initial:", x.size())
x = resnet18.conv1(x)
print("Conv1:", x.size())
x = resnet18.maxpool(x)
print("Maxpool:", x.size())
x = resnet18.layer1(x)
print("Layer1:", x.size())
x = resnet18.layer2(x)
print("Layer2:", x.size())
x = resnet18.layer3(x)
print("Layer3:", x.size())
x = resnet18.layer4(x)
print("Layer4:", x.size())
x = resnet18.avgpool(x)
print("AvgPool:", x.size())
x = torch.flatten(x, 1)
print("Flatten:", x.size())
x = resnet18.fc(x)
print("fc:", x.size())

Initial: torch.Size([8, 1, 64, 64])
Conv1: torch.Size([8, 64, 32, 32])
Maxpool: torch.Size([8, 64, 16, 16])
Layer1: torch.Size([8, 64, 16, 16])
Layer2: torch.Size([8, 128, 8, 8])
Layer3: torch.Size([8, 256, 4, 4])
Layer4: torch.Size([8, 512, 2, 2])
AvgPool: torch.Size([8, 512, 1, 1])
Flatten: torch.Size([8, 512])
fc: torch.Size([8, 10])


In [17]:
model_sl = create_resnet18("supervised")
model_ssl = create_resnet18("self-supervised")

In [18]:
model_sl

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [19]:
model_ssl

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [20]:
shape_trace_resnet(model_sl, example_images)

Initial: torch.Size([8, 1, 64, 64])
Conv1: torch.Size([8, 64, 32, 32])
Maxpool: torch.Size([8, 64, 16, 16])
Layer1: torch.Size([8, 64, 16, 16])
Layer2: torch.Size([8, 128, 8, 8])
Layer3: torch.Size([8, 256, 4, 4])
Layer4: torch.Size([8, 512, 2, 2])
AvgPool: torch.Size([8, 512, 1, 1])
Flatten: torch.Size([8, 512])
fc: torch.Size([8, 10])


In [21]:
shape_trace_resnet(model_ssl, example_images)

Initial: torch.Size([8, 1, 64, 64])
Conv1: torch.Size([8, 64, 32, 32])
Maxpool: torch.Size([8, 64, 16, 16])
Layer1: torch.Size([8, 64, 16, 16])
Layer2: torch.Size([8, 128, 8, 8])
Layer3: torch.Size([8, 256, 4, 4])
Layer4: torch.Size([8, 512, 2, 2])
AvgPool: torch.Size([8, 512, 1, 1])
Flatten: torch.Size([8, 512])
fc: torch.Size([8, 512])


In [23]:
model_sl.train()

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [24]:
model_sl.eval()

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  