In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    
    print('Memory Usage:',round(torch.cuda.get_device_properties(0).total_memory/1024**3,1), 'GB')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')


Using device: cuda

A100-SXM4-40GB
Memory Usage: 39.6 GB
Allocated: 0.0 GB
Cached:    0.0 GB




In [3]:
import torch
import numpy as np
import random

def set_random_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
set_random_seed(0)

In [4]:
! pip install datasets transformers rouge-score nltk py7zr



In [5]:
# from google.colab import drive
# drive.mount('/content/drive')

In [6]:
# %cd /content/drive/MyDrive/NLP Project with SCL

## Loading the dataset

In [7]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("samsum")

metric = load_metric("rouge")

Found cached dataset samsum (/root/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)


  0%|          | 0/3 [00:00<?, ?it/s]

  """


## T5

### Preprocessing the data

In [8]:
model_checkpoint = "t5-base"

In [9]:
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [10]:
tokenizer(["Amanda: bla bla\r\nGrey: toyot"])

{'input_ids': [[21542, 10, 3, 4605, 3, 4605, 12630, 10, 12, 63, 32, 17, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [11]:
def check_token_length(dataset):
    ids=[]
    for i in range(len(dataset['dialogue'])):
        if len(tokenizer(dataset['dialogue'][i])['input_ids'])>1000:
            ids.append(i)
    print(ids)
    return ids
def remove_idx(list_idx, dataset):
    return dataset.select((
          i for i in range(len(dataset)) 
          if i not in set(list_idx)))
    
train_ids=check_token_length(raw_datasets['train'])
validation_ids=check_token_length(raw_datasets['validation'])
test_ids = check_token_length(raw_datasets['test'])
changed_datasets_train=remove_idx(train_ids, raw_datasets['train'])
changed_datasets_val = remove_idx(validation_ids, raw_datasets['validation'])
changed_datasets_test = remove_idx(test_ids, raw_datasets['test'])

Token indices sequence length is longer than the specified maximum sequence length for this model (567 > 512). Running this sequence through the model will result in indexing errors


[4269, 9491]
[]




[]


In [12]:
max_input_length = 1024
max_target_length = 128

def make_one_hot_sequence(input_ids, sequence_ids):
    changed_sequence_id=[]
    token_to_speaker_id={}
    uniq_id = 1
    for dic in sequence_ids:
        if str(input_ids[dic['spk'][0]:dic['spk'][1]]) in token_to_speaker_id:
            speaker_id = token_to_speaker_id[str(input_ids[dic['spk'][0]:dic['spk'][1]])]
        else:
            token_to_speaker_id[str(input_ids[dic['spk'][0]:dic['spk'][1]])] = uniq_id
            speaker_id = uniq_id
            uniq_id+=1
        for _ in range(dic['spk'][0], dic['spk'][1]):
            changed_sequence_id.append(speaker_id)
        for _ in range(dic['utt'][0], dic['utt'][1]):
            changed_sequence_id.append(-1)
    changed_sequence_id.append(0)
    return changed_sequence_id 


def preprocess_function(examples): ## hit gold here. change this preprocess function to include speaker and turn information. 
    # slash_n = tokenizer(["\r\n"])['input_ids'][0][1:-1]
    # slash_n_mask = tokenizer(["\r\n"])['attention_mask'][0][1:-1]
    inputs_list=[]
    masks_list=[]
    pos_list=[]
    for index in range(len(examples['dialogue'])):
        # breaking the dialogue for spk:utt info
        
        
        original = tokenizer(examples['dialogue'][index])['input_ids']
        
        
        broken=[]
        for utt in examples['dialogue'][index].split("\r\n"):
            first_ind = utt.find(':')
            broken.append(utt[:first_ind+1])
            broken.append(utt[first_ind+1:])
        
        tokenized_broken = tokenizer(broken)['input_ids']
        attention_broken = tokenizer(broken)['attention_mask']
        # # adding \r\n tokens
        # for i in range(1, len(tokenized_broken)-1, 2):
        #     print(slash_n[0])

        #     tokenized_broken[i].insert(-1, slash_n[0])
        #     tokenized_broken[i].insert(-1, slash_n[1])
        #     attention_broken[i].insert(-1, slash_n_mask[0])
        #     attention_broken[i].insert(-1, slash_n_mask[1])
        #     print("second",tokenized_broken[i])

        joined = tokenized_broken[0]

        # annotating for spk_utt_pos
        assoc_dict={}
        assoc_dict['spk'] = [0, len(tokenized_broken[0])-1] # the range is actually exclusive of the last index. 
        odd_bool = True
        running_length = len(tokenized_broken[0])
        sequence_ids=[]
        for inner in tokenized_broken[1:]:
            if odd_bool==True:
                assoc_dict['utt']=[running_length-1, running_length+len(inner)-2]
                odd_bool=False
                sequence_ids.append(assoc_dict)
                assoc_dict={}
            else:
                assoc_dict['spk']=[running_length-1, running_length+len(inner)-2]
                odd_bool=True
            joined = joined[:-1]+inner
            running_length += (len(inner)-1)
        
        # test for CUDA assert error
        if(len(joined)>1024):
            print("input tokens list length greater than 1024, skipping example", end=' ')
            print("equal to", len(joined))
            print(tokenizer.decode(joined))
        
        # creating inputs list
        inputs_list.append(joined)
        one_hot_spk_pos = make_one_hot_sequence(joined, sequence_ids)
        pos_list.append(one_hot_spk_pos)
        
        # creating new mask
        joined_mask = attention_broken[0]
        for inner_attention in attention_broken[1:]:
            joined_mask = joined_mask[:-1]+inner_attention
        masks_list.append(joined_mask)
    
    # overriding normal model_inputs
    inputs = [doc for doc in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    model_inputs['input_ids'] = inputs_list
    model_inputs['attention_mask'] = masks_list
    model_inputs['spk_utt_pos'] = pos_list
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [13]:
tokenized_datasets_train_o = changed_datasets_train.map(preprocess_function, batched=True)
tokenized_datasets_val_o = changed_datasets_val.map(preprocess_function, batched=True)
tokenized_datasets_test_o = changed_datasets_test.map(preprocess_function, batched=True)

# tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
tokenized_datasets_train = tokenized_datasets_train_o.remove_columns(['id', 'dialogue', 'summary'])
tokenized_datasets_val = tokenized_datasets_val_o.remove_columns(['id', 'dialogue', 'summary'])
tokenized_datasets_test = tokenized_datasets_test_o.remove_columns(['id', 'dialogue', 'summary'])

Loading cached processed dataset at /root/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-199093c5bc04cca3.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-6b9cd97b7cbc6058.arrow


  0%|          | 0/1 [00:00<?, ?ba/s]

  "`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your "


In [14]:
from transformers import Seq2SeqTrainer
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES


class CustomTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # How the loss is computed by Trainer. By default, all models return the loss in the first element.
        # Subclass and override for custom behavior.
        # print(inputs)
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)

        # Save past state if it exists
        # TODO: this needs to be fixed and mselfade cleaner later.

        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        return (loss, outputs) if return_outputs else loss


In [15]:
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from transformers import DataCollatorForSeq2Seq
from typing import Optional, Any, Union
import numpy as np


class CustomCollatorForSeq2Seq(DataCollatorForSeq2Seq):
    r"""
    Data collator that will dynamically pad the inputs received, as well as the labels.
    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        model ([`PreTrainedModel`]):
            The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
            prepare the *decoder_input_ids*
            This is useful when using *label_smoothing* to avoid calculating loss twice.
        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
              is provided).
            - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
              acceptable input length for the model if that argument is not provided.
            - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
              lengths).
        max_length (`int`, *optional*):
            Maximum length of the returned list and optionally padding length (see above).
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
        label_pad_token_id (`int`, *optional*, defaults to -100):
            The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
        return_tensors (`str`):
            The type of Tensor to return. Allowable values are "np", "pt" and "tf".
    """

    tokenizer: PreTrainedTokenizerBase
    model: Optional[Any] = None
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"

    def __call__(self, features, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors
        labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
        # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
        # same length to return tensors.
        if labels is not None:
            max_label_length = max(len(l) for l in labels)
            if self.pad_to_multiple_of is not None:
                max_label_length = (
                        (max_label_length + self.pad_to_multiple_of - 1)
                        // self.pad_to_multiple_of
                        * self.pad_to_multiple_of
                )

            padding_side = self.tokenizer.padding_side
            for feature in features:
                remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
                if isinstance(feature["labels"], list):
                    feature["labels"] = (
                        feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
                    )
                elif padding_side == "right":
                    feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
                else:
                    feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
        # added here
        spk_utt_pos = [feature["spk_utt_pos"] for feature in features]
        max_spk_utt_pos_length = max(len(l) for l in spk_utt_pos)

        if self.pad_to_multiple_of is not None:
            max_spk_utt_pos_length = (
                    (max_spk_utt_pos_length + self.pad_to_multiple_of - 1)
                    // self.pad_to_multiple_of
                    * self.pad_to_multiple_of
            )

        padding_side = self.tokenizer.padding_side
        for feature in features:
            remainder = [0] * (max_spk_utt_pos_length - len(feature["spk_utt_pos"]))
            if isinstance(feature["spk_utt_pos"], list):
                feature["spk_utt_pos"] = (
                    feature["spk_utt_pos"] + remainder if padding_side == "right" else remainder + feature[
                        "spk_utt_pos"]
                )
            elif padding_side == "right":
                feature["spk_utt_pos"] = np.concatenate([feature["spk_utt_pos"], remainder]).astype(np.int64)
            else:
                feature["spk_utt_pos"] = np.concatenate([remainder, feature["spk_utt_pos"]]).astype(np.int64)

        features = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=return_tensors,
        )

        # prepare decoder_input_ids
        if (
                labels is not None
                and self.model is not None
                and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
        ):
            decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
            features["decoder_input_ids"] = decoder_input_ids

        return features


In [16]:
from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration
from torch import nn
from transformers import BartForConditionalGeneration, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from torch.nn import CrossEntropyLoss
from transformers.models.bart.modeling_bart import BartConfig
import torch
from typing import *
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.models.bart.modeling_bart import shift_tokens_right
import random
from tqdm import tqdm
import gc
import itertools
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
class T5WithSCL(T5ForConditionalGeneration):
    def __init__(self, config: BartConfig):
        super().__init__(config)

    def set_losses_list(self, SCLossesList=['token']):
        self.SCLossesList = SCLossesList
    
    def set_scl_coeff(self, scl_coeff=1e-1):
        self.scl_coeff=scl_coeff
    
    def token_scl(self,
                  last_hidden_state: torch.FloatTensor,
                  spk_utt_pos: torch.LongTensor,
    ) -> torch.FloatTensor:
        r"""
        last_hidden_state (torch.LongTensor) of shape (batch_size, sequence_length, n_dims):
            Output of the last layer of the encoder.
        spk_utt_pos (torch.LongTensor) of shape (batch_size, sequence_length,):
            metadata about the speaker tokens and utterance tokens
        Returns:
        Token Level Supervised Constrastive Loss (torch.LongTensor)
        """
        batch_scl = 0
        for i in range(len(spk_utt_pos)):
            batch_element = spk_utt_pos[i]
            spk_utt_list = []
            spk_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            utt_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            for j in range(len(batch_element)):
                if batch_element[j] == 0 and j > 0:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                         'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    break
                if batch_element[j] > 0 and spk_dict['bool'] == False:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    if j > 1:
                        spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                             'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    spk_dict['start'] = j
                    spk_dict['bool'] = True
                    spk_dict['spk_id'] = batch_element[j]
                    

                if batch_element[j] < 0 and spk_dict['bool'] == True:
                    spk_dict['end'] = j
                    spk_dict['bool'] = False
                    utt_dict['spk_id'] = spk_dict['spk_id']
                    utt_dict['start'] = j
                    utt_dict['bool'] = True
            # uniq spks
            if spk_utt_list[0]['spk'][2]==0:
                continue
            uniq_spks = list(set([int(dic['spk'][2].cpu()) for dic in spk_utt_list]))
            if len(uniq_spks)==1:
                continue
            # spk_utt_states
            spk_utt_states = {spk: [] for spk in uniq_spks}

            for spk in uniq_spks:
                for dic in spk_utt_list:
                    if spk == dic['utt'][2]:
                        spk_utt_states[spk].append(last_hidden_state[i, dic['utt'][0]:dic['utt'][1]])
            
            
            #---------- hitesh------------------------------
            # positive samples
            # L_pos = 0
            # L_neg = 0 

            # sampled_spk_utt_states = []           

            # for spk in uniq_spks:
            #     utts = len(spk_utt_states[spk])
            #     spk_utt = []
            #     if utts > 1:
            #         # ids = random.sample(list(range(len(spk_utt_states[spk]))), random.randint(1, utts))
            #         ids = random.sample(list(range(len(spk_utt_states[spk]))), 2)
            #         for i in ids:
            #           spk_utt.append(spk_utt_states[spk][i])
            #     sampled_spk_utt_states.append(spk_utt)

            # for instance in sampled_spk_utt_states:
            #   for i in range(len(instance)):
            #     for j in range(len(instance)):
            #       mat_mul = torch.einsum('ij, kj->ik', instance[i], instance[j])
            #       sigm = torch.sigmoid(mat_mul)
            #       log = torch.log(sigm)
            #       L_pos += torch.sum(-1 * log)
            # # print("L_pos", L_pos)

            # #negative loss
            # for i in range(0,len(sampled_spk_utt_states)):
            #   instance = sampled_spk_utt_states[i]

            #   neg_instances = sampled_spk_utt_states[:i]+sampled_spk_utt_states[i+1:]
            #   neg_instances = list(itertools.chain(*neg_instances))
            #   # neg_instances = random.choices(neg_instances,k = random.randint(1, len(neg_instances)))
            #   if len(neg_instances)>0:
            #     # print(len(neg_instances))
            #     # print("-------------------------")
            #     # print(sampled_spk_utt_states)
            #     neg_instances = random.choices(neg_instances,k = 2)
            #     for i in range(len(instance)):
            #       for j in range(len(neg_instances)):
            #         mat_mul = torch.einsum('ij, kj->ik', instance[i], neg_instances[j])
            #         sigm = torch.sigmoid(mat_mul)
            #         log = torch.log(1 - sigm+1e-5)
            #         L_neg += torch.sum(-1 * log)
            #---------- hitesh------------------------------
            
            
            # positive samples
            L_pos = 0
            
            for spk in uniq_spks:
                if len(spk_utt_states[spk]) > 1:
                    ids = random.sample(list(range(len(spk_utt_states[spk]))), 2)
                    id1 = ids[0]
                    id2 = ids[1]
                    mat_mul = torch.einsum('ij, kj->ik', spk_utt_states[spk][id1], spk_utt_states[spk][id2])
                    sigm = torch.sigmoid(mat_mul)
                    log = torch.log(sigm)
                    L_pos += torch.sum(-1 * log)
                    L_pos = torch.nan_to_num(L_pos, posinf = 1e10, neginf = -1e10)
            # print("L_pos", L_pos)
            # negative samples
            
            L_neg = 0
            for spk in uniq_spks:
                new_uniq_spks = uniq_spks.copy()
                new_uniq_spks.remove(spk)

                spk2 = random.choice(new_uniq_spks)

                id1 = random.randint(0, len(spk_utt_states[spk])-1)
                id2 = random.randint(0, len(spk_utt_states[spk2])-1)

                mat_mul = torch.einsum('ij, kj->ik', spk_utt_states[spk][id1], spk_utt_states[spk2][id2])
                sigm = torch.sigmoid(mat_mul)
                # print(1 - sigm)
                # print(1 - sigm+1e-5)
                log = torch.log(1 - sigm+1e-5)
                L_neg += torch.sum(-1 * log)
                
                L_neg = torch.nan_to_num(L_neg, posinf = 1e10, neginf = -1e10)

            # print("L_neg", L_neg)
            
            batch_scl += L_pos
            batch_scl += L_neg
            

        batch_scl /= last_hidden_state.size(0)
        gc.collect()
        return batch_scl
    
    def turn_scl(self,
                  last_hidden_state: torch.FloatTensor,
                  spk_utt_pos: torch.LongTensor,
    ) -> torch.FloatTensor:
        r"""
        last_hidden_state (torch.LongTensor) of shape (batch_size, sequence_length, n_dims):
            Output of the last layer of the encoder.
        spk_utt_pos (torch.LongTensor) of shape (batch_size, sequence_length,):
            metadata about the speaker tokens and utterance tokens
        Returns:
        Turn Level Supervised Constrastive Loss (torch.LongTensor)
        """
        batch_scl = 0
        for i in range(len(spk_utt_pos)):
            batch_element = spk_utt_pos[i]
            spk_utt_list = []
            spk_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            utt_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            for j in range(len(batch_element)):
                if batch_element[j] == 0 and j > 0:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                         'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    break
                if batch_element[j] > 0 and spk_dict['bool'] == False:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    if j > 1:
                        spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                             'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    spk_dict['start'] = j
                    spk_dict['bool'] = True
                    spk_dict['spk_id'] = batch_element[j]
                    

                if batch_element[j] < 0 and spk_dict['bool'] == True:
                    spk_dict['end'] = j
                    spk_dict['bool'] = False
                    utt_dict['spk_id'] = spk_dict['spk_id']
                    utt_dict['start'] = j
                    utt_dict['bool'] = True
            # uniq spks
            if spk_utt_list[0]['spk'][2]==0:
                continue
            uniq_spks = list(set([int(dic['spk'][2].cpu()) for dic in spk_utt_list]))
            if len(uniq_spks)==1:
                continue
            # spk_utt_states
            spk_utt_states = {spk: [] for spk in uniq_spks}

            for spk in uniq_spks:
                for dic in spk_utt_list:
                    if spk == dic['utt'][2]:
                        mean_pool = torch.mean(last_hidden_state[i, dic['utt'][0]:dic['utt'][1]], 0)
                        spk_utt_states[spk].append(mean_pool)

            # positive samples
            L_pos = 0
            for spk in uniq_spks:
                if len(spk_utt_states[spk]) > 1:
                    ids = random.sample(list(range(len(spk_utt_states[spk]))), 2)
                    id1 = ids[0]
                    id2 = ids[1]
                    mat_mul = torch.einsum('i, j->', spk_utt_states[spk][id1], spk_utt_states[spk][id2])
                    sigm = torch.sigmoid(mat_mul)
                    log = torch.log(sigm)
                    L_pos += torch.sum(-1 * log)
                    # L_pos = torch.nan_to_num(L_pos, posinf = 1e10, neginf = -1e10)
            # print("L_pos", L_pos)
            # negative samples
            L_neg = 0
            for spk in uniq_spks:
                new_uniq_spks = uniq_spks.copy()
                new_uniq_spks.remove(spk)

                spk2 = random.choice(new_uniq_spks)

                id1 = random.randint(0, len(spk_utt_states[spk])-1)
                id2 = random.randint(0, len(spk_utt_states[spk2])-1)

                mat_mul = torch.einsum('i, j->', spk_utt_states[spk][id1], spk_utt_states[spk2][id2])
                sigm = torch.sigmoid(mat_mul)
                # print(1 - sigm)
                # print(1 - sigm+1e-5)
                log = torch.log(1 - sigm+1e-5)
                L_neg += torch.sum(-1 * log)
                
                # L_neg = torch.nan_to_num(L_neg, posinf = 1e10, neginf = -1e10)

            # print("L_neg", L_neg)
            
            batch_scl += L_pos
            batch_scl += L_neg
            

        batch_scl /= last_hidden_state.size(0)
        gc.collect()
        return batch_scl
    
    def global_scl(self,
                  last_hidden_state: torch.FloatTensor,
                  spk_utt_pos: torch.LongTensor,
    ) -> torch.FloatTensor:
        r"""
        last_hidden_state (torch.LongTensor) of shape (batch_size, sequence_length, n_dims):
            Output of the last layer of the encoder.
        spk_utt_pos (torch.LongTensor) of shape (batch_size, sequence_length,):
            metadata about the speaker tokens and utterance tokens
        Returns:
        Turn Level Supervised Constrastive Loss (torch.LongTensor)
        """
        batch_scl = 0
        for i in range(len(spk_utt_pos)):
            batch_element = spk_utt_pos[i]
            spk_utt_list = []
            spk_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            utt_dict = {'start': 0, 'end': 0, 'spk_id': 0, 'bool': False}
            for j in range(len(batch_element)):
                if batch_element[j] == 0 and j > 0:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                         'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    break
                if batch_element[j] > 0 and spk_dict['bool'] == False:
                    utt_dict['end'] = j
                    utt_dict['bool'] = False
                    if j > 1:
                        spk_utt_list.append({'spk': [spk_dict['start'], spk_dict['end'], spk_dict['spk_id']],
                                             'utt': [utt_dict['start'], utt_dict['end'], utt_dict['spk_id']]})
                    spk_dict['start'] = j
                    spk_dict['bool'] = True
                    spk_dict['spk_id'] = batch_element[j]
                    

                if batch_element[j] < 0 and spk_dict['bool'] == True:
                    spk_dict['end'] = j
                    spk_dict['bool'] = False
                    utt_dict['spk_id'] = spk_dict['spk_id']
                    utt_dict['start'] = j
                    utt_dict['bool'] = True
            # uniq spks
            if spk_utt_list[0]['spk'][2]==0:
                continue
            uniq_spks = list(set([int(dic['spk'][2].cpu()) for dic in spk_utt_list]))
            if len(uniq_spks)==1:
                continue
            # spk_utt_states
            spk_utt_states = {spk: [] for spk in uniq_spks}

            for spk in uniq_spks:
                for dic in spk_utt_list:
                    if spk == dic['utt'][2]:
                        mean_pool = torch.mean(last_hidden_state[i, dic['utt'][0]:dic['utt'][1]], 0)
                        spk_utt_states[spk].append(mean_pool)

            # positive samples
            L_pos = 0
            L_neg = 0
            for spk in uniq_spks:
                if len(spk_utt_states[spk]) > 1:
                    ids = random.choice(list(range(len(spk_utt_states[spk]))))
                    
                    spk_mean_exc = torch.mean(torch.vstack([spk_utt_states[spk][temp] for temp in range(len(spk_utt_states[spk])) if temp != ids]), 0)
                    
                    pos_mat_mul = torch.einsum('i, j->', spk_utt_states[spk][ids], spk_mean_exc)
                    pos_sigm = torch.sigmoid(pos_mat_mul)
                    pos_log = torch.log(pos_sigm)
                    L_pos += torch.sum(-1 * pos_log)

                    # negative sample

                    new_uniq_spks = uniq_spks.copy()
                    new_uniq_spks.remove(spk)
                    
                    spk2 = random.choice(new_uniq_spks)
                    id_neg = random.choice(list(range(len(spk_utt_states[spk2]))))
                    neg_mat_mul = torch.einsum('i, j->', spk_utt_states[spk2][id_neg], spk_mean_exc)
                    neg_sigm = torch.sigmoid(neg_mat_mul)
                    neg_log = torch.log(1 - neg_sigm+1e-5)
                    L_neg += torch.sum(-1 * neg_log)
                

            # print("L_neg", L_neg)
            
            batch_scl += L_pos
            batch_scl += L_neg
            


        batch_scl /= last_hidden_state.size(0)
        gc.collect()
        return batch_scl
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        spk_utt_pos: Optional[torch.Tensor] = None, ##changed here
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`
        Returns:
        """
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            turn_attention_mask=None
            token_encoder_outputs=None
            tog_encoder_outputs=None
            
            if 'token' in self.SCLossesList:
                token_encoder_outputs = self.encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    head_mask=head_mask,
                    inputs_embeds=inputs_embeds,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                )

            if 'turn' in self.SCLossesList or 'global' in self.SCLossesList:
                tog_attention_mask = torch.where(spk_utt_pos>0, 0, attention_mask)
                tog_encoder_outputs = self.encoder(
                    input_ids=input_ids,
                    attention_mask=tog_attention_mask,
                    head_mask=head_mask,
                    inputs_embeds=inputs_embeds,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )
            
        hidden_states = encoder_outputs[0]

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            masked_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
        
        
        # added here
        sc_loss = 0
        if 'token' in self.SCLossesList and labels is not None:
            sc_loss += self.token_scl(last_hidden_state=token_encoder_outputs['last_hidden_state'], spk_utt_pos=spk_utt_pos)
            # print(sc_loss)
        if 'turn' in self.SCLossesList and labels is not None:
            sc_loss += self.turn_scl(last_hidden_state=tog_encoder_outputs['last_hidden_state'], spk_utt_pos=spk_utt_pos)
        
        if 'global' in self.SCLossesList and labels is not None:
            sc_loss += self.global_scl(last_hidden_state=tog_encoder_outputs['last_hidden_state'], spk_utt_pos=spk_utt_pos)
        
        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((masked_lm_loss+(self.scl_coeff*sc_loss),) + output) if loss is not None else output

        loss = None
        if masked_lm_loss is None:
            loss = None
        else:
            loss = masked_lm_loss+(self.scl_coeff*sc_loss)      


        



        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    

In [17]:
model = T5WithSCL.from_pretrained(model_checkpoint)
model.set_losses_list(['token','turn','global'])
model.set_scl_coeff(0.1)


### Fine-tuning the model

In [18]:
# Parameters\
batch_size=4
training_args = Seq2SeqTrainingArguments(
    output_dir="t5-token-b4c0.1",
    num_train_epochs=5,
    do_train=True,
    do_eval=True,
    evaluation_strategy = "epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=1e-4,
    warmup_steps=500,
    weight_decay=0.1,
    # label_smoothing_factor=0.1, ## causes to throw an error
    predict_with_generate=True,
    # logging_dir="logs",
    logging_steps=10,
    # save_total_limit=3,
)


data_collator = CustomCollatorForSeq2Seq(tokenizer, model=model)


In [19]:
import nltk
import numpy as np
import torch
torch.cuda.empty_cache()
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    for i in range(0,50):
      # print(tokenized_datasets_val["dialogue"][i])
      print("-----------",i,"--------------")
      print("------>Predictions by Model")
      print(decoded_preds[i])
      print("----->Predictions Original")
      print(decoded_labels[i])
      print("**************************")
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [20]:

trainer = CustomTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets_train,
    eval_dataset=tokenized_datasets_val,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [21]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [22]:
trainer.train()

***** Running training *****
  Num examples = 14730
  Num Epochs = 5
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 18415
  Number of trainable parameters = 222903552
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhiteshwarjmu[0m ([33makatsuki_leaf[0m). Use [1m`wandb login --relogin`[0m to force relogin


You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,0.0,,0.0,0.0,0.0,0.0,0.0
2,0.0,,0.0,0.0,0.0,0.0,0.0
3,0.0,,0.0,0.0,0.0,0.0,0.0


Saving model checkpoint to t5-token-b4c0.1/checkpoint-500
Configuration saved in t5-token-b4c0.1/checkpoint-500/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-500/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-500/tokenizer_config.json
Special tokens file saved in t5-token-b4c0.1/checkpoint-500/special_tokens_map.json
Saving model checkpoint to t5-token-b4c0.1/checkpoint-1000
Configuration saved in t5-token-b4c0.1/checkpoint-1000/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-1000/tokenizer_config.json
Special tokens file saved in t5-token-b4c0.1/checkpoint-1000/special_tokens_map.json
Saving model checkpoint to t5-token-b4c0.1/checkpoint-1500
Configuration saved in t5-token-b4c0.1/checkpoint-1500/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-1500/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-1500/token

----------- 0 --------------
------>Predictions by Model

----->Predictions Original
A will go to the animal shelter tomorrow to get a puppy for her son.
They already visited the shelter last Monday and the son chose the puppy.
**************************
----------- 1 --------------
------>Predictions by Model

----->Predictions Original
Emma and Rob love the advent calendar.
Lauren fits inside calendar various items, for instance, small toys and Christmas decorations.
Her children are excited whenever they get the calendar.
**************************
----------- 2 --------------
------>Predictions by Model

----->Predictions Original
Madison is pregnant but she doesn't want to talk about it.
Patricia Stevens got married and she thought she was pregnant.
**************************
----------- 3 --------------
------>Predictions by Model

----->Predictions Original
Marla found a pair of boxers under her bed.
**************************
----------- 4 --------------
------>Predictions by M

Saving model checkpoint to t5-token-b4c0.1/checkpoint-4000
Configuration saved in t5-token-b4c0.1/checkpoint-4000/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-4000/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-4000/tokenizer_config.json
Special tokens file saved in t5-token-b4c0.1/checkpoint-4000/special_tokens_map.json
Saving model checkpoint to t5-token-b4c0.1/checkpoint-4500
Configuration saved in t5-token-b4c0.1/checkpoint-4500/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-4500/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-4500/tokenizer_config.json
Special tokens file saved in t5-token-b4c0.1/checkpoint-4500/special_tokens_map.json
Saving model checkpoint to t5-token-b4c0.1/checkpoint-5000
Configuration saved in t5-token-b4c0.1/checkpoint-5000/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-5000/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-5000/

----------- 0 --------------
------>Predictions by Model

----->Predictions Original
A will go to the animal shelter tomorrow to get a puppy for her son.
They already visited the shelter last Monday and the son chose the puppy.
**************************
----------- 1 --------------
------>Predictions by Model

----->Predictions Original
Emma and Rob love the advent calendar.
Lauren fits inside calendar various items, for instance, small toys and Christmas decorations.
Her children are excited whenever they get the calendar.
**************************
----------- 2 --------------
------>Predictions by Model

----->Predictions Original
Madison is pregnant but she doesn't want to talk about it.
Patricia Stevens got married and she thought she was pregnant.
**************************
----------- 3 --------------
------>Predictions by Model

----->Predictions Original
Marla found a pair of boxers under her bed.
**************************
----------- 4 --------------
------>Predictions by M

Saving model checkpoint to t5-token-b4c0.1/checkpoint-7500
Configuration saved in t5-token-b4c0.1/checkpoint-7500/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-7500/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-7500/tokenizer_config.json
Special tokens file saved in t5-token-b4c0.1/checkpoint-7500/special_tokens_map.json
Saving model checkpoint to t5-token-b4c0.1/checkpoint-8000
Configuration saved in t5-token-b4c0.1/checkpoint-8000/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-8000/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-8000/tokenizer_config.json
Special tokens file saved in t5-token-b4c0.1/checkpoint-8000/special_tokens_map.json
Saving model checkpoint to t5-token-b4c0.1/checkpoint-8500
Configuration saved in t5-token-b4c0.1/checkpoint-8500/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-8500/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-8500/

----------- 0 --------------
------>Predictions by Model

----->Predictions Original
A will go to the animal shelter tomorrow to get a puppy for her son.
They already visited the shelter last Monday and the son chose the puppy.
**************************
----------- 1 --------------
------>Predictions by Model

----->Predictions Original
Emma and Rob love the advent calendar.
Lauren fits inside calendar various items, for instance, small toys and Christmas decorations.
Her children are excited whenever they get the calendar.
**************************
----------- 2 --------------
------>Predictions by Model

----->Predictions Original
Madison is pregnant but she doesn't want to talk about it.
Patricia Stevens got married and she thought she was pregnant.
**************************
----------- 3 --------------
------>Predictions by Model

----->Predictions Original
Marla found a pair of boxers under her bed.
**************************
----------- 4 --------------
------>Predictions by M

Saving model checkpoint to t5-token-b4c0.1/checkpoint-11500
Configuration saved in t5-token-b4c0.1/checkpoint-11500/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-11500/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-11500/tokenizer_config.json
Special tokens file saved in t5-token-b4c0.1/checkpoint-11500/special_tokens_map.json
Saving model checkpoint to t5-token-b4c0.1/checkpoint-12000
Configuration saved in t5-token-b4c0.1/checkpoint-12000/config.json
Model weights saved in t5-token-b4c0.1/checkpoint-12000/pytorch_model.bin
tokenizer config file saved in t5-token-b4c0.1/checkpoint-12000/tokenizer_config.json
Special tokens file saved in t5-token-b4c0.1/checkpoint-12000/special_tokens_map.json


BrokenPipeError: ignored

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f5cf2ca1fd0>> (for post_run_cell):


BrokenPipeError: ignored

In [23]:
# evaluate before training for comparison
trainer.evaluate()

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f5cf2ca1fd0>> (for pre_run_cell):


BrokenPipeError: ignored

***** Running Evaluation *****
  Num examples = 818
  Batch size = 4


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,0.0,,0.0,0.0,0.0,0.0,0.0
2,0.0,,0.0,0.0,0.0,0.0,0.0
3,0.0,,0.0,0.0,0.0,0.0,0.0


----------- 0 --------------
------>Predictions by Model

----->Predictions Original
A will go to the animal shelter tomorrow to get a puppy for her son.
They already visited the shelter last Monday and the son chose the puppy.
**************************
----------- 1 --------------
------>Predictions by Model

----->Predictions Original
Emma and Rob love the advent calendar.
Lauren fits inside calendar various items, for instance, small toys and Christmas decorations.
Her children are excited whenever they get the calendar.
**************************
----------- 2 --------------
------>Predictions by Model

----->Predictions Original
Madison is pregnant but she doesn't want to talk about it.
Patricia Stevens got married and she thought she was pregnant.
**************************
----------- 3 --------------
------>Predictions by Model

----->Predictions Original
Marla found a pair of boxers under her bed.
**************************
----------- 4 --------------
------>Predictions by M

BrokenPipeError: ignored

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f5cf2ca1fd0>> (for post_run_cell):


BrokenPipeError: ignored

In [None]:
# torch.save(model.state_dict(), "./best_model.bin")