# setup

In [None]:
!pip install --upgrade transformers
!pip install lightning
!pip install wandb
!pip install gdown
!pip install torchmetrics

In [None]:
import wandb
wandb.login()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn  
from torch.utils.data import Dataset, DataLoader 
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel, DataCollatorWithPadding

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from torchmetrics.classification import MulticlassRecall

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
!gdown 15hLDYO17KUvP9O0TMVx-7U7Muqm4q5PS
!gdown 1PmGhoMOXTLvceGCo-btUpPl8_MtIKWzR

# dataset

In [None]:
class TripletDataset(Dataset):
  def __init__(self, model):
    # read csv_file into df 
    self.duplicates_df = pd.read_csv("pandas_duplicates.csv") 
    self.negatives_df = pd.read_csv("pandas_negatives.csv")

    # tokenizer
    self.tokenizer = AutoTokenizer.from_pretrained(model)
    
    self.max_seq_len = 100  
  
  def __len__(self):
    return 100
    # return min(len(self.duplicates_df.index), len(self.negatives_df.index))
  
  def __getitem__(self, idx):
    # read df 
    dup1_text = str(self.duplicates_df.at[idx, "NaturalLang1"])
    dup1_code = str(self.duplicates_df.at[idx, "ProgrammingLang1"])
    dup2_text = str(self.duplicates_df.at[idx, "NaturalLang2"])
    dup2_code = str(self.duplicates_df.at[idx, "ProgrammingLang2"])

    neg_text = str(self.negatives_df.at[idx, "NaturalLang"])
    neg_code = str(self.negatives_df.at[idx, "ProgrammingLang"])


    # tokenize
    dup1 = self.tokenizer(dup1_text, dup1_code, truncation=True, max_length=self.max_seq_len, return_tensors="pt")
    dup2 = self.tokenizer(dup2_text, dup2_code, truncation=True, max_length=self.max_seq_len, return_tensors="pt")
    neg = self.tokenizer(neg_text, neg_code, truncation=True, max_length=self.max_seq_len, return_tensors="pt")

    # squeeze dimensions (necessary for batching)
    dup1['input_ids'] = torch.squeeze(dup1['input_ids'])
    dup1['attention_mask'] = torch.squeeze(dup1['attention_mask'])
    dup2['input_ids'] = torch.squeeze(dup2['input_ids'])
    dup2['attention_mask'] = torch.squeeze(dup2['attention_mask'])
    neg['input_ids'] = torch.squeeze(neg['input_ids'])
    neg['attention_mask'] = torch.squeeze(neg['attention_mask'])

    return dup1, dup2, neg 

# dataloader

In [None]:
# dynamic padding -> pad all sequences to longest sequence in current batch 
tokenizer = AutoTokenizer.from_pretrained("UWB-AIR/MQDD-pretrained")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def my_collate_fn(batch):
  dups1 = [item[0] for item in batch] # dup1
  dups2 = [item[1] for item in batch] # dup2
  negs = [item[2] for item in batch] # neg

  dups1 = data_collator(dups1)
  dups2 = data_collator(dups2)
  negs = data_collator(negs)

  return dups1, dups2, negs 

# triplet loss

In [None]:
def get_triplet_loss_labels(batch_size):
  """
  N = batch size 
  labels = [0, ..., N-1, 0, ..., N-1, N, ..., 2N-1]
  """
  d_labels = torch.arange(batch_size) 
  n_labels = torch.arange(batch_size, 2*batch_size)
  labels = torch.cat((d_labels, d_labels, n_labels)) 
  labels = labels.to(device)

  return labels 

In [None]:
def calc_pairwise_distances(embeddings):
  """
  distances: (B, B)
  distances[i, j] = squared euclidian distance of embeddings i and j
  """
  dot_product = embeddings @ embeddings.T 
  square_norm = torch.diag(dot_product) 
  distances = square_norm.unsqueeze(0) - 2*dot_product + square_norm.unsqueeze(1)
  distances[distances < 0] = 0 
  return distances

def get_triplet_mask(labels):
  """
  mask[i, j, k] = True if and only if:
    - i != j != k  
    - labels[i] == labels[j] and labels[i] != labels[k]
  """
  n = labels.shape[0]
  device = labels.device

  # indices
  indices_neq = ~torch.eye(n, device=device).bool() 

  i_neq_j = indices_neq.unsqueeze(2)
  i_neq_k = indices_neq.unsqueeze(1)
  j_neq_k = indices_neq.unsqueeze(0) 

  distinct_indices = i_neq_j & i_neq_k & j_neq_k 

  # labels 
  labels_eq = (labels.unsqueeze(0) == labels.unsqueeze(1)) 

  i_eq_j = labels_eq.unsqueeze(2)
  i_eq_k = labels_eq.unsqueeze(1) 

  valid_labels = i_eq_j & (~i_eq_k)

  valid_triplets = distinct_indices & valid_labels 
  return valid_triplets

def get_anchor_positive_triplet_mask(labels):
  """
  mask[i, j] True iff:
    - i != j
    - labels[i] == labels[j]
  """
  n = labels.shape[0]
  device = labels.device

  # indices
  indices_neq = ~torch.eye(n, device=device).bool() 

  # labels
  labels_eq = labels.unsqueeze(0) == labels.unsqueeze(1)

  return labels_eq & indices_neq

def get_anchor_negative_triplet_mask(labels):
  """
  mask[i, k] True iff
    - labels[i] ~= labels[k]
  """
  return ~(labels.unsqueeze(0) == labels.unsqueeze(1))

In [None]:
def batch_all_triplet_loss(embeddings, labels, margin=0.05):
  pairwise_distances = calc_pairwise_distances(embeddings)

  anchor_positive_distances = pairwise_distances.unsqueeze(2) # (4,4,1)
  anchor_negative_distances = pairwise_distances.unsqueeze(1) # (4,1,4)

  triplet_loss = anchor_positive_distances - anchor_negative_distances + margin 

  mask = get_triplet_mask(labels)
  triplet_loss = triplet_loss * mask.float().to(device) # valid triplets

  triplet_loss = F.relu(triplet_loss) # semi-hard/hard triplets 

  positive_triplets = triplet_loss[triplet_loss > 1e-16]
  num_positive_triplets = positive_triplets.size(0)

  triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)

  return triplet_loss

In [None]:
def batch_hard_triplet_loss(embeddings, labels, margin=0.05):
  pairwise_distances = calc_pairwise_distances(embeddings)

  # hardest positive
  anchor_positive_mask = get_anchor_positive_triplet_mask(labels).float()
  anchor_positive_distances = pairwise_distances * anchor_positive_mask 
  hardest_positive_distances, _ = anchor_positive_distances.max(1, keepdim=True)

  # hardest negative
  anchor_negative_mask = get_anchor_negative_triplet_mask(labels).float()
  max_anchor_negative_distances, _ = pairwise_distances.max(1, keepdim=True)
  anchor_negative_distances = pairwise_distances + max_anchor_negative_distances * (1.0 - anchor_negative_mask)
  hardest_negative_distances, _ = anchor_negative_distances.min(1, keepdim=True) 

  # loss
  triplet_loss = hardest_positive_distances - hardest_negative_distances + margin 
  triplet_loss = F.relu(triplet_loss)
  triplet_loss = triplet_loss.mean()
  return triplet_loss

# evaluation

In [None]:
def calc_pairwise_cosine_similarity(embeddings):
    embeddings = F.normalize(embeddings, p=2, dim=1)

    pairwise_cos_sims = embeddings @ embeddings.T
    return pairwise_cos_sims.to(device)

In [None]:
def get_recall_labels(batch_size):
  l1 = torch.arange(batch_size, 2*batch_size)
  l2 = torch.arange(batch_size)
  labels = torch.cat((l1, l2)).long()
  return labels.to(device)

# pytorch lightning

In [None]:
class TripletLightningModule(pl.LightningModule):
  def __init__(self, config):
    super().__init__()
    
    self.model_name = config['model']
    self.model = AutoModel.from_pretrained(config['model'])
    
    self.batch_size = config['batch_size']

    self.loss = config['loss'] 
    self.margin = config['margin']

    self.lr = config['lr']

    self.save_hyperparameters() 

    # triplet loss
    self.triplet_labels = get_triplet_loss_labels(config['batch_size']).to(device)

    # recall 
    self.recall_labels = get_recall_labels(self.batch_size).to(device)

    self.train_recall = MulticlassRecall(num_classes=config['batch_size']*3, top_k=5).to(device)
    self.val_recall = MulticlassRecall(num_classes=config['batch_size']*3, top_k=5).to(device)
    self.test_recall = MulticlassRecall(num_classes=config['batch_size']*3, top_k=5).to(device)

  
  def training_step(self, batch, batch_idx):
    embeddings = self._get_embeddings(batch)
    
    # loss 
    loss = self._get_loss(embeddings)

    # recall 
    dupl_cos_sims = self._get_dupl_cos_sims(embeddings).to(device)
    recall = self.train_recall(dupl_cos_sims, self.recall_labels)
  
    self.log("train_loss", loss, on_step=True)
    self.log("train_recall", recall, on_step=True)

    return loss 
  
  def validation_step(self, batch, batch_idx):
    embeddings = self._get_embeddings(batch)
    
    # loss
    loss = self._get_loss(embeddings)
    
    # recall
    dupl_cos_sims = self._get_dupl_cos_sims(embeddings).to(device)
    recall = self.val_recall(dupl_cos_sims, self.recall_labels)
    
    # log
    self.log("val_loss", loss)
    self.log("val_recall", recall)

    return loss 
  
  def test_step(self, batch, batch_idx):
    embeddings = self._get_embeddings(batch)
    
    # loss
    loss = self._get_loss(embeddings)
    
    # recall
    dupl_cos_sims = self._get_dupl_cos_sims(embeddings).to(device)
    recall = self.test_recall(dupl_cos_sims, self.recall_labels)
    
    # log
    self.log("val_loss", loss)
    self.log("val_recall", recall)

    return loss 

  
  def configure_optimizers(self):
    optimizer = torch.optim.SGD(self.parameters(), lr=self.lr)
    return optimizer 


  def _get_embeddings(self, batch):
    d1, d2, n = batch 

    d1 = self.model(**d1)[1] 
    d2 = self.model(**d2)[1] 
    n = self.model(**n)[1] 

    embeddings = torch.cat((d1, d2, n), dim=0) 
    return embeddings
  
  def _get_loss(self, embeddings):
    if self.loss == 'all':
      loss = batch_all_triplet_loss(embeddings, self.triplet_labels, margin=self.margin)
    else:
      loss = batch_hard_triplet_loss(embeddings, self.triplet_labels, margin=self.margin)
    
    return loss 

  def _get_dupl_cos_sims(self, embeddings):
    pairwise_cos_sims = calc_pairwise_cosine_similarity(embeddings)
    
    # mask out diagonal
    mask = 1 - torch.eye(self.batch_size*3).float().to(device)
    pairwise_cos_sims *= mask
    
    # ignore negatives 
    dupl_cos_sims = pairwise_cos_sims[:2*self.batch_size]
    return dupl_cos_sims.to(device)


# run

In [None]:
losses = ['all', 'hard']
models = ['UWB-AIR/MQDD-pretrained', 'microsoft/codebert-base']

config = {
    'model': models[0]
    'batch_size': 4,
    'loss': losses[0],
    'margin': 0.05,
    'lr': 0.001,
}

In [None]:
dataset = TripletDataset(config['model'])
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.7, 0.15, 0.15])

In [None]:
batch_size = config['batch_size'] 

train_dataloader = DataLoader(train_dataset, batch_size, drop_last=True, collate_fn=my_collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size, drop_last=True, collate_fn=my_collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size, drop_last=True, collate_fn=my_collate_fn)

In [None]:
project_name = 'stack-overflow-duplicate-detection'  

wandb.init(project=project_name)
wandb_logger = WandbLogger(project=project_name, log_model='all') 

In [None]:
triplet_lightning_module = TripletLightningModule(config)

checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')
trainer = pl.Trainer(max_epochs=5, logger=wandb_logger, callbacks=[checkpoint_callback], log_every_n_steps=1)
trainer.fit(triplet_lightning_module, train_dataloader, val_dataloader)

In [None]:
trainer.test(triplet_lightning_module, dataloaders=test_dataloader)