In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
import torch
import numpy as np
import random

def set_random_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
set_random_seed(0)

In [None]:
! pip install datasets transformers rouge-score nltk py7zr



In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# %cd /content/drive/MyDrive/NLP Project with SCL

# Fine-tuning a model on a summarization task

## Loading the dataset

In [None]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("samsum")

metric = load_metric("rouge")

Found cached dataset samsum (/root/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)


  0%|          | 0/3 [00:00<?, ?it/s]

  """


## BART

### Preprocessing the data

In [None]:
model_checkpoint = "facebook/bart-base"

In [None]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
def check_token_length(dataset):
    ids=[]
    for i in range(len(dataset['dialogue'])):
        if len(tokenizer(dataset['dialogue'][i])['input_ids'])>1000:
            ids.append(i)
    print(ids)
    return ids
def remove_idx(list_idx, dataset):
    return dataset.select((
          i for i in range(len(dataset)) 
          if i not in set(list_idx)))
    
train_ids=check_token_length(raw_datasets['train'])
validation_ids=check_token_length(raw_datasets['validation'])
test_ids = check_token_length(raw_datasets['test'])
changed_datasets_train=remove_idx(train_ids, raw_datasets['train'])
changed_datasets_val = remove_idx(validation_ids, raw_datasets['validation'])
changed_datasets_test = remove_idx(test_ids, raw_datasets['test'])

In [None]:
max_input_length = 1024
max_target_length = 128

def make_one_hot_sequence(input_ids, sequence_ids):
    changed_sequence_id=[0]
    token_to_speaker_id={}
    uniq_id = 1
    for dic in sequence_ids:
        if str(input_ids[dic['spk'][0]:dic['spk'][1]]) in token_to_speaker_id:
            speaker_id = token_to_speaker_id[str(input_ids[dic['spk'][0]:dic['spk'][1]])]
        else:
            token_to_speaker_id[str(input_ids[dic['spk'][0]:dic['spk'][1]])] = uniq_id
            speaker_id = uniq_id
            uniq_id+=1
        for _ in range(dic['spk'][0], dic['spk'][1]):
            changed_sequence_id.append(speaker_id)
        for _ in range(dic['utt'][0], dic['utt'][1]):
            changed_sequence_id.append(-1)
    changed_sequence_id.append(0)
    return changed_sequence_id 


def preprocess_function(examples): ## hit gold here. change this preprocess function to include speaker and turn information. 
    slash_n = tokenizer(["\r\n"])['input_ids'][0][1:-1]
    slash_n_mask = tokenizer(["\r\n"])['attention_mask'][0][1:-1]
    inputs_list=[]
    masks_list=[]
    pos_list=[]
    for index in range(len(examples['dialogue'])):
        # breaking the dialogue for spk:utt info
        broken=[]
        for utt in examples['dialogue'][index].split("\r\n"):
            first_ind = utt.find(':')
            broken.append(utt[:first_ind])
            broken.append(utt[first_ind:])
        
        tokenized_broken = tokenizer(broken)['input_ids']
        attention_broken = tokenizer(broken)['attention_mask']
        
        # adding \r\n tokens
        for i in range(1, len(tokenized_broken)-1, 2):
            tokenized_broken[i].insert(-1, slash_n[0])
            tokenized_broken[i].insert(-1, slash_n[1])
            attention_broken[i].insert(-1, slash_n_mask[0])
            attention_broken[i].insert(-1, slash_n_mask[1])
        joined = tokenized_broken[0]

        # annotating for spk_utt_pos
        assoc_dict={}
        assoc_dict['spk'] = [1, len(tokenized_broken[0])-1] # the range is actually exclusive of the last index. 
        odd_bool = True
        running_length = len(tokenized_broken[0])
        sequence_ids=[]
        for inner in tokenized_broken[1:]:
            if odd_bool==True:
                assoc_dict['utt']=[running_length-1, running_length+len(inner)-3]
                odd_bool=False
                sequence_ids.append(assoc_dict)
                assoc_dict={}
            else:
                assoc_dict['spk']=[running_length-1, running_length+len(inner)-3]
                odd_bool=True
            joined = joined[:-1]+inner[1:]
            running_length += (len(inner)-2)
        
        # test for CUDA assert error
        if(len(joined)>1024):
            print("input tokens list length greater than 1024, skipping example", end=' ')
            print("equal to", len(joined))
            print(tokenizer.decode(joined))
        
        # creating inputs list
        inputs_list.append(joined)
        pos_list.append(make_one_hot_sequence(joined, sequence_ids))
        
        # creating new mask
        joined_mask = attention_broken[0]
        for inner_attention in attention_broken[1:]:
            joined_mask = joined_mask[:-1]+inner_attention[1:]
        masks_list.append(joined_mask)
    
    # overriding normal model_inputs
    inputs = [doc for doc in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    model_inputs['input_ids'] = inputs_list
    model_inputs['attention_mask'] = masks_list
    model_inputs['spk_utt_pos'] = pos_list
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_datasets_train_o = changed_datasets_train.map(preprocess_function, batched=True)
tokenized_datasets_val_o = changed_datasets_val.map(preprocess_function, batched=True)
tokenized_datasets_test_o = changed_datasets_test.map(preprocess_function, batched=True)

# tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
tokenized_datasets_train = tokenized_datasets_train_o.remove_columns(['id', 'dialogue', 'summary'])
tokenized_datasets_val = tokenized_datasets_val_o.remove_columns(['id', 'dialogue', 'summary'])
tokenized_datasets_test = tokenized_datasets_test_o.remove_columns(['id', 'dialogue', 'summary'])

In [None]:
from transformers import Seq2SeqTrainer
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES


class CustomTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # How the loss is computed by Trainer. By default, all models return the loss in the first element.
        # Subclass and override for custom behavior.
        # print(inputs)
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)

        # Save past state if it exists
        # TODO: this needs to be fixed and mselfade cleaner later.

        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        return (loss, outputs) if return_outputs else loss


In [None]:
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from transformers import DataCollatorForSeq2Seq
from typing import Optional, Any, Union
import numpy as np


class CustomCollatorForSeq2Seq(DataCollatorForSeq2Seq):
    r"""
    Data collator that will dynamically pad the inputs received, as well as the labels.
    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        model ([`PreTrainedModel`]):
            The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
            prepare the *decoder_input_ids*
            This is useful when using *label_smoothing* to avoid calculating loss twice.
        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
              is provided).
            - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
              acceptable input length for the model if that argument is not provided.
            - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
              lengths).
        max_length (`int`, *optional*):
            Maximum length of the returned list and optionally padding length (see above).
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
        label_pad_token_id (`int`, *optional*, defaults to -100):
            The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
        return_tensors (`str`):
            The type of Tensor to return. Allowable values are "np", "pt" and "tf".
    """

    tokenizer: PreTrainedTokenizerBase
    model: Optional[Any] = None
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"

    def __call__(self, features, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors
        labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
        # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
        # same length to return tensors.
        if labels is not None:
            max_label_length = max(len(l) for l in labels)
            if self.pad_to_multiple_of is not None:
                max_label_length = (
                        (max_label_length + self.pad_to_multiple_of - 1)
                        // self.pad_to_multiple_of
                        * self.pad_to_multiple_of
                )

            padding_side = self.tokenizer.padding_side
            for feature in features:
                remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
                if isinstance(feature["labels"], list):
                    feature["labels"] = (
                        feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
                    )
                elif padding_side == "right":
                    feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
                else:
                    feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
        # added here
        spk_utt_pos = [feature["spk_utt_pos"] for feature in features]
        max_spk_utt_pos_length = max(len(l) for l in spk_utt_pos)

        if self.pad_to_multiple_of is not None:
            max_spk_utt_pos_length = (
                    (max_spk_utt_pos_length + self.pad_to_multiple_of - 1)
                    // self.pad_to_multiple_of
                    * self.pad_to_multiple_of
            )

        padding_side = self.tokenizer.padding_side
        for feature in features:
            remainder = [0] * (max_spk_utt_pos_length - len(feature["spk_utt_pos"]))
            if isinstance(feature["spk_utt_pos"], list):
                feature["spk_utt_pos"] = (
                    feature["spk_utt_pos"] + remainder if padding_side == "right" else remainder + feature[
                        "spk_utt_pos"]
                )
            elif padding_side == "right":
                feature["spk_utt_pos"] = np.concatenate([feature["spk_utt_pos"], remainder]).astype(np.int64)
            else:
                feature["spk_utt_pos"] = np.concatenate([remainder, feature["spk_utt_pos"]]).astype(np.int64)

        features = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=return_tensors,
        )

        # prepare decoder_input_ids
        if (
                labels is not None
                and self.model is not None
                and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
        ):
            decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
            features["decoder_input_ids"] = decoder_input_ids

        return features


In [None]:
from torch import nn
from transformers import BartForConditionalGeneration, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

from transformers.models.bart.modeling_bart import BartConfig
import torch
from typing import *
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.models.bart.modeling_bart import shift_tokens_right
import random
from tqdm import tqdm
import gc
import itertools

class BartWithSCL(BartForConditionalGeneration):
    def __init__(self, config: BartConfig):
        super().__init__(config)

    def set_losses_list(self, SCLossesList=['token']):

        self.SCLossesList = SCLossesList
    def set_scl_coeff(self, scl_coeff=1e-1):
        self.scl_coeff=scl_coeff
    def token_scl(self,
                  last_hidden_state: torch.FloatTensor,
                  spk_utt_pos: torch.LongTensor,
    ) -> torch.FloatTensor:
        r"""
        last_hidden_state (torch.LongTensor) of shape (batch_size, sequence_length, n_dims):
            Output of the last layer of the encoder.
        spk_utt_pos (torch.LongTensor) of shape (batch_size, sequence_length,):
            metadata about the speaker tokens and utterance tokens
        Returns:
        Token Level Supervised Constrastive Loss (torch.LongTensor)
        """
        
        batch_scl = 0
        for i in range(len(spk_utt_pos)):
            batch_element = spk_utt_pos[i]
            spk_utt_list = []
            spk_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            utt_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            for j in range(len(batch_element)):
                if batch_element[j] == 0 and j > 0:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                         'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    break
                if batch_element[j] > 0 and spk_dict['bool'] == False:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    if j > 1:
                        spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                             'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    spk_dict['start'] = j
                    spk_dict['bool'] = True
                    spk_dict['spk_id'] = batch_element[j]
                    

                if batch_element[j] < 0 and spk_dict['bool'] == True:
                    spk_dict['end'] = j
                    spk_dict['bool'] = False
                    utt_dict['spk_id'] = spk_dict['spk_id']
                    utt_dict['start'] = j
                    utt_dict['bool'] = True
            # uniq spks
            if spk_utt_list[0]['spk'][2]==0:
                continue
            uniq_spks = list(set([int(dic['spk'][2].cpu()) for dic in spk_utt_list]))
            if len(uniq_spks)==1:
                continue
            # spk_utt_states
            spk_utt_states = {spk: [] for spk in uniq_spks}

            for spk in uniq_spks:
                for dic in spk_utt_list:
                    if spk == dic['utt'][2]:
                        spk_utt_states[spk].append(last_hidden_state[i, dic['utt'][0]:dic['utt'][1]])
            
            
            #---------- hitesh------------------------------
            # positive samples
            # L_pos = 0
            # L_neg = 0 

            # sampled_spk_utt_states = []           

            # for spk in uniq_spks:
            #     utts = len(spk_utt_states[spk])
            #     spk_utt = []
            #     if utts > 1:
            #         # ids = random.sample(list(range(len(spk_utt_states[spk]))), random.randint(1, utts))
            #         ids = random.sample(list(range(len(spk_utt_states[spk]))), 2)
            #         for i in ids:
            #           spk_utt.append(spk_utt_states[spk][i])
            #     sampled_spk_utt_states.append(spk_utt)

            # for instance in sampled_spk_utt_states:
            #   for i in range(len(instance)):
            #     for j in range(len(instance)):
            #       mat_mul = torch.einsum('ij, kj->ik', instance[i], instance[j])
            #       sigm = torch.sigmoid(mat_mul)
            #       log = torch.log(sigm)
            #       L_pos += torch.sum(-1 * log)
            # # print("L_pos", L_pos)

            # #negative loss
            # for i in range(0,len(sampled_spk_utt_states)):
            #   instance = sampled_spk_utt_states[i]

            #   neg_instances = sampled_spk_utt_states[:i]+sampled_spk_utt_states[i+1:]
            #   neg_instances = list(itertools.chain(*neg_instances))
            #   # neg_instances = random.choices(neg_instances,k = random.randint(1, len(neg_instances)))
            #   if len(neg_instances)>0:
            #     # print(len(neg_instances))
            #     # print("-------------------------")
            #     # print(sampled_spk_utt_states)
            #     neg_instances = random.choices(neg_instances,k = 2)
            #     for i in range(len(instance)):
            #       for j in range(len(neg_instances)):
            #         mat_mul = torch.einsum('ij, kj->ik', instance[i], neg_instances[j])
            #         sigm = torch.sigmoid(mat_mul)
            #         log = torch.log(1 - sigm+1e-5)
            #         L_neg += torch.sum(-1 * log)
            #---------- hitesh------------------------------
            
            
            # positive samples
            L_pos = 0
            for spk in uniq_spks:
                if len(spk_utt_states[spk]) > 1:
                    ids = random.sample(list(range(len(spk_utt_states[spk]))), 2)
                    id1 = ids[0]
                    id2 = ids[1]
                    mat_mul = torch.einsum('ij, kj->ik', spk_utt_states[spk][id1], spk_utt_states[spk][id2])
                    sigm = torch.sigmoid(mat_mul)
                    log = torch.log(sigm)
                    L_pos += torch.sum(-1 * log)
                    L_pos = torch.nan_to_num(L_pos, posinf = 1e10, neginf = -1e10)
            # print("L_pos", L_pos)
            # negative samples
            
            L_neg = 0
            for spk in uniq_spks:
                new_uniq_spks = uniq_spks.copy()
                new_uniq_spks.remove(spk)

                spk2 = random.choice(new_uniq_spks)

                id1 = random.randint(0, len(spk_utt_states[spk])-1)
                id2 = random.randint(0, len(spk_utt_states[spk2])-1)

                mat_mul = torch.einsum('ij, kj->ik', spk_utt_states[spk][id1], spk_utt_states[spk2][id2])
                sigm = torch.sigmoid(mat_mul)
                # print(1 - sigm)
                # print(1 - sigm+1e-5)
                log = torch.log(1 - sigm+1e-5)
                L_neg += torch.sum(-1 * log)
                
                L_neg = torch.nan_to_num(L_neg, posinf = 1e10, neginf = -1e10)

            # print("L_neg", L_neg)
            
            batch_scl += L_pos
            batch_scl += L_neg
            

        batch_scl /= last_hidden_state.size(0)
        gc.collect()
        return batch_scl
    
    def turn_scl(self,
                  last_hidden_state: torch.FloatTensor,
                  spk_utt_pos: torch.LongTensor,
    ) -> torch.FloatTensor:
        r"""
        last_hidden_state (torch.LongTensor) of shape (batch_size, sequence_length, n_dims):
            Output of the last layer of the encoder.
        spk_utt_pos (torch.LongTensor) of shape (batch_size, sequence_length,):
            metadata about the speaker tokens and utterance tokens
        Returns:
        Turn Level Supervised Constrastive Loss (torch.LongTensor)
        """
        batch_scl = 0
        for i in range(len(spk_utt_pos)):
            batch_element = spk_utt_pos[i]
            spk_utt_list = []
            spk_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            utt_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            for j in range(len(batch_element)):
                if batch_element[j] == 0 and j > 0:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                         'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    break
                if batch_element[j] > 0 and spk_dict['bool'] == False:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    if j > 1:
                        spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                             'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    spk_dict['start'] = j
                    spk_dict['bool'] = True
                    spk_dict['spk_id'] = batch_element[j]
                    

                if batch_element[j] < 0 and spk_dict['bool'] == True:
                    spk_dict['end'] = j
                    spk_dict['bool'] = False
                    utt_dict['spk_id'] = spk_dict['spk_id']
                    utt_dict['start'] = j
                    utt_dict['bool'] = True
            # uniq spks
            if spk_utt_list[0]['spk'][2]==0:
                continue
            uniq_spks = list(set([int(dic['spk'][2].cpu()) for dic in spk_utt_list]))
            if len(uniq_spks)==1:
                continue
            # spk_utt_states
            spk_utt_states = {spk: [] for spk in uniq_spks}

            for spk in uniq_spks:
                for dic in spk_utt_list:
                    if spk == dic['utt'][2]:
                        mean_pool = torch.mean(last_hidden_state[i, dic['utt'][0]:dic['utt'][1]], 0)
                        spk_utt_states[spk].append(mean_pool)

            # positive samples
            L_pos = 0
            for spk in uniq_spks:
                if len(spk_utt_states[spk]) > 1:
                    ids = random.sample(list(range(len(spk_utt_states[spk]))), 2)
                    id1 = ids[0]
                    id2 = ids[1]
                    mat_mul = torch.einsum('i, j->', spk_utt_states[spk][id1], spk_utt_states[spk][id2])
                    sigm = torch.sigmoid(mat_mul)
                    log = torch.log(sigm)
                    L_pos += torch.sum(-1 * log)
                    # L_pos = torch.nan_to_num(L_pos, posinf = 1e10, neginf = -1e10)
            # print("L_pos", L_pos)
            # negative samples
            L_neg = 0
            for spk in uniq_spks:
                new_uniq_spks = uniq_spks.copy()
                new_uniq_spks.remove(spk)

                spk2 = random.choice(new_uniq_spks)

                id1 = random.randint(0, len(spk_utt_states[spk])-1)
                id2 = random.randint(0, len(spk_utt_states[spk2])-1)

                mat_mul = torch.einsum('i, j->', spk_utt_states[spk][id1], spk_utt_states[spk2][id2])
                sigm = torch.sigmoid(mat_mul)
                # print(1 - sigm)
                # print(1 - sigm+1e-5)
                log = torch.log(1 - sigm+1e-5)
                L_neg += torch.sum(-1 * log)
                
                # L_neg = torch.nan_to_num(L_neg, posinf = 1e10, neginf = -1e10)

            # print("L_neg", L_neg)
            
            batch_scl += L_pos
            batch_scl += L_neg

        batch_scl /= last_hidden_state.size(0)
        gc.collect()
        return batch_scl
    
    def global_scl(self,
                  last_hidden_state: torch.FloatTensor,
                  spk_utt_pos: torch.LongTensor,
    ) -> torch.FloatTensor:
        r"""
        last_hidden_state (torch.LongTensor) of shape (batch_size, sequence_length, n_dims):
            Output of the last layer of the encoder.
        spk_utt_pos (torch.LongTensor) of shape (batch_size, sequence_length,):
            metadata about the speaker tokens and utterance tokens
        Returns:
        Turn Level Supervised Constrastive Loss (torch.LongTensor)
        """
        batch_scl = 0
        for i in range(len(spk_utt_pos)):
            batch_element = spk_utt_pos[i]
            spk_utt_list = []
            spk_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            utt_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            for j in range(len(batch_element)):
                if batch_element[j] == 0 and j > 0:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                         'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    break
                if batch_element[j] > 0 and spk_dict['bool'] == False:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    if j > 1:
                        spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                             'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    spk_dict['start'] = j
                    spk_dict['bool'] = True
                    spk_dict['spk_id'] = batch_element[j]
                    

                if batch_element[j] < 0 and spk_dict['bool'] == True:
                    spk_dict['end'] = j
                    spk_dict['bool'] = False
                    utt_dict['spk_id'] = spk_dict['spk_id']
                    utt_dict['start'] = j
                    utt_dict['bool'] = True
            # uniq spks
            if spk_utt_list[0]['spk'][2]==0:
                continue
            uniq_spks = list(set([int(dic['spk'][2].cpu()) for dic in spk_utt_list]))
            if len(uniq_spks)==1:
                continue
            # spk_utt_states
            spk_utt_states = {spk: [] for spk in uniq_spks}

            for spk in uniq_spks:
                for dic in spk_utt_list:
                    if spk == dic['utt'][2]:
                        mean_pool = torch.mean(last_hidden_state[i, dic['utt'][0]:dic['utt'][1]], 0)
                        spk_utt_states[spk].append(mean_pool)

            # positive samples
            L_pos = 0
            L_neg = 0
            for spk in uniq_spks:
                if len(spk_utt_states[spk]) > 1:
                    ids = random.choice(list(range(len(spk_utt_states[spk]))))
                    
                    spk_mean_exc = torch.mean(torch.vstack([spk_utt_states[spk][temp] for temp in range(len(spk_utt_states[spk])) if temp != ids]), 0)
                    
                    pos_mat_mul = torch.einsum('i, j->', spk_utt_states[spk][ids], spk_mean_exc)
                    pos_sigm = torch.sigmoid(pos_mat_mul)
                    pos_log = torch.log(pos_sigm)
                    L_pos += torch.sum(-1 * pos_log)

                    # negative sample

                    new_uniq_spks = uniq_spks.copy()
                    new_uniq_spks.remove(spk)
                    
                    spk2 = random.choice(new_uniq_spks)
                    id_neg = random.choice(list(range(len(spk_utt_states[spk2]))))
                    neg_mat_mul = torch.einsum('i, j->', spk_utt_states[spk2][id_neg], spk_mean_exc)
                    neg_sigm = torch.sigmoid(neg_mat_mul)
                    neg_log = torch.log(1 - neg_sigm+1e-5)
                    L_neg += torch.sum(-1 * neg_log)
                

            # print("L_neg", L_neg)
            
            batch_scl += L_pos
            batch_scl += L_neg

        batch_scl /= last_hidden_state.size(0)
        gc.collect()
        return batch_scl

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            spk_utt_pos: Optional[torch.Tensor] = None, ##changed here
            decoder_input_ids: Optional[torch.LongTensor] = None,
            decoder_attention_mask: Optional[torch.LongTensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            decoder_head_mask: Optional[torch.Tensor] = None,
            cross_attn_head_mask: Optional[torch.Tensor] = None,
            encoder_outputs: Optional[List[torch.FloatTensor]] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        Returns:
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if use_cache:
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if encoder_outputs is None:
            encoder = self.get_encoder()
            # TODO: mask the speaker names from the input IDs using the speaker pos info
            turn_attention_mask=None
            token_encoder_outputs=None
            tog_encoder_outputs=None
            
            if 'token' in self.SCLossesList:
                token_encoder_outputs = encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    head_mask=head_mask,
                    inputs_embeds=inputs_embeds,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                )

            if 'turn' in self.SCLossesList or 'global' in self.SCLossesList:
                tog_attention_mask = torch.where(spk_utt_pos>0, 0, attention_mask)
                tog_encoder_outputs = encoder(
                    input_ids=input_ids,
                    attention_mask=tog_attention_mask,
                    head_mask=head_mask,
                    inputs_embeds=inputs_embeds,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                )
        # if 'hidden_states' in encoder_outputs:
        #     print("encoder_outputs['last_hidden_state'].size(), encoder_outputs['hidden_states'].size()",
        #     encoder_outputs['last_hidden_state'].size(), encoder_outputs['hidden_states'].size())
        # else:
        #     print("encoder_outputs['last_hidden_state'].size()", encoder_outputs['last_hidden_state'].size())

        lm_logits = self.lm_head(outputs[0])
        lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)

        masked_lm_loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
        # added here
        sc_loss = 0
        if 'token' in self.SCLossesList and labels is not None:
            sc_loss += self.token_scl(last_hidden_state=token_encoder_outputs['last_hidden_state'], spk_utt_pos=spk_utt_pos)
            # print(sc_loss)
        if 'turn' in self.SCLossesList and labels is not None:
            sc_loss += self.turn_scl(last_hidden_state=tog_encoder_outputs['last_hidden_state'], spk_utt_pos=spk_utt_pos)
        
        if 'global' in self.SCLossesList and labels is not None:
            sc_loss += self.global_scl(last_hidden_state=tog_encoder_outputs['last_hidden_state'], spk_utt_pos=spk_utt_pos)
        
        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss+(self.scl_coeff*sc_loss),) + output) if masked_lm_loss is not None else output
        loss = None
        if masked_lm_loss is None:
            loss = None
        else:
            loss = masked_lm_loss+(self.scl_coeff*sc_loss)
        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )


In [None]:
import gc 
gc.collect()

### Fine-tuning the model

In [None]:
# from models import BartWithSCL
# from datacollator import CustomCollatorForSeq2Seq
# from trainer import CustomTrainer


from transformers import BartForConditionalGeneration, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

In [None]:
model = BartWithSCL.from_pretrained(model_checkpoint)
model.set_losses_list(['token','turn','global'])
model.set_scl_coeff(0.1)

In [None]:
batch_size = 3
args = Seq2SeqTrainingArguments(
    "bart-tjoin-b6c0.1",
    evaluation_strategy = "epoch",
    # eval_steps=5,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    # save_total_limit=2,
    num_train_epochs=5,
    logging_steps = 10, ## added
    predict_with_generate=True,
    remove_unused_columns=False, ## added
    fp16=True,
)

In [None]:
data_collator = CustomCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
import nltk
import numpy as np
import torch
torch.cuda.empty_cache()
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    for i in range(0,50):
      # print(tokenized_datasets_val["dialogue"][i])
      print("-----------",i,"--------------")
      print("------>Predictions by Model")
      print(decoded_preds[i])
      print("----->Predictions Original")
      print(decoded_labels[i])
      print("**************************")
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
trainer = CustomTrainer(
    model,
    args,
    train_dataset=tokenized_datasets_train,
    eval_dataset=tokenized_datasets_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
import nltk
nltk.download('punkt')

In [None]:
trainer.train()

In [None]:
trainer.train()

In [None]:
trainer.evaluate()