In [1]:
import torch
import torchvision
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

## tldr 


#### Components
- one class for the pre-trained resnet model. 
- one class for the projection head. 
- put them together in simCLR class with a loss function (using pre-defined loss)
- no optimizer rn (4/29). They use LARS


#### Projection Head
Slightly different from the WCL implementation, only in terms of dimensions. The WCL model 'unsqueezes' twice, applies BatchNorm, and 'squeezes' twice.

> Batch normalization (BatchNorm) typically operates on mini-batches of data in convolutional neural networks (CNNs). In CNNs, the input to each layer is often represented as a 4D tensor with dimensions [batch_size, channels, height, width]. BatchNorm normalizes along the channel dimension, so it requires statistics to be computed across the batch and spatial dimensions while keeping the channel dimension intact.`

However, they also use BatchNorm1d, which does not require that. Moreover, adjusting dimensions does not work with BN1d. :( 

#### Questions
Q: What does torch.no_grad() do that model.eval() does not?   
Q: SimCLR uses pre-trained Resnet weights  
Q: WCL model uses F.normalize after projection head  
Q: SimCLR throws away projection head and uses hi for linear eval ..

## Code

### Heads

For a Conv2D layer, input = (N * C * H * W) or (C * H * W)

In [None]:
# Using downloadable resnet model without all WCL changes to architecture

# replace the first layer with a smaller conv layer [kernel size = 3]
# remove the final fc (linear) layer so the output is of size 2048

class Net(nn.Module): 
    def __init__(self): 
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,    # In: 3 channels
                               stride=1, padding=1,     # Out: 64 channels
                               bias = False)
        layers = list(models.resnet50().children())[1:-1] 
        self.middle = nn.Sequential(*layers) 

    def forward(self, x):  # [N * C * H * W]
        x = self.conv1(x)  #  [N * 64 * H * W]
        x = self.middle(x)  # [N * 2048] 2048 is resnet hidden dim
        return x.view(x.shape[0], -1) # [N * 2048]

In [None]:
class ProjectionHead(nn.Module): 
    def __init__(self, dim_in=2048, dim_out=128, dim_hidden=2048): 
        super(ProjectionHead, self).__init__()
        self.linear1 = nn.Linear(dim_in, dim_hidden)
        self.bn1 = nn.BatchNorm1d(dim_hidden)
        self.linear2 = nn.Linear(dim_hidden, dim_hidden)
        self.bn2 = nn.BatchNorm1d(dim_hidden)
        self.linear3 = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(self.bn1(x))
        x = self.linear2(x)
        x = F.relu(self.bn2(x))
        x = self.linear3(x)
        return x    

### Loss

**Contrastive Loss** defined on zi's per [original paper](https://arxiv.org/pdf/2002.05709v3): 

Contrastive loss function defined for a contrastive pre- diction task. Given a set {x ̃k} including a positive pair of examples x ̃i and x ̃j, the contrastive prediction task aims to identify x ̃j in {x ̃k}k̸=i for a given x ̃i.

[Kevin Musgraves added NTXentLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/issues/6)

[Requirements](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss): Positive pairs (embeddings[i], embeddings[j]) are defined when labels[i]==labels[j]  



In [4]:
## Example
# ! pip install pytorch-metric-learning

import torch
from pytorch_metric_learning.losses import NTXentLoss

batch_size = 16
embedding_dim = 512

# Just to make this example runnable
anchor_embeddings = torch.randn(batch_size, embedding_dim)
positive_embeddings = torch.randn(batch_size, embedding_dim)

embeddings = torch.cat((anchor_embeddings, positive_embeddings))
indices = torch.arange(0, anchor_embeddings.size(0), device=anchor_embeddings.device)
labels = torch.cat((indices, indices))

loss = NTXentLoss(temperature=0.10)
loss(embeddings, labels)

tensor(3.5128)

### Final Class

In [5]:
## Basic Intuitive Model? 

class SimCLR(nn.Module): 
    def __init__(self): 
        super(SimCLR, self).__init__()
        self.resnet = Net()
        self.head = ProjectionHead()

    def forward(self, x1, x2, t=0.1): 
        
        h1 = self.resnet(x1)
        h2 = self.resnet(x2)

        ## As far as I can see, the WCL module applies F.normalize to these
        ## maybe for computational complexity? 
        z1 = self.head(h1)
        z2 = self.head(h2)     #dim = batch_size * embedding_size(128)

        # N = batch_size * 2
        # positive pairs identified by new labels (using index)
        z = torch.cat((z1, z2))     #dim = N * embedding_size(128)
        indices = torch.arange(0, z1.size(0)) 
        labels = torch.cat((indices, indices))      #dim = N
        return z, labels


In [6]:
## test dimensions
batch_size = 5
height, width = 10, 10
channels = 3
image1 = torch.randn(batch_size, channels, height, width)
image2 = torch.randn(batch_size, channels, height, width)

image1.shape, image2.shape

(torch.Size([5, 3, 10, 10]), torch.Size([5, 3, 10, 10]))

In [7]:
model = SimCLR()
z, labels = model.forward(image1, image2)
z.shape, labels.shape 

(torch.Size([10, 128]), torch.Size([10]))

In [8]:
loss = NTXentLoss(temperature=0.10)
loss(z, labels)

tensor(2.0625, grad_fn=<MeanBackward0>)