This file contains code to demonstrate how how contrastive learning is effectively an instance discrimination problem.
In particular, we present a similarity rank based metric that can determine the difficulty of the instance discrimination task.

We use openly available public dataset [Nomao](https://archive.ics.uci.edu/dataset/227/nomao) and are interested in learning contrastive loss based data representation.

In [3]:
# importing packages

import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm, linear_model, model_selection, metrics, ensemble
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, average_precision_score, pairwise,mutual_info_score,mean_squared_error
from scipy import linalg, stats
import os.path
import pandas as pd
import matplotlib.pyplot as plt
import random
from matching.games import StableMarriage
import pingouin as pg
import datetime
# from datetime import datetime
import json, sys, argparse
from tqdm.auto import tqdm

In [4]:
def NTXentLoss(z1, z2,
                 temperature=0.1):  # embeddings from known features of both databases followed by the unknown features
    # compute the cosine similarity bu first normalizing and then matrix multiplying the known and unknown tensors
    cos_sim_o = torch.div(torch.matmul(torch.nn.functional.normalize(z1),
                                       torch.transpose(torch.nn.functional.normalize(z2), 0, 1)),
                          temperature)

    # for numerical stability  ## TODO update this logit name
    logits_max_o, _ = torch.max(cos_sim_o, dim=1, keepdim=True)
    logits_o = cos_sim_o - logits_max_o.detach()

    # breakpoint()
    if True:
      # computing the exp logits
      exp_o = torch.exp(logits_o)
      batch_loss_o = - torch.log(exp_o.diag() / exp_o.sum(dim=0)).sum() - torch.log(
        exp_o.diag() / exp_o.sum(dim=1)).sum()
      # computing the avg rank of the positive examples for checking if the algo is learning the representation closer
      # since we are computing the rank on the similarity so higher the better
      avg_rank_cos_sim_o = np.trace(stats.rankdata(cos_sim_o.cpu().detach().numpy(), axis=1)) / len(cos_sim_o)

    # print("This batch's loss and avg rank ", batch_loss_o.item(), batch_loss_r.item(), avg_rank_cos_sim_o, avg_rank_cos_sim_r)
    return batch_loss_o, avg_rank_cos_sim_o

In [None]:
# encoder networks
class MLP(torch.nn.Sequential):
    """Simple multi-layer perceptron with ReLu activation and optional dropout layer"""

    def __init__(self, input_dim, hidden_dim, n_layers, dropout=0.0):
        layers = []
        in_dim = input_dim
        for _ in range(n_layers - 1):
            layers.append(torch.nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU(inplace=True))
            layers.append(torch.nn.Dropout(dropout))
            in_dim = hidden_dim

        layers.append(torch.nn.Linear(in_dim, hidden_dim))

        super().__init__(*layers)
        
class CL_model(nn.Module):
    def __init__(
        self,
        input_dim,
        emb_dim,
        encoder_depth=4,
        head_depth=2,
    ):
        """Implementation of a SimCLR kind basic CL approach.
        It consists of an encoder that learns the embeddings.
        It is done by minimizing the contrastive loss of a sample and an augmented view of it.
            Args:
                input_dim (int): size of the inputs
                emb_dim (int): dimension of the embedding space
                encoder_depth (int, optional): number of layers of the encoder MLP. Defaults to 4.
                head_depth (int, optional): number of layers of the pretraining head. Defaults to 2.
        """
        super().__init__()


        self.encoder = MLP(input_dim, emb_dim, encoder_depth)
        self.pretraining_head = MLP(emb_dim, emb_dim, head_depth)

        # initialize weights
        self.encoder.apply(self._init_weights)
        self.pretraining_head.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            module.bias.data.fill_(0.01)