<a href="https://colab.research.google.com/github/superbunny38/DeepLearning/blob/main/papers/VectorizedNT_XentLossExplained.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Requirements

In [None]:
import os
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm

# Vectorized Criterion Version

In [None]:
class Criterion(nn.Module):
    def __init__(self):
        super().__init__()
        self.batch_size = cfg.train.batch_size
        self.register_buffer("temperature", torch.tensor(cfg.train.temperature))
        self.register_buffer("negatives_mask", (~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=bool)).float())
            
    def forward(self, emb_i, emb_j):#emb_i,emb_j = z_1,z_2
        """
        emb_i and emb_j are batches of embeddings, where corresponding indices are pairs
        z_i, z_j as per SimCLR paper
        """
        z_i = F.normalize(emb_i, dim=1)
        z_j = F.normalize(emb_j, dim=1)

        representations = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
        
        sim_ij = torch.diag(similarity_matrix, self.batch_size)
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)
        
        #여기부터
        nominator = torch.exp(positives / self.temperature)
        denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)
        #여기까지: positive에 i,j랑 j,i의 sim 구해서 합친다음 다른 모든 sim(이때 본인과 본인의 sim 제외)으로 나눠줌

        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))#dim = 1은 row 방향 합산(--->)
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss

In [None]:
batch_size = 3

In [None]:
negatives_mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float()
negatives_mask

tensor([[0., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1.],
        [1., 1., 0., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 0.]])

In [None]:
batch_1 = torch.tensor([[1,8,2],[5,10,4],[0,9,9]],dtype=float)#I1,I2,I3
batch_2 = torch.tensor([[9,2,2],[6,1,3],[4,8,9]],dtype=float)#J1,J2,J3
z_1 = F.normalize(batch_1, dim = 1)
z_2 = F.normalize(batch_2, dim = 1)
z_2

tensor([[0.9540, 0.2120, 0.2120],
        [0.8847, 0.1474, 0.4423],
        [0.3152, 0.6305, 0.7093]], dtype=torch.float64)

In [None]:
representations = torch.cat([z_1,z_2],dim=0)
representations

tensor([[0.1204, 0.9631, 0.2408],
        [0.4211, 0.8422, 0.3369],
        [0.0000, 0.7071, 0.7071],
        [0.9540, 0.2120, 0.2120],
        [0.8847, 0.1474, 0.4423],
        [0.3152, 0.6305, 0.7093]], dtype=torch.float64)

In [None]:
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
similarity_matrix

tensor([[1.0000, 0.9429, 0.8513, 0.3701, 0.3550, 0.8159],
        [0.9429, 1.0000, 0.8337, 0.6517, 0.6457, 0.9026],
        [0.8513, 0.8337, 1.0000, 0.2998, 0.4170, 0.9474],
        [0.3701, 0.6517, 0.2998, 1.0000, 0.9690, 0.5848],
        [0.3550, 0.6457, 0.4170, 0.9690, 1.0000, 0.6856],
        [0.8159, 0.9026, 0.9474, 0.5848, 0.6856, 1.0000]], dtype=torch.float64)

similarity_matrix[0]: sim(I1,I1),sim(I1,I2),...sim(I1,J3)

In [None]:
similarity_matrix[0]

tensor([1.0000, 0.9429, 0.8513, 0.3701, 0.3550, 0.8159], dtype=torch.float64)

similarity(I1,I2) == similarity(I2,I1)

In [None]:
sim_ij = torch.diag(similarity_matrix, 3)#positive pairs: sim(I1,J1),sim(I2,J2),sim(I3,J3)
sim_ij

tensor([0.3701, 0.6457, 0.9474], dtype=torch.float64)

In [None]:
sim_ji = torch.diag(similarity_matrix, -3)
sim_ji

tensor([0.3701, 0.6457, 0.9474], dtype=torch.float64)

In [None]:
positives = torch.cat([sim_ij, sim_ji], dim=0)# sim(I1,J1), sim(I2,J2), sim(I3,J3), sim(J1,I1),sim(J2,I2), sim(J3,I3)
positives

tensor([0.3701, 0.6457, 0.9474, 0.3701, 0.6457, 0.9474], dtype=torch.float64)

$l(i,j)$ = $-log\frac{exp(\frac{s_{i,j}}{τ})}{\sum_{k=1}^{2N}I_{k \neq i}exp(\frac{s_{i,j}}{τ})}$

In [None]:
temperature = 0.5
nominator = torch.exp(positives / temperature)
nominator#분자

tensor([2.0962, 3.6377, 6.6509, 2.0962, 3.6377, 6.6509], dtype=torch.float64)

In [None]:
denominator = negatives_mask * torch.exp(similarity_matrix / temperature)
denominator#분모

tensor([[0.0000, 6.5911, 5.4877, 2.0962, 2.0340, 5.1135],
        [6.5911, 0.0000, 5.2982, 3.6815, 3.6377, 6.0817],
        [5.4877, 5.2982, 0.0000, 1.8214, 2.3026, 6.6509],
        [2.0962, 3.6815, 1.8214, 0.0000, 6.9447, 3.2206],
        [2.0340, 3.6377, 2.3026, 6.9447, 0.0000, 3.9399],
        [5.1135, 6.0817, 6.6509, 3.2206, 3.9399, 0.0000]], dtype=torch.float64)

In [None]:
sum(denominator[0])

tensor(21.3226, dtype=torch.float64)

In [None]:
#simlarity for I1, similarity for I2, ..., similarity for J3
torch.sum(denominator, dim=1)#row 방향 합산 --->

tensor([21.3226, 25.2903, 21.5609, 17.7643, 18.8589, 25.0066],
       dtype=torch.float64)

In [None]:
loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))

In [None]:
loss_partial#l(I1,J1), l(I2,J2), l(I3,J3), l(J1,I1), l(J2,I2), l(J3,I3)

tensor([2.3196, 1.9391, 1.1761, 2.1371, 1.6456, 1.3244], dtype=torch.float64)

In [None]:
loss = torch.sum(loss_partial) / (2 * batch_size)
loss

tensor(1.7570, dtype=torch.float64)