# Details of contrastive loss

In [None]:
import torch
def contrastive_force(positive, negative, temp):
    '''
    positive: (N, D)
    negative: (M, D)
    temp: float
    '''
    N, M = positive.shape[0], negative.shape[0]
    concat = torch.cat([negative, positive], dim=0) # [N + M, D]
    pairwise_dist = torch.cdist(concat, concat, p=2) # [N + M, N + M]
    affinity = torch.softmax(-pairwise_dist / temp, dim=1).pow(0.5) * \
                torch.softmax(-pairwise_dist / temp, dim=0).pow(0.5)  # [N + M, N + M]
    total_force = torch.zeros_like(negative) # [M, D]
    for j in range(M):
        for i in range(N):
            for k in range(M), k != j:
                total_force[j] += affinity[j, i + M] * affinity[j, k] * (positive[i] - negative[k])
    ... skipped normalization here. 
    return total_force

def contrastive_loss(positive, negative):
    '''
    positive: (N, D)
    negative: (M, D)
    '''
    ... skipped normalization here. 
    total_force = 0
    for temp in [0.01, 0.02, 0.05, 0.1, 0.2]:
        total_force += contrastive_force(positive, negative, temp)
    target = negative + total_force # [M, D]
    ... skipped normalization here. 
    return (negative - sg(target)) ** 2




In [None]:

def contrastive_force(positive, negative, temp):
    '''
    positive: (N, D)
    negative: (M, D)
    temp: float
    '''
    # main part
    concat = torch.cat([negative, positive], dim=0) # [N + M, D]
    pairwise_dist = torch.cdist(concat, concat, p=2) # [N + M, N + M]
    affinity = torch.softmax(-pairwise_dist / temp, dim=1).pow(0.5) * \
                torch.softmax(-pairwise_dist / temp, dim=0).pow(0.5)  # [N + M, N + M]
    total_force = torch.zeros_like(negative) # [M, D]
    for j in range(M):
        for i in range(N):
            for k in range(M), k != j:
                total_force[j] += affinity[j, i + M] * affinity[j, k] * (positive[i] - negative[k])
    # a silly normalization below. maybe not needed. Haven't ablated. 
    total_force = total_force / ((pairwise_dist * affinity) ** 2).mean().sqrt()    
    return total_force

def contrastive_loss(positive, negative):
    # Normalize to make sure pairwise dist has scale 1. 
    concat = torch.cat([positive, negative], dim=0) # [N + M, D]
    pairwise_dist = torch.cdist(concat, concat, p=2) # [N + M, N + M]
    scale = pairwise_dist.mean()
    positive, negative = positive / scale, negative / scale
    # main part
    total_force = 0
    for temp in [0.01, 0.02, 0.05, 0.1, 0.2]:
        total_force += contrastive_force(positive, negative, temp)
    target = negative + total_force # [M, D]
    # Remark: in actual implementation, normalized [B, M, D] to have scale 1. 
    return (negative - sg(target)) ** 2


