In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

### Get Data

In [9]:
device = 'cuda:3'
# torch.set_default_device(device) ## causes errors with dataset and loader. 

In [10]:
transforms = transforms.Compose(
    [
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
    ]
)

In [11]:
batch_size = 128

train = torchvision.datasets.CIFAR10(root='/home/sshad/data/wcl-data/', 
                                     train=True,
                                     transform=transforms)
trainloader = DataLoader(train, batch_size=batch_size, 
                         shuffle=True, num_workers=2)

In [12]:
x, y = next(iter(trainloader))

In [13]:
x.shape, y.shape

(torch.Size([128, 3, 32, 32]), torch.Size([128]))

In [12]:
x = x.to(device)
y = y.to(device)

### Custom Dataset

In [5]:
from torchvision.datasets import CIFAR10

In [7]:
# From WCL
class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


# # Not sure how to use this with torch dataset. 
# class ContrastiveCIFAR10(Dataset): 
#     def __init__(self, train = True): 
#         super(Dataset, self).__init__

#         self.data = CIFAR10(root='/home/sshad/data/wcl-data/', 
#                             train= train,
#                             transform=None)
#         type(self.data)
#         ## SimCLR transformations
#         self.transform = transforms.Compose([
#             transforms.RandomResizedCrop(224),
#             transforms.RandomHorizontalFlip(),
#             transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
#             transforms.RandomGrayscale(p=0.2),
#             transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
#             transforms.ToTensor(),
#             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#         ])

#     def __len__(self): 
#         return len(self.data)
    
#     def __getitem__(self, index): 
#         image, label = self.data[index]
#         img1, img2 = self.transform(image), self.transform(image)
#         return img1, img2, label


### Model

In [4]:
# torch.set_default_device(device)

In [5]:
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

In [6]:
resnet = models.resnet50(weights=None)
len(list(resnet.children())[1:-1])

8

In [14]:
# Using downloadable resnet model without all WCL changes to architecture

# replace the first layer with a smaller conv layer [kernel size = 3]
# remove the final fc (linear) layer so the output is of size 2048

class Net(nn.Module): 
    def __init__(self): 
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,  # 3 channels in
                               stride=1, padding=1,   # 64 channels out
                               bias = False)
        layers = list(models.resnet50(weights=None).children())[1:-1]
        self.middle = nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.middle(x)
        return x.view(x.shape[0], -1)
        

In [15]:
## Test Net

test_model = Net()
z = test_model.forward(x)
z.shape

torch.Size([128, 3, 32, 32])
torch.Size([128, 64, 32, 32])
torch.Size([128, 2048, 1, 1])


torch.Size([128, 2048])

In [22]:
## Adapted from WCL but INCORRECT ## SEE NEXT 
class ProjectionHead(nn.Module): 
    def __init__(self, dim_in=2048, dim_out=128, dim_hidden=2048): 
        super(ProjectionHead, self).__init__()
        self.linear1 = nn.Linear(dim_in, dim_hidden)
        self.bn1 = nn.BatchNorm1d(dim_hidden)
        self.linear2 = nn.Linear(dim_hidden, dim_hidden)
        self.bn2 = nn.BatchNorm1d(dim_hidden)
        self.linear3 = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = self.linear1(x).unsqueeze(-1).unsqueeze(-1) ## BN spatial reqs
        x = self.bn1(x).squeeze(-1).squeeze(-1)
        x = nn.ReLU(inplace=True)
        x = self.linear2(x).unsqueeze(-1).unsqueeze(-1)
        x = self.bn2(x).squeeze(-1).squeeze(-1)
        x = nn.ReLU(inplace=True)
        x = self.linear3(x)
        return x               

##### Why do we squeeze and unsqueeze?

Batch normalization (BatchNorm) typically operates on mini-batches of data in convolutional neural networks (CNNs). In CNNs, the input to each layer is often represented as a 4D tensor with dimensions [batch_size, channels, height, width]. BatchNorm normalizes along the channel dimension, so it requires statistics to be computed across the batch and spatial dimensions while keeping the channel dimension intact.

In [None]:
# In this case though, they are using BatchNorm1d 
#which accepts (N, C, L) or (N, C). Extra unsqueeze unnecessary?? 

test_lin = nn.Linear(2048, 2048)
a = test_lin(z).unsqueeze(-1).unsqueeze(-1) # This does not work
a.shape
test_bn = nn.BatchNorm1d(2048)
b = test_bn(a)

In [25]:
a = test_lin(z) ## this works
b = test_bn(a)
b.shape

torch.Size([128, 2048])

In [17]:
## CORRECT 
class ProjectionHead(nn.Module): 
    def __init__(self, dim_in=2048, dim_out=128, dim_hidden=2048): 
        super(ProjectionHead, self).__init__()
        self.linear1 = nn.Linear(dim_in, dim_hidden)
        self.bn1 = nn.BatchNorm1d(dim_hidden)
        self.linear2 = nn.Linear(dim_hidden, dim_hidden)
        self.bn2 = nn.BatchNorm1d(dim_hidden)
        self.linear3 = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(self.bn1(x))
        x = self.linear2(x)
        x = F.relu(self.bn2(x))
        x = self.linear3(x)
        return x               

In [18]:
test_model = ProjectionHead()
a = test_model.forward(z)
a.shape

torch.Size([128, 2048])
torch.Size([128, 2048])
torch.Size([128, 2048])
torch.Size([128, 2048])
torch.Size([128, 128])


torch.Size([128, 128])

##### Question

Purpose of @torch.no_grad()? Difference from model.eval()

In [101]:
class WCL(nn.Module): 
    def __init__(self, dim_hidden=4096, dim_out=256):
        super(WCL, self).__init__
        self.net = Net()
        self.head1 = ProjectionHead(dim_in=2048, 
                                    dim_out=dim_out, 
                                    dim_hidden=dim_hidden)
        self.head1 = ProjectionHead(dim_in=2048, 
                                    dim_out=dim_out, 
                                    dim_hidden=dim_hidden)
    
    def build_connected_component(self, distances): 
        b = distances.shape[0]
        distances = torch.eye(b, b) *2
        # Returns a 2-D tensor with twos on the diagonal and zeros elsewhere.

    def forward(self, x1, x2, t=0.1): 
        