In [1]:
import time, random, numpy as np, argparse, sys, re, os
from types import SimpleNamespace
import torch
from torch import nn
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.amp import autocast
from contextlib import nullcontext
from tqdm import tqdm
from itertools import cycle
import copy
from transformers import BertModel, BertTokenizer
import pandas as pd

In [2]:
# fix the random seed
def seed_everything(seed= 10002):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [3]:
'''
This module contains our Dataset classes and functions to load the 3 datasets we're using.
You should only need to call load_multitask_data to get the training and dev examples
to train your model.
'''

import csv
import torch
from torch.utils.data import Dataset

def preprocess_string(s):
    '''Preprocesses a string by lowercasing it and adding spaces around punctuation.'''
    return ' '.join(s.lower().replace('.', ' .').replace('?', ' ?').replace(',', ' ,').replace('\'', ' \'').split())


class SentenceClassificationDataset(Dataset):
    '''This class is a wrapper around the dataset with one sentence inputs that we will use to train our model.
    (ie. the SST dataset)'''
    def __init__(self, dataset):
        self.dataset = dataset
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def pad_data(self, data):
        '''This function pads the data to the max length of the batch.'''
        sents = [x[0] for x in data]
        labels = [x[1] for x in data]
        sent_ids = [x[2] for x in data]

        encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
        token_ids = torch.LongTensor(encoding['input_ids'])
        attention_mask = torch.LongTensor(encoding['attention_mask'])
        labels = torch.LongTensor(labels)

        return token_ids, attention_mask, labels, sents, sent_ids

    def collate_fn(self, all_data):
        token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)

        batched_data = {
                'token_ids': token_ids,
                'attention_mask': attention_mask,
                'labels': labels,
                'sents': sents,
                'sent_ids': sent_ids
            }

        return batched_data


class SentenceClassificationTestDataset(Dataset):
    '''This class is a wrapper around the dataset with one sentence inputs that we will use to test our model.
    (ie. the SST dataset)'''
    def __init__(self, dataset):
        self.dataset = dataset
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def pad_data(self, data):
        '''This function pads the data to the max length of the batch.'''
        sents = [x[0] for x in data]
        sent_ids = [x[1] for x in data]

        encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
        token_ids = torch.LongTensor(encoding['input_ids'])
        attention_mask = torch.LongTensor(encoding['attention_mask'])

        return token_ids, attention_mask, sents, sent_ids

    def collate_fn(self, all_data):
        token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)

        batched_data = {
                'token_ids': token_ids,
                'attention_mask': attention_mask,
                'sents': sents,
                'sent_ids': sent_ids
            }

        return batched_data


class SentencePairDataset(Dataset):
    '''This class is a wrapper around the dataset with pair sentences that we will use to train our model.
    (ie. A class for handling the SemEval and Quora datasets.)'''
    def __init__(self, dataset, isRegression =False):
        self.dataset = dataset
        # self.p = args
        self.isRegression = isRegression
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def pad_data(self, data):
        '''This function pads the data to the max length of the batch.'''
        sent1 = [x[0] for x in data]
        sent2 = [x[1] for x in data]
        labels = [x[2] for x in data]
        sent_ids = [x[3] for x in data]

        encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
        encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)

        token_ids = torch.LongTensor(encoding1['input_ids'])
        attention_mask = torch.LongTensor(encoding1['attention_mask'])
        token_type_ids = torch.LongTensor(encoding1['token_type_ids'])

        token_ids2 = torch.LongTensor(encoding2['input_ids'])
        attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
        token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
        if self.isRegression:
            labels = torch.FloatTensor(labels)
        else:
            labels = torch.LongTensor(labels)
            

        return (token_ids, token_type_ids, attention_mask,
                token_ids2, token_type_ids2, attention_mask2,
                labels,sent_ids)

    def collate_fn(self, all_data):
        (token_ids, token_type_ids, attention_mask,
         token_ids2, token_type_ids2, attention_mask2,
         labels, sent_ids) = self.pad_data(all_data)

        batched_data = {
                'token_ids_1': token_ids,
                'token_type_ids_1': token_type_ids,
                'attention_mask_1': attention_mask,
                'token_ids_2': token_ids2,
                'token_type_ids_2': token_type_ids2,
                'attention_mask_2': attention_mask2,
                'labels': labels,
                'sent_ids': sent_ids
            }

        return batched_data


class SentencePairTestDataset(Dataset):
    '''This class is a wrapper around the dataset with pair sentences that we will use to test our model.
    (ie. A class for handling the SemEval and Quora datasets.)'''
    def __init__(self, dataset, tokenizer, args):
        self.dataset = dataset
        self.p = args
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def pad_data(self, data):
        '''This function pads the data to the max length of the batch'''
        sent1 = [x[0] for x in data]
        sent2 = [x[1] for x in data]
        sent_ids = [x[2] for x in data]

        encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
        encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)

        token_ids = torch.LongTensor(encoding1['input_ids'])
        attention_mask = torch.LongTensor(encoding1['attention_mask'])
        token_type_ids = torch.LongTensor(encoding1['token_type_ids'])

        token_ids2 = torch.LongTensor(encoding2['input_ids'])
        attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
        token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])


        return (token_ids, token_type_ids, attention_mask,
                token_ids2, token_type_ids2, attention_mask2,
               sent_ids)

    def collate_fn(self, all_data):
        (token_ids, token_type_ids, attention_mask,
         token_ids2, token_type_ids2, attention_mask2,
         sent_ids) = self.pad_data(all_data)

        batched_data = {
                'token_ids_1': token_ids,
                'token_type_ids_1': token_type_ids,
                'attention_mask_1': attention_mask,
                'token_ids_2': token_ids2,
                'token_type_ids_2': token_type_ids2,
                'attention_mask_2': attention_mask2,
                'sent_ids': sent_ids
            }

        return batched_data


def load_multitask_test_data():
    '''This function loads the test datasets for the multitask dataset.'''
    paraphrase_filename = f'data/quora-test.csv'
    sentiment_filename = f'data/ids-sst-test.txt'
    similarity_filename = f'data/sts-test.csv'

    sentiment_data = []

    with open(sentiment_filename, 'r') as fp:
        for record in csv.DictReader(fp,delimiter = '\t'):
            sent = record['sentence'].lower().strip()
            sentiment_data.append(sent)

    print(f"Loaded {len(sentiment_data)} test examples from {sentiment_filename}")

    paraphrase_data = []
    with open(paraphrase_filename, 'r') as fp:
        for record in csv.DictReader(fp,delimiter = '\t'):
            #if record['split'] != split:
            #    continue
            paraphrase_data.append((preprocess_string(record['sentence1']),
                                    preprocess_string(record['sentence2']),
                                    ))

    print(f"Loaded {len(paraphrase_data)} test examples from {paraphrase_filename}")

    similarity_data = []
    with open(similarity_filename, 'r') as fp:
        for record in csv.DictReader(fp,delimiter = '\t'):
            similarity_data.append((preprocess_string(record['sentence1']),
                                    preprocess_string(record['sentence2']),
                                    ))

    print(f"Loaded {len(similarity_data)} test examples from {similarity_filename}")

    return sentiment_data, paraphrase_data, similarity_data



def load_multitask_data( sentiment_filename, paraphrase_filename, similarity_filename,split='train'):
    '''This function loads the training datasets for the multitask dataset'''
    sentiment_data = []
    num_labels = {}
    if split == 'test':
        with open(sentiment_filename, 'r') as fp:
            for record in csv.DictReader(fp,delimiter = '\t'):
                sent = record['sentence'].lower().strip()
                sent_id = record['id'].lower().strip()
                sentiment_data.append((sent,sent_id))
    else:
        with open(sentiment_filename, 'r') as fp:
            for record in csv.DictReader(fp,delimiter = '\t'):
                sent = record['sentence'].lower().strip()
                sent_id = record['id'].lower().strip()
                label = int(record['sentiment'].strip())
                if label not in num_labels:
                    num_labels[label] = len(num_labels)
                sentiment_data.append((sent, label,sent_id))

    print(f"Loaded {len(sentiment_data)} {split} examples from {sentiment_filename}")

    paraphrase_data = []
    if split == 'test':
        with open(paraphrase_filename, 'r') as fp:
            for record in csv.DictReader(fp,delimiter = '\t'):
                sent_id = record['id'].lower().strip()
                paraphrase_data.append((preprocess_string(record['sentence1']), preprocess_string(record['sentence2']), sent_id))
    else:
        with open(paraphrase_filename, 'r') as fp:
            for record in csv.DictReader(fp,delimiter = '\t'):
                try:
                    sent_id = record['id'].lower().strip()
                    paraphrase_data.append((preprocess_string(record['sentence1']), preprocess_string(record['sentence2']), int(float(record['is_duplicate'])),sent_id))
                except:
                    pass

    print(f"Loaded {len(paraphrase_data)} {split} examples from {paraphrase_filename}")

    similarity_data = []
    if split == 'test':
        with open(similarity_filename, 'r') as fp:
            for record in csv.DictReader(fp,delimiter = '\t'):
                sent_id = record['id'].lower().strip()
                similarity_data.append((preprocess_string(record['sentence1']), preprocess_string(record['sentence2']) ,sent_id))
    else:
        with open(similarity_filename, 'r') as fp:
            for record in csv.DictReader(fp,delimiter = '\t'):
                sent_id = record['id'].lower().strip()
                similarity_data.append((preprocess_string(record['sentence1']), preprocess_string(record['sentence2']), float(record['similarity']),sent_id))

    print(f"Loaded {len(similarity_data)} {split} examples from {similarity_filename}")

    return sentiment_data, num_labels, paraphrase_data, similarity_data

In [4]:
class Bert_MultiTask(nn.Module):
    """"
    config: {
        'hidden_dropout_prob':
        'option':
        'num_hidden_lrs'
    }
    """
    def __init__(self,config):
        super(Bert_MultiTask, self).__init__()
        self.model = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa");  
        self.tokenizer= BertTokenizer.from_pretrained("bert-base-uncased")
        BERT_HIDDEN_SIZE = 768
        
        N_SENTIMENT_CLASSES = 5
        N_STS_CLASSES = 6

        
        # setting the model for finetuning or full training
        for param in self.model.parameters():
            if config.option == 'finetune':
                param.requires_grad = True
            else:
                param.requires_grad = False
        
        # defining the lineear layers for sentiment classification
        self.dropout_sentiment = nn.ModuleList([nn.Dropout(config.hidden_dropout_prob) for _ in range(config.n_hidden_layers + 1)])
        self.linear_sentiment = nn.ModuleList([nn.Linear(BERT_HIDDEN_SIZE, BERT_HIDDEN_SIZE, dtype=torch.float16) for _ in range(config.n_hidden_layers)] + [nn.Linear(BERT_HIDDEN_SIZE, N_SENTIMENT_CLASSES, dtype=torch.float16)])
        self.last_linear_sentiment = None

        # Step 3: Add a linear layer for paraphrase detection
        self.dropout_paraphrase = nn.ModuleList([nn.Dropout(config.hidden_dropout_prob) for _ in range(config.n_hidden_layers + 1)])
        self.linear_paraphrase = nn.ModuleList([nn.Linear(BERT_HIDDEN_SIZE, BERT_HIDDEN_SIZE, dtype=torch.float16) for _ in range(config.n_hidden_layers)] + [nn.Linear(BERT_HIDDEN_SIZE, 1, dtype=torch.float16)])

        # Step 4: Add a linear layer for semantic textual similarity
        # This is a regression task, so the output should be a single number
        self.dropout_similarity = nn.ModuleList([nn.Dropout(config.hidden_dropout_prob) for _ in range(config.n_hidden_layers + 1)])
        self.linear_similarity = nn.ModuleList([nn.Linear(BERT_HIDDEN_SIZE, BERT_HIDDEN_SIZE,dtype=torch.float16) for _ in range(config.n_hidden_layers)] + [nn.Linear(BERT_HIDDEN_SIZE, 1,dtype=torch.float16)])

    def forward(self, input_ids, attention_mask, task_id):
        # gives embeddings for the batch of sentences
        # The final BERT embedding is the hidden state of [CLS] token (the first token)
        # Here, you can start by just returning the embeddings straight from BERT.
        # When thinking of improvements, you can later try modifying this
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        
        # Extract the [CLS] token embedding (the first token's hidden state)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # (batch_size, hidden_size)

        # Combine CLS embedding with task embedding (e.g., addition, concatenation)
        combined_embedding = cls_embedding
        return combined_embedding
    
    def last_layers_sentiment(self, x):
        """Given a batch of sentences embeddings, outputs logits for classifying sentiment."""
        for i in range(len(self.linear_sentiment) - 1):
            x = self.dropout_sentiment[i](x)
            x = self.linear_sentiment[i](x)
            x = F.relu(x)

        x = self.dropout_sentiment[-1](x)
        logits = self.linear_sentiment[-1](x)
        # logits = F.softmax(logits, dim=1)
        return logits
    
    def predict_sentiment(self, input_ids, attention_mask):
        '''Given a batch of sentences, outputs logits for classifying sentiment.
        There are 5 sentiment classes:
        (0 - negative, 1- somewhat negative, 2- neutral, 3- somewhat positive, 4- positive)
        Thus, your output should contain 5 logits for each sentence.
        '''
        x = self.forward(input_ids, attention_mask, task_id=0 )
        x = self.last_layers_sent(x)
        return x

    def get_similarity_paraphrase_embeddings(self, input_ids_1, attention_mask_1, input_ids_2, attention_mask_2, task_id):
        '''Given a batch of pairs of sentences, get the BERT embeddings.'''
        # Get [SEP] token ids
        sep_token_id = torch.tensor([self.tokenizer.sep_token_id], dtype=torch.long, device=input_ids_1.device)
        batch_sep_token_id = sep_token_id.repeat(input_ids_1.shape[0], 1)

        # Concatenate the two sentences in: sent1 [SEP] sent2 [SEP]
        input_id = torch.cat((input_ids_1, batch_sep_token_id, input_ids_2, batch_sep_token_id), dim=1)
        attention_mask = torch.cat((attention_mask_1, torch.ones_like(batch_sep_token_id), attention_mask_2, torch.ones_like(batch_sep_token_id)), dim=1)

        # Get the BERT embeddings
        x = self.forward(input_id, attention_mask, task_id=task_id)

        return x

    def last_layers_paraphrase(self, x):
        """Given a batch of pairs of sentences embedding, outputs logits for predicting whether they are paraphrases."""
        #Step 2: Hidden layers
        for i in range(len(self.linear_paraphrase) - 1):
            x = self.dropout_paraphrase[i](x)
            x = self.linear_paraphrase[i](x)
            x = F.relu(x)

        # Step 3: Final layer
        x = self.dropout_paraphrase[-1](x)
        logits = self.linear_paraphrase[-1](x)
        # logits = torch.sigmoid(logits)
        return logits

    def predict_paraphrase(self,
                           input_ids_1, attention_mask_1,
                           input_ids_2, attention_mask_2):
        '''Given a batch of pairs of sentences, outputs a single logit corresponding to how similar they are.
        Note that your output should be unnormalized (a logit).
        '''
        # Step 1: Get the BERT embeddings
        x = self.get_similarity_paraphrase_embeddings(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2, task_id=1)
        return self.last_layers_paraphrase(x)


    def last_layers_similarity(self, x):
        """Given a batch of pairs of sentences embeddings, outputs logits for predicting how similar they are."""
        # Step 3: Hidden layers
        for i in range(len(self.linear_similarity) - 1):
            x = self.dropout_similarity[i](x)
            x = self.linear_similarity[i](x)
            x = F.relu(x)

        # Step 4: Final layer
        x = self.dropout_similarity[-1](x)
        preds = self.linear_similarity[-1](x)
        # preds = torch.sigmoid(preds) * 6 - 0.5 # Scale to [-0.5, 5.5]

        # # If we are evaluating, then we cap the predictions to the range [0, 5]
        # if not self.training:
        #     preds = torch.clamp(preds, 0, 5)
        return preds
    
    def predict_similarity(self,
                           input_ids_1, attention_mask_1,
                           input_ids_2, attention_mask_2):
        '''Given a batch of pairs of sentences, outputs a single logit corresponding to how similar they are.
        Note that your output should be unnormalized (a logit); it will be passed to the sigmoid function
        during evaluation, and handled as a logit by the appropriate loss function.
        '''
        # Step 1 : Get the BERT embeddings
        x = self.get_similarity_paraphrase_embeddings(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2, task_id=2)
        return self.last_layers_similarity(x)

In [5]:
class ObjectsGroup:

    def __init__(self, model, optimizer, scaler = None):
        self.model = model
        self.optimizer = optimizer
        self.scaler = scaler
        self.loss_sum = 0

class Scheduler:
    '''A class to manage the learning rate scheduler.'''

    def __init__(self, dataloaders, reset=True):
        self.dataloaders = dataloaders
        self.names = list(dataloaders.keys())
        if reset: self.reset()

    def reset(self):
        self.sst_iter = iter(self.dataloaders['sst'])
        # self.para_iter = iter(self.dataloaders['para'])
        # self.sts_iter = iter(self.dataloaders['sts'])
        self.steps = {'sst': 0}

    def get_SST_batch(self):
        try:
            return next(self.sst_iter)
        except StopIteration:
            self.sst_iter = cycle(self.dataloaders['sst'])
            return next(self.sst_iter)

    def get_Paraphrase_batch(self):
        try:
            return next(self.para_iter)
        except StopIteration:
            self.para_iter = cycle(self.dataloaders['para'])
            return next(self.para_iter)

    def get_STS_batch(self):
        try:
            return next(self.sts_iter)
        except StopIteration:
            self.sts_iter = cycle(self.dataloaders['sts'])
            return next(self.sts_iter)

    def get_batch(self, name: str):
        if name == "sst": return self.get_SST_batch()
        # elif name == "para": return self.get_Paraphrase_batch()
        # elif name == "sts": return self.get_STS_batch()
        raise ValueError(f"Unknown batch name: {name}")

    def process_named_batch(self, objects_group: ObjectsGroup, args: dict, name: str, apply_optimization: bool = True):
        '''Processes a batch of data from the given dataset, and updates the model accordingly.'''
        batch = self.get_batch(name)
        process_fn, gradient_accumulations = None, 0
        if name == "sst":
            process_fn = process_sentiment_batch
            gradient_accumulations = args.gradient_accumulations_sst
        elif name == "para":
            process_fn = process_paraphrase_batch
            gradient_accumulations = args.gradient_accumulations_para
        elif name == "sts":
            process_fn = process_similarity_batch
            gradient_accumulations = args.gradient_accumulations_sts
        else:
            raise ValueError(f"Unknown batch name: {name}")
        
        # Process the batch
        loss_of_batch = 0
        for _ in range(gradient_accumulations):
            loss_of_batch += process_fn(batch, objects_group, args)

        # Update the model
        self.steps[name] += 1
        if apply_optimization: step_optimizer(objects_group, args, step=self.steps[name])

        return loss_of_batch


class RoundRobinScheduler(Scheduler):
    '''A scheduler that processes batches in a round-robin fashion.'''
    def __init__(self, dataloaders):
        super().__init__(dataloaders, reset=False)
        self.reset()

    def reset(self):
        self.index = 0
        return super().reset()

    def process_one_batch(self, epoch: int, num_epochs: int, objects_group: ObjectsGroup, args: dict):
        name = self.names[self.index]
        self.index = (self.index + 1) % len(self.names)
        return name, self.process_named_batch(objects_group, args, name)

def process_sentiment_batch(batch, objects_group: ObjectsGroup, args: dict):
    '''This function processes a batch of SST data. It takes as input a batch of data, a group of objects (model, optimizer, scheduler, etc.), 
    and the arguments. It returns the loss of the batch.'''
    device = args.device
    model, scaler = objects_group.model, objects_group.scaler

    with torch.autocast(device_type='cuda', dtype=torch.float16):
        b_ids, b_mask, b_labels = (batch['token_ids'], batch['attention_mask'], batch['labels'])
        b_ids, b_mask, b_labels = b_ids.to(device), b_mask.to(device), b_labels.to(device)

        embeddings = model.forward(b_ids, b_mask, task_id=0)
        logits = model.last_layers_sentiment(embeddings)
        
        loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
        loss_value = loss.item()
        
        objects_group.loss_sum += loss_value

        if args.projection == "none":
            if args.use_amp: scaler.scale(loss).backward()
            else: loss.backward()
        return loss


def process_paraphrase_batch(batch, objects_group: ObjectsGroup, args: dict):
    '''This function processes a batch of paraphrase data. It takes as input a batch of data, 
    a group of objects (model, optimizer, scheduler, etc.), and the arguments. It returns the loss of the batch.'''
    device = args.device
    model, scaler = objects_group.model, objects_group.scaler

    with autocast('cuda') if args.use_amp else nullcontext():
        b_ids_1, b_mask_1, b_ids_2, b_mask_2, b_labels = (batch['token_ids_1'], batch['attention_mask_1'], batch['token_ids_2'], batch['attention_mask_2'], batch['labels'])
        b_ids_1, b_mask_1, b_ids_2, b_mask_2, b_labels = b_ids_1.to(device), b_mask_1.to(device), b_ids_2.to(device), b_mask_2.to(device), b_labels.to(device)

        embeddings = model.get_similarity_paraphrase_embeddings(b_ids_1, b_mask_1, b_ids_2, b_mask_2, task_id=1)
        preds = model.last_layers_paraphrase(embeddings)
        loss = F.binary_cross_entropy_with_logits(preds.view(-1), b_labels.float(), reduction='sum') / args.batch_size
        loss_value = loss.item()

        #To use smart_regularization
        # if args.use_smart_regularization:
        #     smart_regularization(loss_value, args.smart_weight_regularization, embeddings, preds, model.last_layers_paraphrase)

        objects_group.loss_sum += loss_value
        
        if args.projection == "none":
            if args.use_amp: scaler.scale(loss).backward()
            else: loss.backward()
        return loss


def process_similarity_batch(batch, objects_group: ObjectsGroup, args: dict):
    '''This function processes a batch of similarity data. It takes as input a batch of data,
    a group of objects (model, optimizer, scheduler, etc.), and the arguments. It returns the loss of the batch.'''
    device = args.device
    model, scaler = objects_group.model, objects_group.scaler

    with autocast('cuda') if args.use_amp else nullcontext():
        b_ids_1, b_mask_1, b_ids_2, b_mask_2, b_labels = (batch['token_ids_1'], batch['attention_mask_1'], batch['token_ids_2'], batch['attention_mask_2'], batch['labels'])
        b_ids_1, b_mask_1, b_ids_2, b_mask_2, b_labels = b_ids_1.to(device), b_mask_1.to(device), b_ids_2.to(device), b_mask_2.to(device), b_labels.to(device)

        embeddings = model.get_similarity_paraphrase_embeddings(b_ids_1, b_mask_1, b_ids_2, b_mask_2, task_id=2)
        preds = model.last_layers_similarity(embeddings)
        loss = F.mse_loss(preds.view(-1), b_labels.view(-1), reduction='sum') / args.batch_size
        loss_value = loss.item()

        #To use smart_regularization
        # if args.use_smart_regularization:
        #     smart_regularization(loss_value, args.smart_weight_regularization, embeddings, preds, model.last_layers_similarity)

        objects_group.loss_sum += loss_value
        
        if args.projection == "none":
            if args.use_amp: scaler.scale(loss).backward()
            else: loss.backward()
        return loss

def step_optimizer(objects_group: ObjectsGroup, args: dict, step: int, total_nb_batches = None):
    """Step the optimizer and update the scaler. Returns the loss"""
    optimizer, scaler = objects_group.optimizer, objects_group.scaler
    if args.use_amp:
        scaler.step(optimizer)
        scaler.update()
    else:
        optimizer.step()
    optimizer.zero_grad()
    loss_value = objects_group.loss_sum
    objects_group.loss_sum = 0
    torch.cuda.empty_cache()
    return loss_value

In [6]:
seed_everything(10004)

In [7]:
args= {
    'sst_file': '/home/interiit/hp3/uday/nlp_pr/train_data/final_sentiment.csv',
    'para_file': '/home/interiit/hp3/uday/nlp_pr/train_data/final_paraphase.csv',
    'sts_file':  '/home/interiit/hp3/uday/nlp_pr/train_data/final_similarity.csv',
    'para_batch_size': 2,
    'sst_batch_size': 2,
    'sts_batch_size': 2,
    'option': 'optimize',
    'hidden_layers': 2,
    'hidden_drp_prob': 0.2,
    'lr': 1e-6,
    'epochs': 2 
}

In [8]:
device = torch.device('cuda')
sst_train_data, num_labels,para_train_data, sts_train_data = load_multitask_data(args['sst_file'],args['para_file'],args['sts_file'], split ='train')

Loaded 9810 train examples from /home/interiit/hp3/uday/nlp_pr/train_data/final_sentiment.csv
Loaded 49401 train examples from /home/interiit/hp3/uday/nlp_pr/train_data/final_paraphase.csv
Loaded 9840 train examples from /home/interiit/hp3/uday/nlp_pr/train_data/final_similarity.csv


In [9]:
sst_train_data = SentenceClassificationDataset(sst_train_data)
sst_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size= args['sst_batch_size'], collate_fn=sst_train_data.collate_fn)
"""validation dataset for sentence classification dataset"""

# Para: Paraphrase detection
para_train_data = SentencePairDataset(para_train_data)
para_train_dataloader = DataLoader(para_train_data, shuffle=True, batch_size= args['para_batch_size'],collate_fn=para_train_data.collate_fn)
"""validation dataset for paraphase detection dataset"""

# STS: Semantic textual similarity
sts_train_data = SentencePairDataset(sts_train_data, isRegression=True)
sts_train_dataloader = DataLoader(sts_train_data, shuffle=True, batch_size=args['sts_batch_size'],collate_fn=sts_train_data.collate_fn)
"""validation dataset for semantic textual similarity dataset"""

'validation dataset for semantic textual similarity dataset'

In [10]:
config = {'hidden_dropout_prob': args['hidden_drp_prob'],'num_labels': 5,'hidden_size': 768,'data_dir': '.','option': args['option'],'n_hidden_layers': args['hidden_layers']}
config = SimpleNamespace(**config)
model = Bert_MultiTask(config).to(device)
optimizer = AdamW(model.parameters(), lr=args['lr'])
scaler = None

"""PCGrad used here"""

'PCGrad used here'

In [11]:
# Package objects
objects_group = ObjectsGroup(model, optimizer, scaler)
args['device'] = device
dataloaders = {'sst': sst_train_dataloader, 'para': para_train_dataloader, 'sts': sts_train_dataloader}
scheduler = RoundRobinScheduler(dataloaders)

In [12]:
if args['option'] == "optimize":
        # Run Kernel Optimization for SST
    linear = nn.Linear(5,5, dtype=torch.float16)
    W = np.eye(5).astype(np.float16)
    B = np.array([0, 0, 0, 0, 0]).astype(np.float16)

    # Init to W
    linear.weight.data = torch.from_numpy(W).to(torch.float16).to(device)
    linear.bias.data = torch.from_numpy(B).to(torch.float16).to(device)
    linear.to(device)
    optimizer = AdamW(linear.parameters(), lr=args['lr'])
    # Compute accuracy on dev set
    # model.last_linear_sentiment = linear
    # dev_ac, _, _, _ = model_eval_sentiment(sst_dev_dataloader, model, device)
    # print(Colors.BOLD + Colors.BLUE + "Accuracy on dev set: " + Colors.END + Colors.BLUE + str(dev_ac) + Colors.END)

    # Print number of parameters for the optimizer
    # print(Colors.BOLD + Colors.BLUE + "Number of parameters for the optimizer: " + Colors.END + Colors.BLUE + str(count_parameters(linear)) + Colors.END)
    for epoch in range(args['epochs']):
        
        model.last_linear_sentiment = None
        model.eval()
        
        for batch in tqdm(sst_train_dataloader, desc="Kernel Optimization", smoothing=0):
            b_ids, b_mask, b_labels = (batch['token_ids'], batch['attention_mask'], batch['labels'])
            b_ids, b_mask, b_labels = b_ids.to(device), b_mask.to(device), b_labels.to(device)
            embeddings = model.forward(b_ids, b_mask, task_id=0)
            logits = model.last_layers_sentiment(embeddings)
            logits = linear(logits)
            loss = F.cross_entropy(logits, b_labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # Evaluate on dev set
        # Actual evaluation
        # model.last_linear_sentiment = linear
        # dev_ac, _, _, _ = model_eval_sentiment(sst_dev_dataloader, model, device)
        # print(Colors.BOLD + Colors.BLUE + "Accuracy on dev set: " + Colors.END + Colors.BLUE + str(dev_ac) + Colors.END)

# Print weights of linear layer
# print(Colors.BOLD + Colors.BLUE + "Weights of linear layer: " + Colors.END + Colors.BLUE + str(linear.weight.data) + Colors.END)
# print(Colors.BOLD + Colors.BLUE + "Bias of linear layer: " + Colors.END + Colors.BLUE + str(linear.bias.data) + Colors.END)
        

Kernel Optimization:   0%|          | 0/4905 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half