# Contrastive Learning development
## Method: SimCLR

- reference: https://arxiv.org/pdf/2002.05709.pdf  
- use: training on pairs of angle-jittered thickness-sphere images
- TO DO: add more testing for loss function; all other components seem to work correctly

In [1]:
""" Note on development:
    https://stackoverflow.com/questions/66065272/customizing-the-batch-with-specific-elements
    was a useful reference for designing the custom batch_sampler.

"""
import user
from user.data import CustomImageDataset

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchsummary import summary

import numpy as np
import pandas as pd
import os

from typing import Union  # useful for function annotation

In [2]:
def inner_products(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """ Computes all inner products of two batches of tensors.
        Given x_bi, y_cj (batch dimension first), returns
            s_bc = x_bi y_ci, inner product over vector dim.
    """
    assert x.shape == y.shape  # inputs must be compatible
    xnorm = torch.norm(x, dim=-1)  # norm is Euclidean by default
    ynorm = torch.norm(y, dim=-1)
    denom = torch.outer(xnorm, ynorm)
    # d = xnorm.shape[0] # get appropriate dimension for outer product
    # denom = xnorm.reshape(d, 1) @ ynorm.reshape(1, d)
    xdoty = x @ y.T
    
    return torch.div(xdoty, denom)

In [3]:
class ProjectionHead(nn.Module):
    """ One hidden-layer MLP for the projection heads of network.
    """
    def __init__(self, nin, nmid, nout, nonlin=None):
        super().__init__()
        if nonlin is None:
            nonlin = F.elu
        self.l1 = nn.Linear(nin, nmid)
        self.l2 = nn.Linear(nmid, nout)
        self.nonlin = nonlin
        
    def forward(self, x):
        x = self.nonlin(self.l1(x))
        x = self.l2(x)
        return x

In [52]:
def loss(inn: torch.Tensor, tau: Union[float, torch.Tensor]) -> torch.Tensor:
    """ Contrastive loss.
        Inputs:
            inn: s_ij, inner products of vector batch.
            tau: temperature parameter
    """
    d = inn.shape[0]
    assert d % 2 == 0  # must be complete number of pairs
    innexp = torch.exp(inn/tau)  # complete matrix of all pairwise exps
    denoms = torch.sum(innexp, dim=-1) - torch.diagonal(inn, 0)
    
    # compute all losses between (positive) pairs: 2x,2x+1 and 2x+1,x
    left_nums = torch.diagonal(innexp, 1)  # above diagonal; innexp_i,i+1
    left_loss = -torch.log(left_nums[::2]/denoms[:-1][::2])
    right_nums = torch.diagonal(innexp, -1)  # below diagonal; innexp_i+1,i
    right_loss = -torch.log(right_nums[::2]/denoms[:-1][::2])
    total_loss = (1/d) * torch.sum(right_loss + left_loss) # normalize by 1/d
    
    return total_loss

In [55]:
output = model(X)

In [57]:
output_inner = inner_products(output, output)

In [59]:
output_inner.shape

torch.Size([64, 64])

In [60]:
loss(output_inner, 0.5)

tensor(4.1568, device='cuda:0', grad_fn=<MulBackward0>)

In [5]:
#  dataset loading / processing

img_dir = '../data/fsa-data/images/images-thickness-sphere-processed'
label_dir = '../data/labels'
label_file = label_dir + '/age_thickness_sphere_labels.csv'

In [6]:
labels = pd.read_csv(label_file)

In [7]:
def create_paired_labels(df, nsamples=10, savepath=None):
    """ Takes a non-paired label file and creates
        a file of uniform labels, grouped by subject
        with the same number of samples per subject.
    """
    if isinstance(df, str):
        labels = pd.read_csv(df)
    labels = df
    # create new labels file, in preparation for custom batch sampler
    labels['Subject'] = labels['Filename'].apply(lambda s: s[:16])
    counts = labels.groupby('Subject', as_index=False).count()
    counts.columns = list(counts.columns)[:-1] + ['Count']
    label_counts = pd.merge(labels, counts[['Subject', 'Count']], on='Subject', how='left')
    labels_uniform = label_counts[label_counts.Count==nsamples]
    # very important to sort data in labels so that indices are in order
    labels_uniform_sorted = labels_uniform.sort_values('Filename')
    
    #  optionally saves a .csv of modified labels
    if savepath is not None:
        labels_uniform_sorted.to_csv(savepath, index=None)
    
    return labels_uniform_sorted

In [8]:
labelfile = label_dir + '/age_thickness_sphere_paired_labels.csv'

labels_exist = True
if not labels_exist:
    labels_unif = create_paired_labels(labels)
    labels_unif.to_csv(labelfile, index=False)

In [9]:
t = pd.read_csv(labelfile)

In [10]:
t.head(20)

Unnamed: 0,Filename,Age,Subject,Count
0,sub-NDARAA075AMK.fsa.lh.thickness--62.10--1.83...,7,sub-NDARAA075AMK,10
1,sub-NDARAA075AMK.fsa.lh.thickness--64.32--1.43...,7,sub-NDARAA075AMK,10
2,sub-NDARAA075AMK.fsa.lh.thickness--64.88--3.30...,7,sub-NDARAA075AMK,10
3,sub-NDARAA075AMK.fsa.lh.thickness--64.95-2.36-...,7,sub-NDARAA075AMK,10
4,sub-NDARAA075AMK.fsa.lh.thickness--64.97-3.92-...,7,sub-NDARAA075AMK,10
5,sub-NDARAA075AMK.fsa.lh.thickness--65.86-2.87-...,7,sub-NDARAA075AMK,10
6,sub-NDARAA075AMK.fsa.lh.thickness--66.10--1.62...,7,sub-NDARAA075AMK,10
7,sub-NDARAA075AMK.fsa.lh.thickness--67.74--1.54...,7,sub-NDARAA075AMK,10
8,sub-NDARAA075AMK.fsa.lh.thickness--67.89-3.39-...,7,sub-NDARAA075AMK,10
9,sub-NDARAA075AMK.fsa.lh.thickness--68.23--2.24...,7,sub-NDARAA075AMK,10


In [11]:
# note that here, img_dir need not change; the labelfile
# simply specifies a subset of images
from user.data import ToFloat

ds = CustomImageDataset(labelfile, img_dir, transform=transforms.ConvertImageDtype(torch.float32), target_transform=ToFloat())

In [12]:
class PairedBatchSampler(torch.utils.data.Sampler):
    """ Custom batch sampler, to sample (either sequentially
        or randomly) in pairs. 
    """
    
    def __init__(self, datasource, nxs, ntrans, batchsize, random=False):
        super().__init__(datasource)
        self.nx = nxs
        self.nt = ntrans
        self.nb = batchsize
        self.random = random
        if not self.random:
            #  these properties are only necessary for sequential sampling
            self.nextx = 0
            self.pair_ind = 0
        
    def __iter__(self):
        """ Creates index sample of 2 * batchsize,
            batchsize number of pairs. Different depending on whether
            sampling is random or sequential.
        """
        return self._create_batch_iterator()
    
    def _create_batch_iterator(self):
        return PairedBatchSampler._batch_iterator(self)

    class _batch_iterator():
        """ Helper class to create a batch iterator. This is necessary
            because batch_sampler (arg of DataLoader) expects an iterator
            over batches, not merely over elements of a batch.
        """
        #  Note: could accomplish the random sampling by using _pair()
        def __init__(self, outer):
            self.outer = outer
        
        def __iter__(self):
            return self
        
        def __next__(self):
            if self.outer.random:
                pinds = []
                for i in range(self.outer.nb):
                    pinds.append(np.random.choice(list(range(self.outer.nt)), size=2, replace=False))
                pinds = np.array(pinds)
                xinds = np.random.randint(0, self.outer.nx, self.outer.nb)
                inds = (pinds.T + self.outer.nt * xinds).T # must multiply xinds by nt to yield actual ind
                inds = inds.reshape(self.outer.nb * 2)

            if not self.outer.random:
                xinds = list(range(self.outer.nextx, self.outer.nextx + self.outer.nb))
                xinds = [i % self.outer.nx for i in xinds]  # wrap back to beginning of db
                if self.outer.nextx > self.outer.nx - self.outer.nb:  # change pair after all subjects
                    self.outer.pair_ind += 1
                i, j = self.outer._pairs(self.outer.pair_ind)
                xinds  = [self.outer.nt * xi for xi in xinds]
                eveninds = [xi + i for xi in xinds]
                oddinds = [xi + j for xi in xinds]
                inds = [v for pair in zip(eveninds, oddinds) for v in pair]

                self.outer.nextx = (self.outer.nextx + self.outer.nb) % self.outer.nx

            return inds

    def _pairs(self, ind, base=None):
        """ Helper function to uniquely associate indices of a valid pair
            with the numbers from 0 to <number of possible pairs> - 1.
            If called without base specified, base will be set to
            self.nt as this is the primary use case.
        """
        if base is None:
            base = self.nt
        #  get pair of indices    
        i = ind // (base - 1)
        j = ind % (base - 1)
        if j >= i:
            j += 1
        return i, j   

In [13]:
len(ds)

15510

In [14]:
s = PairedBatchSampler(ds, len(ds)//10, 10, 32, random=True)

In [15]:
dl = DataLoader(ds, batch_sampler=s)

In [16]:
X, y = next(iter(dl))

# Test entire SimCLR pipeline

by training on pairs of angle-jittered images

In [17]:
from user.models import ConvPoolBlock

In [18]:
class SimCLR(nn.Module):
    """ Network architecture for contrastive learning.
        200 x 200 inputs.
    """
    
    def __init__(self, pool='Conv'):
        super().__init__()
        self.block1 = ConvPoolBlock(1, 32, 64, pool=pool)
        self.block2 = ConvPoolBlock(64, 128, 128, pool=pool)
        self.block3 = ConvPoolBlock(128, 128, 128, pool=pool)
        self.block4 = ConvPoolBlock(128, 128, 64, pool=pool)
        self.fc1 = nn.Linear(6400, 100)
        self.ph = ProjectionHead(100, 32, 8)
        # second-to-last hidden layer
        # self.z = torch.autograd.Variable(torch.zeros(100, dtype=torch.float), requires_grad=False)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = torch.flatten(x, 1)
        x = F.elu(self.fc1(x))
        # self.z = x  # save second-to-last hidden layer's values in network Variable.
        # x = F.elu(self.fc2(x))
        # x = F.elu(self.fc3(x))
        x = self.ph(x)
        return x
    


In [19]:
net = SimCLR()

In [20]:
device = user.utils.get_device()
net.to(device)

SimCLR(
  (block1): ConvPoolBlock(
    (c1): Conv2d(1, 32, kernel_size=(4, 4), stride=(1, 1), padding=same)
    (c2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  )
  (block2): ConvPoolBlock(
    (c1): Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=same)
    (c2): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2))
  )
  (block3): ConvPoolBlock(
    (c1): Conv2d(128, 128, kernel_size=(4, 4), stride=(1, 1), padding=same)
    (c2): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2))
  )
  (block4): ConvPoolBlock(
    (c1): Conv2d(128, 128, kernel_size=(4, 4), stride=(1, 1), padding=same)
    (c2): Conv2d(128, 64, kernel_size=(4, 4), stride=(2, 2))
  )
  (fc1): Linear(in_features=6400, out_features=100, bias=True)
  (ph): ProjectionHead(
    (l1): Linear(in_features=100, out_features=32, bias=True)
    (l2): Linear(in_features=32, out_features=8, bias=True)
  )
)

In [21]:
summary(net, (1, 200, 200))

Layer (type:depth-idx)                   Output Shape              Param #
├─ConvPoolBlock: 1-1                     [-1, 64, 99, 99]          --
|    └─Conv2d: 2-1                       [-1, 32, 200, 200]        544
|    └─Conv2d: 2-2                       [-1, 64, 99, 99]          32,832
├─ConvPoolBlock: 1-2                     [-1, 128, 48, 48]         --
|    └─Conv2d: 2-3                       [-1, 128, 99, 99]         131,200
|    └─Conv2d: 2-4                       [-1, 128, 48, 48]         262,272
├─ConvPoolBlock: 1-3                     [-1, 128, 23, 23]         --
|    └─Conv2d: 2-5                       [-1, 128, 48, 48]         262,272
|    └─Conv2d: 2-6                       [-1, 128, 23, 23]         262,272
├─ConvPoolBlock: 1-4                     [-1, 64, 10, 10]          --
|    └─Conv2d: 2-7                       [-1, 128, 23, 23]         262,272
|    └─Conv2d: 2-8                       [-1, 64, 10, 10]          131,136
├─Linear: 1-5                            [-1, 100]

  return F.conv2d(input, weight, bias, self.stride,


Layer (type:depth-idx)                   Output Shape              Param #
├─ConvPoolBlock: 1-1                     [-1, 64, 99, 99]          --
|    └─Conv2d: 2-1                       [-1, 32, 200, 200]        544
|    └─Conv2d: 2-2                       [-1, 64, 99, 99]          32,832
├─ConvPoolBlock: 1-2                     [-1, 128, 48, 48]         --
|    └─Conv2d: 2-3                       [-1, 128, 99, 99]         131,200
|    └─Conv2d: 2-4                       [-1, 128, 48, 48]         262,272
├─ConvPoolBlock: 1-3                     [-1, 128, 23, 23]         --
|    └─Conv2d: 2-5                       [-1, 128, 48, 48]         262,272
|    └─Conv2d: 2-6                       [-1, 128, 23, 23]         262,272
├─ConvPoolBlock: 1-4                     [-1, 64, 10, 10]          --
|    └─Conv2d: 2-7                       [-1, 128, 23, 23]         262,272
|    └─Conv2d: 2-8                       [-1, 64, 10, 10]          131,136
├─Linear: 1-5                            [-1, 100]

In [23]:
X = X.to(device)
y = y.to(device)

In [24]:
ypred = net(X)

In [25]:
ypred.shape

torch.Size([64, 8])

In [27]:
# test by training to predict angle

In [28]:
model = net

In [29]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

In [61]:
def loss_and_inner(X, tau=0.5):
    """ Loss function for contrastive learning.
        Computes the inner products of pairs.
        Default temperature tau of 0.5 is taken from original paper."""
    inn = inner_products(X, X)
    return loss(inn, tau=tau)

def loss_and_inner_dummy(X, _):
    return loss_and_inner(X)

In [63]:
loss_and_inner_dummy(output, y)

tensor(4.1568, device='cuda:0', grad_fn=<MulBackward0>)

In [47]:
loss_fn = loss_and_inner_dummy

In [64]:
from user.utils import get_device

def contrastive_train_loop(dataloader, model, loss_fn, optimizer):
    device = get_device()
    size = len(dataloader.dataset)
    # ensure model is set to training mode in case previously in evaluation mode
    model.train()
    for batch, (X_cpu, y_cpu) in enumerate(dataloader):
        batch_size = X_cpu.shape[0]
        # Compute prediction and loss
        X = X_cpu.to(device) # put both model and data on gpu (if available)
        y = y_cpu.to(device)
        output = model(X)
        # this is required for models with multiple returns; by assumption,
        # the first return of a tuple is the prediction; can only squeeze a tensor
        if isinstance(output, tuple):
            output = tuple([torch.squeeze(o) for o in output])
        
        loss = loss_fn(output, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        
        # in case dataloader has cyclic (infinite) enumeration
        if batch_size * batch > size:
            break

In [66]:
EPOCHS = 50
for epoch in range(EPOCHS):
    contrastive_train_loop(dl, model, loss_fn, optimizer)

torch.save(model, './models/contrastive_model.pt')

loss: 4.156766  [    0/15510]
loss: 4.151641  [  640/15510]
loss: 4.155554  [ 1280/15510]
loss: 3.558662  [ 1920/15510]
loss: 3.697443  [ 2560/15510]
loss: 3.498979  [ 3200/15510]
loss: 3.535395  [ 3840/15510]
loss: 3.272998  [ 4480/15510]
loss: 3.101564  [ 5120/15510]
loss: 3.004882  [ 5760/15510]
loss: 3.105434  [ 6400/15510]
loss: 3.027913  [ 7040/15510]
loss: 2.913542  [ 7680/15510]
loss: 2.799510  [ 8320/15510]
loss: 3.060619  [ 8960/15510]
loss: 2.887210  [ 9600/15510]
loss: 2.793036  [10240/15510]
loss: 2.829901  [10880/15510]
loss: 2.869812  [11520/15510]
loss: 2.894574  [12160/15510]
loss: 2.815442  [12800/15510]
loss: 2.758851  [13440/15510]
loss: 2.940223  [14080/15510]
loss: 2.785467  [14720/15510]
loss: 2.772412  [15360/15510]
loss: 3.013589  [    0/15510]
loss: 2.887372  [  640/15510]
loss: 2.845021  [ 1280/15510]
loss: 2.793487  [ 1920/15510]
loss: 2.880816  [ 2560/15510]
loss: 2.810629  [ 3200/15510]
loss: 2.756248  [ 3840/15510]
loss: 2.793505  [ 4480/15510]
loss: 2.78