In [1]:
import os
import sys
import time
import copy

from typing import Dict, List, Tuple, Optional, Set
import torch
from torch import nn  
from torch.utils import data  
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

In [80]:
softmax = torch.nn.Softmax(dim=1)
crossentropy = torch.nn.CrossEntropyLoss()

def LossFlatNCE(feat1, feat2, inv_temp=1., feat2_transposed=False, normalizer=None):
    
    '''
    feat1   bs1 x dim
    feat2   bs2 x dim (dim x bs2 if feat2_tranposed is True)
    normalizer(x, dim=1)
    
    Official implementation of
    Junya Chen, et al. Simpler, Faster, Stronger: Breaking The log-K Curse On Contrastive Learners With FlatNCE
    NeurIPS 2021 SSL Workshop
    https://arxiv.org/abs/2107.01152
    Some of the results are published in NeurIPS 2022 paper 
    Tight Mutual Information Estimation With Contrastive Fenchel-Legendre Optimization
    
    bs1<=bs2, and (input1[i], input2[i]) are positive pairs
    all (input1[i], input2[j]) i!=j are negative pairs
    when bs1<bs2, the negative samples are augmented (e.g., momentum contrastive (MoCo))
    '''
    
    assert len(feat1.size())==2, 'input1 dimension should be batch_size x feature_dim'
    assert len(feat2.size())==2, 'input2 dimension should be batch_size x feature_dim (or transpose)'
    
    if feat2_transposed is False:
        feat2 = feat2.t()
        
    assert feat1.size(dim=1)==feat2.size(dim=0), 'The feature dimension should match for input1 and input2'
    
    n1 = feat1.size(dim=0)
    n2 = feat2.size(dim=1)
    assert n1<=n2, 'Size of input2 should not be less than input1'
    
    # Normlize feature if normalizer is specified
    if normalizer is not None:
        feat1 = normalizer(feat1, dim=1)
        feat2 = normalizer(feat2, dim=0)
    
    similarity = feat1 @ feat2
    
    mask = torch.eye(n1, dtype=torch.bool)
    if n1<n2:
        mask = torch.cat([mask,torch.zeros([n1,n2-n1], dtype=torch.bool)], dim=1)
    
    positives = similarity[mask].view(n1,-1)
    negatives = similarity[~mask].view(n1,-1)
    
    contrastive_logits = inv_temp * (negatives - positives)
    
    s = torch.logsumexp(contrastive_logits, dim=1)
    
    loss_vec = torch.exp(s-s.detach())
    
    loss = loss_vec.mean()-1.
    
    logits = contrastive_logits
    weight = softmax(logits)
    ness_vec = (1./(torch.square(weight).sum(dim=1)))/logits.size(dim=1) # [1/size, 1]
    ness = ness_vec.mean()
    
    res = dict()
    res['loss_vec'] = loss_vec
    res['similarity'] = similarity
    res['contrastive_logits'] = contrastive_logits
    res['ness_vec'] = ness_vec
    res['ness'] = ness
    
    labels = torch.arange(start=0,end=similarity.size(0),dtype=torch.long).to(loss.device) 
    res['mi'] = -crossentropy(inv_temp * similarity, labels).mean() + torch.log(torch.Tensor([similarity.size(1)]))
    
    return loss, res

In [41]:
def l2_normalizer(x, dim=1):
    
    norm = torch.sqrt(torch.square(x).sum(dim=dim,keepdim=True))
    x_norm = x / norm
    
    return x_norm

In [58]:
bs1 = 10
bs2 = 15
dim = 20

x1 = torch.Tensor(np.random.randn(bs1,dim))
x2 = torch.Tensor(np.random.randn(bs2,dim))

In [81]:
flatnce, res = LossFlatNCE(x1,x2,normalizer=l2_normalizer)

In [82]:
res['mi']

tensor([0.0728])

In [None]:
res['ness'] # Normalized Effective Sample-Size