# Finetuning the language model on the computer science papers dataset using Contrastive loss


This notebook demonstrates the process of fine-tuning a language model on a computer science papers dataset using contrastive loss. The workflow includes data preparation, model definition, and training. Below is a summary of the key steps:

1. **Introduction**
    - Finetuning the language model on the computer science papers dataset using Contrastive loss.

2. **Import necessary packages**
    - Libraries such as `torch`, `pandas`, `transformers`, and `datasets` are imported.

3. **Utility Functions**
    - Functions for saving/loading parameters, calculating similarity matrices, transferring batches to devices, and counting model parameters are defined.

4. **Data Preparation**
    - Loading and preprocessing of paper submission data and journal aims.
    - Merging datasets and constructing pairs for contrastive fine-tuning.

5. **Data Loading and Preprocessing**
    - Loading the dataset and preparing features using a tokenizer.

6. **Data Loader**
    - Creating a custom dataset class and data loader for batching.

7. **Model Definition**
    - Defining a pooler layer for different pooling operations.
    - Implementing a model for contrastive learning using a pre-trained BERT model.

8. **Contrastive Loss**
    - Defining a supervised contrastive loss function.

9. **Model Declaration**
    - Initializing the model and moving it to the appropriate device.

10. **Training**
     - Setting up the optimizer, learning rate scheduler, and loss function.
     - Training the model for a specified number of epochs and saving the model with the lowest loss.

This notebook provides a comprehensive guide to fine-tuning a language model using contrastive learning, including data preparation, model implementation, and training.


## Import necessary packages

In [1]:
import os
import torch
import numpy as np
import pandas as pd
import pickle
import random
from numpy import ndarray
from torch import Tensor
from typing import Union, List, Dict
from multiprocessing import cpu_count
from tqdm.notebook import trange, tqdm
from torch import nn
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions


## Some useful functions

In [2]:
# Utils

def save_parameter(save_object, save_file):
    
    with open(save_file, 'wb') as f:
        pickle.dump(save_object, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_parameter(load_file):
    with open(load_file, 'rb') as f:
        output = pickle.load(f)
    return output

def sim_matrix(a, b, eps=1e-8):
    """
    Calculate cosine similarity between two matrices.
    Note: added eps for numerical stability
    """
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.clamp(a_n, min=eps)
    b_norm = b / torch.clamp(b_n, min=eps)
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt

def batch2device(batch, device):
    """
    Transfers all tensors in a batch to the specified device.
    Args:
        batch (dict): A dictionary where the values are tensors.
        device (torch.device or str): The device to which the tensors should be moved.
    Returns:
        dict: The batch with all tensors moved to the specified device.
    """
    
    for key, value in batch.items():
        batch[key] = batch[key].to(device)
    return batch

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# GPU accelerator
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")


# Data preparation

## Data Loading and Preprocessing

In [4]:
# working dir
work_path = "./" # Removed ./ from the beginning
checkpoint_path = "./checkpoint/"

In [5]:
# Load paper submission data and Journal's aims
data_train = pd.read_csv("./data/preprocessed_data_com_sci/01_train.csv", encoding = "ISO-8859-1")
data_aims = pd.read_csv("./data/preprocessed_data_com_sci/01_aims.csv", encoding = "ISO-8859-1")

data_train.fillna("", inplace=True)
data_aims.fillna("", inplace=True)

# merge two tables respect to Label and Index
merged_df = pd.merge(data_train[['Title', 'Abstract', 'Keywords', 'Label']], data_aims[['Aims']], right_index=True, left_on='Label')


# construct set of pairs for contrastive fine-tuning
train_pairs = pd.DataFrame({'TAK': merged_df['Title'] + ' ' + merged_df['Abstract'] + ' ' + merged_df['Keywords'],
                            'Aims': data_aims['Aims']})
train_pairs.to_csv(work_path + "train_pairs.csv", index=False)

In [6]:
print(data_train.columns)
print(data_aims.columns)

Index(['Title', 'Abstract', 'Keywords', 'itr', 'Label'], dtype='object')
Index(['Aims', 'itr'], dtype='object')


## Load saved pairs for training

In [7]:
data_args = {
    "train_file": work_path + "train_pairs.csv",
    "preprocessing_num_workers": None
}
data_files = {
    "train": data_args["train_file"]
}
tokenizer_kwargs = {
    "pretrained_path": "distilbert/distilroberta-base",
    "use_fast": True,
    "max_seq_length": 300,
    "pad_to_max_length": True,
    "truncation": True,
    "return_tensors": None
}

datasets = load_dataset("csv", data_files=data_files)
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_kwargs["pretrained_path"],
    use_fast=tokenizer_kwargs["use_fast"]
)
column_names = datasets["train"].column_names

def prepare_features(examples):
    total = len(examples[column_names[0]])
    for idx in range(total):
        if examples[column_names[0]][idx] is None:
            examples[column_names[0]][idx] = " "
        if examples[column_names[1]][idx] is None:
            examples[column_names[1]][idx] = " "
    sentences = examples[column_names[0]] + examples[column_names[1]]
    sent_features = tokenizer(
            sentences,
            max_length=tokenizer_kwargs["max_seq_length"],
            truncation=True,
            padding="max_length" if tokenizer_kwargs["pad_to_max_length"] else False,
            return_tensors=tokenizer_kwargs["return_tensors"]
        )
    features = {}
    for key in sent_features:
        features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)]
    return features

train_dataset = datasets["train"].map(
    prepare_features,
    batched=True,
    num_proc=data_args["preprocessing_num_workers"],
    remove_columns=column_names
)

Generating train split: 0 examples [00:00, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Map:   0%|          | 0/298317 [00:00<?, ? examples/s]

## Data loader

In [8]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx):
        x = {
            key: torch.tensor(val) for key, val in self.dataset[idx].items()
        }
        return x
    def __len__(self):
        return len(self.dataset)

In [9]:
dataset = Dataset(train_dataset)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

# Model definition

## Pooler layer

In [10]:
class Pooler(nn.Module):
    """
    Pooler class to perform different types of pooling operations on the output of a transformer model.
    Args:
        pooler_type (str): The type of pooling to perform. Must be one of ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"].
    Methods:
        forward(attention_mask, outputs):
            Perform the specified pooling operation on the model outputs.
            Args:
                attention_mask (torch.Tensor): The attention mask tensor.
                outputs (transformers.modeling_outputs.BaseModelOutput): The output from the transformer model, containing last_hidden_state and hidden_states.
            Returns:
                torch.Tensor: The pooled output tensor.
    """
    
    def __init__(self, pooler_type):
        super().__init__()
        self.pooler_type = pooler_type
        assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type

    def forward(self, attention_mask, outputs):
        last_hidden = outputs.last_hidden_state
        hidden_states = outputs.hidden_states

        if self.pooler_type in ['cls_before_pooler', 'cls']:
            return last_hidden[:, 0]
        elif self.pooler_type == "avg":
            return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
        elif self.pooler_type == "avg_first_last":
            first_hidden = hidden_states[0]
            last_hidden = hidden_states[-1]
            pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
            return pooled_result
        elif self.pooler_type == "avg_top2":
            second_last_hidden = hidden_states[-2]
            last_hidden = hidden_states[-1]
            pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
            return pooled_result
        else:
            raise NotImplementedError


# Model for contrastive leanring training

In [11]:
class ModelForCL(nn.Module):
    """
    Model for Contrastive Learning (CL) using a pre-trained BERT model.
    Args:
        model_name_or_path (str): Path to the pre-trained model or model identifier from huggingface.co/models.
        pooler_type (str): Type of pooling to be applied. Options include "cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last".
    Methods:
        forward(input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, mlm_input_ids, mlm_labels):
            Forward pass for the model.
            Args:
                input_ids (torch.Tensor): Input tensor of token IDs.
                attention_mask (torch.Tensor): Attention mask to avoid performing attention on padding token indices.
                token_type_ids (torch.Tensor, optional): Segment token indices to indicate different portions of the inputs.
                position_ids (torch.Tensor, optional): Indices of positions of each input sequence tokens in the position embeddings.
                head_mask (torch.Tensor, optional): Mask to nullify selected heads of the self-attention modules.
                inputs_embeds (torch.Tensor, optional): Optionally, instead of passing input_ids you can choose to directly pass an embedded representation.
                labels (torch.Tensor, optional): Labels for computing the contrastive loss.
                output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers.
                output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers.
                return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
                mlm_input_ids (torch.Tensor, optional): Input tensor of token IDs for masked language modeling.
                mlm_labels (torch.Tensor, optional): Labels for masked language modeling.
            Returns:
                BaseModelOutputWithPoolingAndCrossAttentions: A dataclass with the following attributes:
                    - pooler_output (torch.Tensor): The pooled output tensor.
                    - last_hidden_state (torch.Tensor): The last hidden state of the model.
                    - hidden_states (tuple(torch.Tensor), optional): The hidden states of the model at each layer.
    """
    
    
    def __init__(self, model_name_or_path, pooler_type):
        super(ModelForCL, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name_or_path)
        self.pooler_type = pooler_type
        self.pooler = Pooler(self.pooler_type)

    def forward(self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        mlm_input_ids=None,
        mlm_labels=None,
    ):
        batch_size = input_ids.size(0)
        # Number of sentences in one instance
        # 2: pair instance; 3: pair instance with a hard negative
        num_sent = input_ids.size(1)

        # Flatten input for encoding
        input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
        attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent, len)
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)

        # Get raw embeddings
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=True if self.pooler_type in ['avg_top2', 'avg_first_last'] else False,
            return_dict=return_dict,
        )

        # Pooling
        if self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"]:
            pooler_output = self.pooler(attention_mask, outputs)
        pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)

        return BaseModelOutputWithPoolingAndCrossAttentions(
            pooler_output=pooler_output,
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
        )

## Contrastive Loss

In [12]:
class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.sim = nn.CosineSimilarity()
    def _eval_denom(self, z1, z2):
        cosine_vals = []
        for v in z1:
            cosine_vals.append(self.sim(v.view(1,-1), z2)/self.temperature)
        cos_batch = torch.cat(cosine_vals, dim=0).view(z1.shape[0], -1)
        denom = torch.sum(torch.exp(cos_batch),dim=1)
        return denom
    def _contrastive_loss(self, z1, z2):
        num = torch.exp(self.sim(z1, z2)/self.temperature)
        denom = self._eval_denom(z1, z2)
        loss = -torch.mean(torch.log(num/denom))
        return loss
    def forward(self, z1, z2):
        return self._contrastive_loss(z1, z2)

## Model declaration

In [13]:
# device = xm.xla_device()
model_args = {
    'model_name_or_path': 'distilbert/distilroberta-base',
    'pooler_type': 'cls_before_pooler'
}
model = ModelForCL(**model_args)
model.to(device)

model.safetensors:   0%|          | 0.00/331M [00:00<?, ?B/s]

ModelForCL(
  (bert): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-5): 6 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm

# Training

## Optimizer and configuration

In [14]:

"""
This script sets up the optimizer, learning rate scheduler, and loss function for fine-tuning an embedding model.

Variables:
    decayRate (float): The factor by which the learning rate is reduced at each step.
    optimizer (torch.optim.Optimizer): The optimizer used for training the model, specifically AdamW.
    lr_scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler that reduces the learning rate by `decayRate` every `step_size` epochs.
    loss_fn (SupervisedContrastiveLoss): The loss function used for training, specifically Supervised Contrastive Loss with a temperature parameter of 0.1.
"""
decayRate = 0.86

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=2, gamma=decayRate)

loss_fn = SupervisedContrastiveLoss(0.1)

In [15]:
"""
Trains a model using contrastive learning for a specified number of epochs and saves the model with the lowest loss.

Variables:
    min_loss (float): The minimum loss observed during training, initialized to infinity.
    max_epochs (int): The maximum number of epochs to train the model.
    epoch (int): The current epoch number.
    loop (tqdm): A tqdm progress bar for the data loader.
    train_loss (float): The cumulative training loss for the current epoch.
    batch (dict): A batch of data from the data loader.
    inputs (dict): The batch of samples transferred to the specified device.
    outputs (dict): The model outputs after a forward pass.
    z1 (Tensor): The first part of the separated representation from the model output.
    z2 (Tensor): The second part of the separated representation from the model output.
    loss (Tensor): The computed loss for the current batch.
    save_path (str): The path to save the model checkpoint.

Functions:
    batch2device(batch, device): Transfers a batch of samples to the specified device.
    loss_fn(z1, z2): Computes the loss between the two representations.
    lr_scheduler.step(): Updates the learning rate according to the scheduler.
    torch.save(): Saves the model and optimizer state dictionaries along with the minimum loss and epoch number.
"""
min_loss = np.inf
max_epochs = 8
for epoch in range(max_epochs):
    loop = tqdm(data_loader, leave=True)
    train_loss = 0.0

    for batch in loop:
        optimizer.zero_grad()

        # if torch.cuda.is_available():
        # Transfer batch of samples to GPU
        inputs = batch2device(batch, device)

        # forward
        outputs = model(**inputs)
        # Separate representation
        z1, z2 = outputs.pooler_output[:, 0], outputs.pooler_output[:, 1]

        # backward
        loss = loss_fn(z1, z2)
        loss.backward()
        train_loss += loss.item()
        # Update Weights
        optimizer.step()
        # xm.optimizer_step(optimizer) 

        loop.set_description('Epoch: {} - lr:{}'.format(epoch, optimizer.param_groups[0]['lr']))
        loop.set_postfix(loss=loss.item())
    train_loss = train_loss / len(data_loader)
    lr_scheduler.step()
    if min_loss > train_loss:
        print(f">> Loss Decreased({min_loss:.6f}--->{train_loss:.6f})")
        min_loss = train_loss
        save_path = work_path + 'saved_model/'
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        torch.save({
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "min_loss": min_loss,
            "epoch": epoch
        }, save_path + "Epoch:{:0>2} SupCL-.pth".format(epoch))

  0%|          | 0/4662 [00:00<?, ?it/s]

>> Loss Decreased(inf--->4.158542)


  0%|          | 0/4662 [00:00<?, ?it/s]

>> Loss Decreased(4.158542--->4.158541)


  0%|          | 0/4662 [00:00<?, ?it/s]

>> Loss Decreased(4.158541--->4.158539)


  0%|          | 0/4662 [00:00<?, ?it/s]

>> Loss Decreased(4.158539--->4.158436)


  0%|          | 0/4662 [00:00<?, ?it/s]

KeyboardInterrupt: 