In [5]:
"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_function

import torch
import torch.nn as nn


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
#---------------------------------------------------------------------------------------------------------------
        #This part of the code uses GPU if present. Else, it uses CPU
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))
#---------------------------------------------------------------------------------------------------------------
        #this part of the code requires that the input feature dimension be 3 dimensional or larger.
        #perhaps because the input data is in RGB format? 
        if len(features.shape) < 3: 
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)
#---------------------------------------------------------------------------------------------------------------

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
#----------------------------------------------------------------------------------------------------------------

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

In [14]:
features = torch.randn(2,2)
features

tensor([[-0.3236, -0.5551],
        [-2.0026,  0.0230]])

In [42]:
#batch_size
bs = 4
features = torch.randn(2, 2)
contrast_feature  = features
anchor_feature = contrast_feature
#torch.randn: Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1

temperature = 0.07
anchor_feature = contrast_feature #Note we not doing exp their is a reason see below

anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    temperature)

#torch.div: Divides each element of the input 'input' by the corresponding element of 'other'.
#torch.matmul: Matrix product of two tensors.

print('feature:\n {}'.format(features))
print('anchor_dot_contrast: \n {}'.format(anchor_dot_contrast))

feature:
 tensor([[0.4683, 1.1201],
        [1.3249, 1.1536]])
anchor_dot_contrast: 
 tensor([[21.0554, 27.3226],
        [27.3226, 44.0876]])


In [21]:
print('anchor_dot_contrast: \n{}\n'.format(anchor_dot_contrast))
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
print('logits_max:\n {}\n'.format(logits_max))
logits = anchor_dot_contrast - logits_max.detach()
print(' logits:\n {}\n'.format(logits))#output see what happen to diagonal

anchor_dot_contrast: 
tensor([[35.6513, 13.5210],
        [13.5210,  5.8937]])

logits_max:
 tensor([[35.6513],
        [13.5210]])

 logits:
 tensor([[  0.0000, -22.1303],
        [  0.0000,  -7.6273]])



In [None]:
bs = 4
print('batch size:', bs)

temperature = 0.07

labels = torch.randint(4, (1,4))
print('labels:', labels)

mask = torch.eq(labels, labels.T).float()
print('mask = \n {}'.format(mask))#hard coding it for easier understanding otherwise its features.shape[1]
# torch.eq: Computes element-wise equality
# example: torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
# output: tensor([[ True, False],
#                [False, True]])


contrast_count = 2

anchor_count = contrast_count

mask = mask.repeat(anchor_count, contrast_count)# mask-out self-contrast cases
#tensor.repeat: Repeats this tensor along the specified dimensions.

logits_mask = torch.scatter(
    torch.ones_like(mask), #Returns a tensor filled with the scalar value 1, with the same size as input.
    1,
    torch.arange(bs * anchor_count).view(-1, 1), #Returns a 1-D tensor of size [(end-stop)/step] with values from the 
                                                #interval [start, end) taken with common difference step beginning from start
    0
)

In [39]:
logits_mask = torch.scatter(
    torch.ones_like(mask), #Returns a tensor filled with the scalar value 1, with the same size as input.
    1,
    torch.arange(bs * anchor_count).view(-1, 1), 
    #Returns a 1-D tensor of size [(end-stop)/step] with values from the 
    #interval [start, end) taken with common difference step beginning from start
    0
)
print("torch.ones_like(mask):\n {}".format(torch.ones_like(mask)))
print("torch.arange(bs*anchor_count.view(-1.1)):\n {}".format(torch.torch.arange(bs*anchor_count).view(-1,1)))
print("logits_mask\n: {}".format(logits_mask))

torch.ones_like(mask):
 tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
torch.arange(bs*anchor_count.view(-1.1)):
 tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7]])
logits_mask
: tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 1.],
        [1., 1., 0., 1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1., 1., 1.],
        [1., 1., 1., 1., 0., 1., 1., 1.],
        [1., 1., 1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0.]])


In [40]:
mask = mask * logits_mask
print('mask * logits_mask = \n{}'.format(mask))

mask * logits_mask = 
tensor([[0., 1., 0., 1., 1., 1., 0., 1.],
        [1., 0., 0., 1., 1., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 1., 0., 0., 1., 1., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 1.],
        [1., 1., 0., 1., 1., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 1., 1., 1., 0., 0.]])


In [33]:
torch.arange(bs*anchor_count).view(-1,1)

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7]])

In [25]:
batch_size = 4
labels_tensor = ([[3, 0, 2, 3]])
#what above means in this particular batch of 4 we got 3,0,2,3 labels. 
#Just in case you forgot we are contrasting here only once so we will have 3_c, 0_c, 2_c, 3_c 
#as our contrast in the input batch.

#basically 'batch_size' X 'contrast_count' X 'C' x 'Width' X 'height' -> check above if you confused mask = 
#tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
#        [1., 0., 1., 1., 1., 1., 1., 1.],
#        [1., 1., 0., 1., 1., 1., 1., 1.],
#        [1., 1., 1., 0., 1., 1., 1., 1.],
#        [1., 1., 1., 1., 0., 1., 1., 1.],
#        [1., 1., 1., 1., 1., 0., 1., 1.],
#        [1., 1., 1., 1., 1., 1., 0., 1.],
#        [1., 1., 1., 1., 1., 1., 1., 0.]])
#Easy to understand the Self Supervised Contrastive Loss now which is simpler than this


#this is really important so we created a mask = mask * logits_mask which tells us for 
#0 th image representation which are the image it should be contrasted with.
# so our labels are labels tensor([[3, 0, 2, 3]])

# I am renaming them for better understanding tensor([[3_1, 0_1, 2_1, 3_2]])
# so at 3_0 will be contrasted with its own augmentation which is at position 5 (index = 4) and 
# position 8 (index = 7) in the first row those are the position marked one else its zero
#See the image bellow for better understanding

#mask * logits_mask = 
#tensor([[0., 0., 0., 1., 1., 0., 0., 1.],
#        [0., 0., 0., 0., 0., 1., 0., 0.],
#        [0., 0., 0., 0., 0., 0., 1., 0.],
#        [1., 0., 0., 0., 1., 0., 0., 1.],
#        [1., 0., 0., 1., 0., 0., 0., 1.],
#        [0., 1., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 1., 0., 0., 0., 0., 0.],
#        [1., 0., 0., 1., 1., 0., 0., 0.]])

In [6]:
# compute logits
anchor_dot_contrast = torch.div(          
    torch.matmul(anchor_feature, contrast_feature.T),        
    self.temperature)
#torch.div: Divides each element of the input 'input' by the corresponding element of 'other'.
#torch.matmul: #Matrix product of two tensors.

# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 
#Returns the maximum value for each row.
#keepdim (bool): whether the output tensor has dim retained or not. Default is False

logits = anchor_dot_contrast - logits_max.detach()
#Tensor.detach() is used to detach a tensor from the current computational graph. It returns a new tensor that doesn't require a gradient.


# tile mask
mask = mask.repeat(anchor_count, contrast_count)
#Repeats this tensor along the specified dimensions

# mask-out self-contrast cases
logits_mask = torch.scatter(
    torch.ones_like(mask),
    1,
    torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
    0
)
mask = mask * logits_mask

# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()

return loss

NameError: name 'self' is not defined

In [2]:
pip install torch torchvision torchaudio

Collecting torchNote: you may need to restart the kernel to use updated packages.
  Downloading torch-1.12.0-cp38-cp38-win_amd64.whl (161.9 MB)
Collecting torchvision
  Downloading torchvision-0.13.0-cp38-cp38-win_amd64.whl (1.1 MB)
Collecting torchaudio
  Downloading torchaudio-0.12.0-cp38-cp38-win_amd64.whl (969 kB)
Installing collected packages: torch, torchvision, torchaudio
Successfully installed torch-1.12.0 torchaudio-0.12.0 torchvision-0.13.0

