In [1]:
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to /home/worachotn/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/worachotn/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/worachotn/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/worachotn/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [2]:
from nltk.corpus import wordnet
from nltk.corpus import stopwords

lemmatizer = nltk.stem.WordNetLemmatizer()  # Initiate nltk lemmatizer

In [3]:
def simple_tokenize(sentence):
    """ Simple function for tokenizing text with nltk """
    return nltk.word_tokenize(sentence, preserve_line=True)

In [4]:
import math
import os
import pprint
import logging

import datasets
import nltk
import numpy as np
import torch
from tqdm.auto import tqdm

import transformers
from accelerate import Accelerator
from filelock import FileLock
from transformers import AdamW, get_scheduler, set_seed

from transformers.file_utils import is_offline_mode
from transformers.utils.versions import require_version

# from args import parse_args
# from data_loader import raw_data_loader, data_processor
from model_loader import model_loader
from rouge_s import py_rouge_scores
from utils import label_smoothed_nll_loss, postprocess_text

import json

from collections import Counter

from nltk.util import ngrams
from nltk import word_tokenize,sent_tokenize

import random
import utils

import datasets
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq

In [5]:
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

In [6]:
from transformers import (
    MODEL_MAPPING,
    SchedulerType,
)

# You should update this to your particular problem to have better documentation of `model_type`
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

In [7]:
import argparse
arg_parser = argparse.ArgumentParser(description="BART")
arg_parser.add_argument("--len_input", dest="len_input", type=str, default=None, help="set up prefix input",choices=('no', 'topic', 'length', 'topic-length', 'length-topic', 'simple', 'simple-topic-tagger', 'simple-tagger'))
arg_parser.add_argument("--len_output", dest="len_output", default=None, help="Use the ctrlen model or not", choices=('no', 'topic', 'length', 'topic-length', 'length-topic'))
arg_parser.add_argument("--output_dir", dest="output_dir", type=str, default="./output/1", help="default")
arg_parser.add_argument("--train_file", dest="train_file", type=str, default=None, help="A csv or a json file containing the training data.")
arg_parser.add_argument("--validation_file", dest="validation_file", type=str, default=None, help="A csv or a json file containing the validation data.")
arg_parser.add_argument("--test_file", dest="test_file", type=str, default=None, help="A csv or a json file containing the test data.")
arg_parser.add_argument("--ignore_pad_token_for_loss", dest="ignore_pad_token_for_loss", type=bool, default=True, help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",)
arg_parser.add_argument("--text_column", dest="text_column", type=str, default="dialogue", help="The name of the column in the datasets containing the full texts (for summarization).")
arg_parser.add_argument("--summary_column", dest="summary_column", type=str, default="summary", help="The name of the column in the datasets containing the summaries (for summarization).")
arg_parser.add_argument("--model_name_or_path", dest="model_name_or_path", type=str, default="facebook/bart-large", help="Path to pretrained model or model identifier from huggingface.co/models.")
arg_parser.add_argument("--model_type", dest="model_type", type=str, default="bart", help="Model type to use if training from scratch.", choices=MODEL_TYPES)
arg_parser.add_argument("--max_source_length", dest="max_source_length", type=int, default=1024, help="default")
arg_parser.add_argument("--source_prefix", dest="source_prefix", type=str, default=None, help="A prefix to add before every source text " "(useful for T5 models).")
arg_parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.")
# arg_parser.add_argument("--overwrite_cache", dest="overwrite_cache", type=lambda x:bool(strtobool(x)), default=True, help="default")
arg_parser.add_argument("--overwrite_cache", dest="overwrite_cache", type=bool, default=None, help="Overwrite the cached training and evaluation sets")
arg_parser.add_argument("--min_target_length", dest="min_target_length", type=int, default=1, help="The minimal total sequence length for target text")
arg_parser.add_argument("--max_target_length", dest="max_target_length", type=int, default=128, help="The maximum total sequence length for target text after "
        "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
        "during ``evaluate`` and ``predict``.")
arg_parser.add_argument("--num_beams", dest="num_beams", type=int, default=4, help="Number of beams to use for evaluation. This argument will be "
        "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.")
arg_parser.add_argument("--learning_rate", dest="learning_rate", type=float, default=5e-5, help="Initial learning rate (after the potential warmup period) to use.")
arg_parser.add_argument("--pad_to_max_length", action="store_true", help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",)
arg_parser.add_argument("--weight_decay", dest="weight_decay", type=float, default=1e-3, help="Weight decay to use.")
arg_parser.add_argument("--label_smoothing", dest="label_smoothing", type=float, default=0.1, help="hyperparameter for label smoothing.")
arg_parser.add_argument("--length_penalty", dest="length_penalty", type=float, default=1.0, help="large - longer sequence, small - shorter sequence")
arg_parser.add_argument("--num_train_epochs", dest="num_train_epochs", type=int, default=15, help="Total number of training epochs to perform.")
arg_parser.add_argument("--per_device_train_batch_size", dest="per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.")
arg_parser.add_argument("--gradient_accumulation_steps", dest="gradient_accumulation_steps", type=int, default=64, help="Number of updates steps to accumulate before performing a backward/update pass.")
arg_parser.add_argument("--per_device_eval_batch_size", dest="per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.")
arg_parser.add_argument("--per_device_test_batch_size", dest="per_device_test_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.")
arg_parser.add_argument("--num_warmup_steps", dest="num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.")
arg_parser.add_argument("--cache_dir", dest="cache_dir", type=str, default="./output/cache", help="default")
arg_parser.add_argument("--seed", dest="seed", type=int, default=12345, help="default")
# arg_parser.add_argument("-f", required=False) #important
arg_parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name")
arg_parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name")
arg_parser.add_argument("--use_slow_tokenizer", dest="use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).")
arg_parser.add_argument("--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.")
arg_parser.add_argument("--lr_scheduler_type", type=SchedulerType, default="linear", help="The scheduler type to use.", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"])
arg_parser.add_argument("--ctrlen_model", action='store_true', default=False, help="Use the ctrlen model or not")
arg_parser.add_argument("--sim_window_size", type=int, default=5, help="window size for computing loss.")
arg_parser.add_argument("--sim_loss", type=float, default=0, help="the loss weight for similarity scores.")
arg_parser.add_argument("--special_len_token_init", type=str, default=None, help="ways to initialize special token for length (random, zero, token_embs)")
arg_parser.add_argument("--embedding_lr", type=float, default=5e-5, help="Initial learning rate for embedding layers.")
arg_parser.add_argument("--len_start", type=int, default=1, help="start length.")
arg_parser.add_argument("--len_end", type=int, default=100, help="end length.")
arg_parser.add_argument("--data_aug",action='store_true',default=False,help="whether to perform data augmentation or not")
arg_parser.add_argument("--pred_len", action='store_true', default=False, help="whether to use the golden length or predicted length")
arg_parser.add_argument("--shuffle", action='store_true', default=False, help="whether to shuffle the dataset to balance train/validation/test")
arg_parser.add_argument("--debug", action='store_true', default=False, help="Use the debug mode or not")

arg_parser.add_argument("--topic_tagger", dest="topic_tagger", type=bool, default=None, help="Use topic tag [TAG] or not")

_StoreAction(option_strings=['--topic_tagger'], dest='topic_tagger', nargs=None, const=None, default=None, type=<class 'bool'>, choices=None, required=False, help='Use topic tag [TAG] or not', metavar=None)

In [8]:
args = arg_parser.parse_args('')

In [9]:
args.train_file = "./data/dialogtest/dialogsum.train.jsonl"
args.validation_file = "./data/dialogtest/dialogsum.dev.jsonl"
args.test_file = "./data/dialogtest/dialogsum.test.jsonl"
args.text_column = "dialogue"
args.summary_column = "summary"
args.model_name_or_path = "facebook/bart-large"
args.model_type = "bart"
args.max_source_length = 1024
args.min_target_length = 1
args.max_target_length = 128
args.num_beams = 4
args.learning_rate = 5e-5
args.weight_decay = 1e-3
args.label_smoothing = 0.1
args.length_penalty = 1.0 
args.num_train_epochs = 3
args.per_device_train_batch_size = 2 
args.gradient_accumulation_steps = 64 
args.per_device_eval_batch_size = 8 
args.per_device_test_batch_size = 8 
args.num_warmup_steps = 0 
args.cache_dir = "./output/cache"
args.overwrite_cache = True
args.seed = 12345

args.len_input = 'topic-length'
args.len_output = 'no'
args.output_dir = "./output/1-bart-baseline-loss"

args.topic_tagger = False

In [10]:
print(args.len_input)
print(args.topic_tagger)

topic-length
False


In [11]:
def simple_tokenize(sentence):
    """ Simple function for tokenizing text with nltk """
    return nltk.word_tokenize(sentence)

def nltk_to_pos(pos):
    """ Simple function for converting nltk pos to wordnet pos"""
    if pos.startswith('J'):
        return wordnet.ADJ
    elif pos.startswith('V'):
        return wordnet.VERB
    elif pos.startswith('N'):
        return wordnet.NOUN
    elif pos.startswith('R'):
        return wordnet.ADV
    else:
        return None

def lemmatize_text(text):
    """ Function to lemmatize text according to the wordnet POS of each token """

    tokenized_text = nltk.word_tokenize(text)
    POS_assigned_text = nltk.pos_tag(tokenized_text)

    available_POS = map(lambda x: (x[0], nltk_to_pos(x[1])), POS_assigned_text)

    lemmatized_text = [token if pos is None
                       else lemmatizer.lemmatize(token, pos)
                       for token, pos in available_POS]

    return lemmatized_text

def build_tagger(original_tokens,lemmatized_tokens, topic_list, idx):
    tagged_tokens = []
    # Extract all the seed words according to the corresponding topic
    token_topics = topic_list
    original_list = original_tokens[idx]

    for j, token in enumerate(lemmatized_tokens[idx]):
        # If the lemmatized form of the token is in topic seeds, tag the original token
        if token.lower() in token_topics:
            # print(token.lower())
            if token.lower() not in stopwords.words('english'):
                # print("="*100)
                # print(token.lower())
                original_list[j] = '[TAG]' + original_list[j] + '[TAG]'

    tagged_tokens.append(" ".join(original_list))
    return tagged_tokens

In [12]:
def load_from_dialogsum(args, file_path):
    ''' load dialoguesum jsonl data '''

    data = []

    with open(file_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))

    id_list = [sample['fname'] for sample in data]
    dialogue_list = [sample['dialogue'] for sample in data]

    if 'summary' in data[0]:
        # summary
        summary_list = [sample['summary'] for sample in data]
        # topic
        topic_list = [sample['topic'] for sample in data]

    elif 'summary1' in data[0]:

        id_list1 = [id+"_sum1" for id in id_list]
        id_list2 = [id+"_sum2" for id in id_list]
        id_list3 = [id+"_sum3" for id in id_list]

        id_list = id_list1 + id_list2 + id_list3
        dialogue_list = dialogue_list + dialogue_list + dialogue_list

        # summary
        summary_list1 = [sample['summary1'] for sample in data]
        summary_list2 = [sample['summary2'] for sample in data]
        summary_list3 = [sample['summary3'] for sample in data]

        summary_list = summary_list1 + summary_list2 + summary_list3

        # topic
        topic_list1 = [sample['topic1'] for sample in data]
        topic_list2 = [sample['topic2'] for sample in data]
        topic_list3 = [sample['topic3'] for sample in data]

        topic_list = topic_list1 + topic_list2 + topic_list3
        
    negative_topic_list = []
    for topic in topic_list:
        negative_topic = random.choice(topic_list)
        if negative_topic == topic:
            negative_topic = random.choice(negative_topic)
        negative_topic_list.append(negative_topic)
        

    if args.topic_tagger:
        topic_tagger = []
        original_tokens = [simple_tokenize(x) for x in dialogue_list]
        lemmatized_tokens = [lemmatize_text(x) for x in dialogue_list]
        for i in range(len(lemmatized_tokens)):
            tagger = build_tagger(original_tokens, lemmatized_tokens, topic_list[i], i)
            topic_tagger.extend(tagger)

        data_dict = {'id': id_list,
                     'dialogue': topic_tagger,
                     'summary': summary_list,
                     'topic': topic_list}
    else:
        data_dict = {'id': id_list,
                     'dialogue': dialogue_list,
                     'summary': summary_list,
                     'topic': topic_list,
                     'negative_topic': negative_topic_list}

    data_dict = Dataset.from_dict(data_dict)

    return data_dict

In [13]:
train_dict = load_from_dialogsum(args, args.train_file)

In [14]:
train_dict

Dataset({
    features: ['id', 'dialogue', 'summary', 'topic', 'negative_topic'],
    num_rows: 1500
})

In [15]:
train_dict['topic'][10]

'do a favor'

In [16]:
train_dict['negative_topic'][10]

'turning off TV'

In [17]:
all(train_dict['topic'][i] != train_dict['negative_topic'][i] for i in range(len(train_dict['topic'])))

True

In [18]:
accelerator = Accelerator(mixed_precision="fp16")
logger.info(accelerator.state)

10/18/2023 14:30:59 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16



In [19]:
config, tokenizer, model = model_loader(accelerator, logger, args)

In [20]:
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
accelerator.is_local_main_process
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
set_seed(args.seed)
torch.backends.cudnn.enabled = False 
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
accelerator.is_main_process
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()

In [21]:
def raw_data_loader(args):
    ''' load raw datasets from csv files '''

    data_files = {}
    if args.train_file is not None:
        data_files["train"] = args.train_file
    if args.validation_file is not None:
        data_files["validation"] = args.validation_file
    if args.test_file is not None:
        data_files["test"] = args.test_file

    if 'dialogsum' in args.train_file:
        train_dict = load_from_dialogsum(args, args.train_file)
        val_dict   = load_from_dialogsum(args, args.validation_file)
        test_dict  = load_from_dialogsum(args, args.test_file)

    train_dict = utils.len_adjust(args, train_dict, 'train')
    val_dict   = utils.len_adjust(args, val_dict, 'val')
    test_dict  = utils.len_adjust(args, test_dict, 'test')

    raw_datasets = datasets.DatasetDict({"train":train_dict, "validation":val_dict, "test":test_dict})

    return raw_datasets

In [22]:
raw_datasets = raw_data_loader(args)

In [23]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'negative_dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'negative_dialogue', 'summary', 'topic'],
        num_rows: 50
    })
    test: Dataset({
        features: ['id', 'dialogue', 'negative_dialogue', 'summary', 'topic'],
        num_rows: 150
    })
})

In [24]:
raw_datasets['train']['dialogue'][1]

"Topic of Summary: vaccines. Length of Summary: 18. Dialogue: #Person1#: Hello Mrs. Parker, how have you been?\n#Person2#: Hello Dr. Peters. Just fine thank you. Ricky and I are here for his vaccines.\n#Person1#: Very well. Let's see, according to his vaccination record, Ricky has received his Polio, Tetanus and Hepatitis B shots. He is 14 months old, so he is due for Hepatitis A, Chickenpox and Measles shots.\n#Person2#: What about Rubella and Mumps?\n#Person1#: Well, I can only give him these for now, and after a couple of weeks I can administer the rest.\n#Person2#: OK, great. Doctor, I think I also may need a Tetanus booster. Last time I got it was maybe fifteen years ago!\n#Person1#: We will check our records and I'll have the nurse administer and the booster as well. Now, please hold Ricky's arm tight, this may sting a little."

In [25]:
raw_datasets['train']['negative_dialogue'][1]

"Topic of Summary: get dressed. Length of Summary: 18. Dialogue: #Person1#: Hello Mrs. Parker, how have you been?\n#Person2#: Hello Dr. Peters. Just fine thank you. Ricky and I are here for his vaccines.\n#Person1#: Very well. Let's see, according to his vaccination record, Ricky has received his Polio, Tetanus and Hepatitis B shots. He is 14 months old, so he is due for Hepatitis A, Chickenpox and Measles shots.\n#Person2#: What about Rubella and Mumps?\n#Person1#: Well, I can only give him these for now, and after a couple of weeks I can administer the rest.\n#Person2#: OK, great. Doctor, I think I also may need a Tetanus booster. Last time I got it was maybe fifteen years ago!\n#Person1#: We will check our records and I'll have the nurse administer and the booster as well. Now, please hold Ricky's arm tight, this may sting a little."

In [26]:
class CustomDataCollator:
    def __init__(self, tokenizer, model):
        self.tokenizer = tokenizer
        self.model = model

    def __call__(self, examples):
        # positive_input_ids = examples['dialogue']
        positive_input = [np.array(example['dialogue']) for example in examples]
        # positive_input_ids = [example['dialogue'][example['dialogue'] != 1] for example in examples]
        # negative_input_ids = examples['negative_dialogue']
        negative_input = [np.array(example['negative_dialogue']) for example in examples]
        # negative_input_ids = [example['negative_dialogue'][example['dialogue'] != 1] for example in examples]
        # summary_input_ids = examples['summary']
        summary_input = [np.array(example['summary']) for example in examples]
        # summary_input_ids = [example['summary'][example['dialogue'] != 1] for example in examples]

        positive_input_ids = [example[example != 1] for example in positive_input]
        negative_input_ids = [example[example != 1] for example in negative_input]
        summary_input_ids = [example[example != 1] for example in summary_input]
        # summary_input_ids = [example['summary'] for example in examples]
            
        # inputs["input_ids"] = tokenizer.pad
        # negative_inputs["input_ids"] = tokenizer.pad
        
        batch = self.tokenizer.pad(encoded_inputs={"input_ids": positive_input_ids+ negative_input_ids}, padding=True, return_tensors='pt')
        # batch["decoder_input_ids"] = torch.stack((inputs["labels"], inputs["labels"]))
        # batch["decoder_attention_mask"] = torch.stack((inputs["labels"], inputs["decoder_attention_mask"]))
        batch["decoder_input_ids"] = self.tokenizer.pad(encoded_inputs={"input_ids": summary_input_ids+summary_input_ids}, padding=True, return_tensors='pt')["input_ids"]
        # summary = [[(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in summary_input_ids]
        summary = self.tokenizer.pad(encoded_inputs={"input_ids": summary_input_ids+summary_input_ids}, padding=True, return_tensors='pt')["input_ids"]
        summary[summary == 1] = -100
        batch["labels"] = summary
        # summary =  self.tokenizer.pad(encoded_inputs={"input_ids": summary_input_ids+summary_input_ids}, padding=True, return_tensors='pt')["input_ids"]
        # summary = [[(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in summary_input_ids]

        return batch

In [27]:
def data_processor(logger, args, accelerator, raw_datasets, tokenizer, model):
    ''' prepare dataset format for train/val/test '''
    def preprocess_function(examples):
        positive_documents = examples['dialogue']
        negative_documents = examples['negative_dialogue']
        source_summaries = examples['summary']

        # Tokenize and create input tensors
        inputs = tokenizer(
            positive_documents,
            # negative_summaries,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_source_length  # Adjust as needed
        )
        
        # Tokenize and create input tensors
        negative_inputs = tokenizer(
            negative_documents,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_source_length  # Adjust as needed
        )
        
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                source_summaries,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=args.max_target_length  # Adjust as needed
            )
        
        # batch = tokenizer.pad(encoded_inputs={"input_ids": inputs["input_ids"].squeeze().tolist() + negative_inputs["input_ids"].squeeze().tolist()}, padding=True, return_tensors='pt')
        # # batch["decoder_input_ids"] = torch.stack((inputs["labels"], inputs["labels"]))
        # # batch["decoder_attention_mask"] = torch.stack((inputs["labels"], inputs["decoder_attention_mask"]))
        # batch["decoder_input_ids"] = tokenizer.pad(encoded_inputs={"input_ids": labels["input_ids"].squeeze().tolist()+labels["input_ids"].squeeze().tolist()}, padding=False, return_tensors='pt')["input_ids"]
        # labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]]
        # batch["labels"] =  tokenizer.pad(encoded_inputs={"input_ids": labels["input_ids"]+labels["input_ids"]}, padding=True, return_tensors='pt')["input_ids"]

        # return batch
        model_inputs = inputs
        model_inputs["dialogue"] = inputs["input_ids"]
        model_inputs["negative_dialogue"] = negative_inputs["input_ids"]
        model_inputs["summary"] = labels["input_ids"]

        return model_inputs

    prefix = args.source_prefix if args.source_prefix is not None else ""

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    column_names = raw_datasets["train"].column_names

    # Get the column names for input/target.
    text_column = args.text_column
    if text_column not in column_names:
        raise ValueError(
            f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}"
        )

    summary_column = args.summary_column
    if summary_column not in column_names:
        raise ValueError(
            f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}"
        )

    # Temporarily set max_target_length for training.
    max_target_length = args.max_target_length
    padding = "max_length" if args.pad_to_max_length else False

    with accelerator.main_process_first():
        processed_datasets = raw_datasets.map(
            preprocess_function,
            batched=True,
            batch_size=100,
            remove_columns=column_names,
            load_from_cache_file=not args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )

    train_dataset = processed_datasets["train"]
    eval_dataset  = processed_datasets["validation"]
    test_dataset  = processed_datasets["test"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 1):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    # data_collator = DataCollatorForSeq2Seq(
    #     tokenizer,
    #     model=model,
    #     label_pad_token_id=label_pad_token_id,
    #     pad_to_multiple_of=8 if accelerator.use_fp16 else None,
    # )
    data_collator = CustomDataCollator(
        tokenizer,
        model=model,
    )

    train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size)
    eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
    test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=args.per_device_test_batch_size)

    return (train_dataloader, eval_dataloader, test_dataloader), (train_dataset, eval_dataset, test_dataset)

In [28]:
inputs = tokenizer(
            raw_datasets['train']['dialogue'][:2],
            # negative_summaries,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_source_length  # Adjust as needed
        )
negative_inputs = tokenizer(
            raw_datasets['train']['dialogue'][2:4],
            # negative_summaries,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_source_length  # Adjust as needed
        )

In [29]:
a = [i.tolist() for i in inputs['input_ids']]

In [30]:
print(a)

[[0, 48931, 9, 19584, 35, 120, 10, 1649, 12, 658, 4, 41852, 9, 19584, 35, 389, 4, 33854, 35, 849, 41761, 134, 10431, 35, 12289, 6, 427, 4, 1259, 4, 38, 437, 12521, 15633, 4, 2612, 32, 47, 259, 452, 116, 50118, 10431, 41761, 176, 10431, 35, 38, 303, 24, 74, 28, 10, 205, 1114, 7, 120, 10, 1649, 12, 658, 4, 50118, 10431, 41761, 134, 10431, 35, 3216, 6, 157, 6, 47, 2220, 75, 56, 65, 13, 195, 107, 4, 370, 197, 33, 65, 358, 76, 4, 50118, 10431, 41761, 176, 10431, 35, 38, 216, 4, 38, 1955, 25, 251, 25, 89, 16, 1085, 1593, 6, 596, 213, 192, 5, 3299, 116, 50118, 10431, 41761, 134, 10431, 35, 2647, 6, 5, 275, 169, 7, 1877, 1473, 14971, 16, 7, 465, 66, 59, 106, 419, 4, 407, 860, 7, 283, 23, 513, 683, 10, 76, 13, 110, 308, 205, 4, 50118, 10431, 41761, 176, 10431, 35, 5148, 4, 50118, 10431, 41761, 134, 10431, 35, 2780, 162, 192, 259, 4, 2486, 2473, 8, 12137, 356, 2051, 4, 4624, 10, 1844, 8016, 6, 2540, 4, 1832, 47, 4603, 6, 427, 4, 1259, 116, 50118, 10431, 41761, 176, 10431, 35, 3216, 4, 50118, 104

In [31]:
# inputs['input_ids'].shape
# # inputs['input_ids'].numpy()
# model_inputs = inputs
# negative_inputs['input_ids'].shape
# model_inputs['negative_dialogue'] = negative_inputs['input_ids']
# type(model_inputs['negative_dialogue'])

In [32]:
def preprocess_function(examples):
    positive_documents = examples['dialogue']
    negative_documents = examples['negative_dialogue']
    source_summaries = examples['summary']

    # Tokenize and create input tensors
    inputs = tokenizer(
        positive_documents,
        # negative_summaries,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=args.max_source_length  # Adjust as needed
    )
    
    # Tokenize and create input tensors
    negative_inputs = tokenizer(
        negative_documents,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=args.max_source_length  # Adjust as needed
    )
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            source_summaries,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_target_length  # Adjust as needed
        )
    
    # batch = tokenizer.pad(encoded_inputs={"input_ids": inputs["input_ids"].squeeze().tolist() + negative_inputs["input_ids"].squeeze().tolist()}, padding=True, return_tensors='pt')
    # # batch["decoder_input_ids"] = torch.stack((inputs["labels"], inputs["labels"]))
    # # batch["decoder_attention_mask"] = torch.stack((inputs["labels"], inputs["decoder_attention_mask"]))
    # batch["decoder_input_ids"] = tokenizer.pad(encoded_inputs={"input_ids": labels["input_ids"].squeeze().tolist()+labels["input_ids"].squeeze().tolist()}, padding=False, return_tensors='pt')["input_ids"]
    # labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]]
    # batch["labels"] =  tokenizer.pad(encoded_inputs={"input_ids": labels["input_ids"]+labels["input_ids"]}, padding=True, return_tensors='pt')["input_ids"]

    # return batch
    model_inputs = inputs
    model_inputs["dialogue"] = inputs["input_ids"]
    model_inputs["negative_dialogue"] = negative_inputs["input_ids"]
    model_inputs["summary"] = labels["input_ids"]

    return model_inputs

In [33]:
with accelerator.main_process_first():
    processed_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        batch_size=1000,
        # remove_columns=column_names,
        load_from_cache_file=not args.overwrite_cache,
        desc="Running tokenizer on dataset",
    )

Running tokenizer on dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]



Running tokenizer on dataset:   0%|          | 0/50 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/150 [00:00<?, ? examples/s]

In [34]:
print(type(processed_datasets['train']['dialogue'][:2]))

<class 'list'>


In [35]:
# for example in processed_datasets['train']:
#     dialogue = example['dialogue']
#     print(type(dialogue))
#     # print([i for i in dialogue if i != 1])
#     break

In [36]:
# positive_input = [example for example in processed_datasets['train']['dialogue']]

In [37]:
dataloader, processed_dataset = data_processor(logger, args, accelerator, raw_datasets, tokenizer, model)
train_dataloader, eval_dataloader, test_dataloader = dataloader
train_dataset, _, _ = processed_dataset

Running tokenizer on dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/50 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/150 [00:00<?, ? examples/s]

10/18/2023 14:31:28 - INFO - __main__ - Sample 580 of the training set: {'dialogue': [0, 48931, 9, 19584, 35, 8018, 4, 41852, 9, 19584, 35, 706, 4, 33854, 35, 849, 41761, 134, 10431, 35, 653, 18, 70, 14, 11347, 59, 116, 50118, 10431, 41761, 176, 10431, 35, 38, 95, 13414, 103, 2480, 514, 15, 6918, 6, 150, 79, 21, 11, 5, 2131, 9310, 6, 47, 197, 33, 450, 69, 652, 4, 50118, 10431, 41761, 134, 10431, 35, 370, 4395, 75, 29993, 110, 2761, 98, 203, 4, 50118, 10431, 41761, 176, 10431, 35, 83, 46909, 24, 21, 95, 10, 8018, 3795, 4, 50118, 10431, 41761, 134, 10431, 35, 370, 185, 24, 350, 444, 2128, 2150, 6, 114, 127, 2138, 56, 57, 101, 47, 77, 38, 21, 1197, 62, 6, 38, 74, 33, 1613, 5373, 4, 9427, 5, 86, 47, 4209, 69, 13495, 34156, 19, 2131, 10702, 116, 178, 77, 47, 342, 6740, 11, 69, 8492, 6, 14, 21, 95, 137, 69, 17008, 4115, 4, 50118, 10431, 41761, 176, 10431, 35, 19719, 59, 14, 3795, 6, 14, 21, 10, 410, 350, 203, 4, 125, 6918, 3829, 127, 11248, 6, 79, 460, 17216, 59, 24, 11795, 4, 50118, 10431, 

In [38]:
train_dataset

Dataset({
    features: ['dialogue', 'negative_dialogue', 'summary', 'input_ids', 'attention_mask'],
    num_rows: 1500
})

In [39]:
for step, data in enumerate(train_dataset):
    # print(data)
    # break
    print('dialogue: ', len(data['dialogue']))
    print('negative_dialogue: ', len(data['negative_dialogue']))
    print('summary: ', len(data['summary']))
    # # print('input_ids: ', batch.input_ids[1].shape)
    # # print('attention_mask: ', batch.attention_mask[1].shape)
    # # print('negative_input_ids: ', batch.negative_input_ids[1].shape)
    # # print('negative_attention_mask: ', batch.negative_attention_mask[1].shape)
    # # print('labels: ', batch.labels[1].shape)
    print('='*100)
    break
    # if step == 500:
    #     print('dialogue: ', len(data['dialogue']))
    #     print('negative_dialogue: ', len(data['negative_dialogue']))
    #     print('summary: ', len(data['summary']))
    # if step == 5:
    #     break

dialogue:  496
negative_dialogue:  497
summary:  74


In [40]:
# len(tokenizer.pad({"input_ids": train_dataset['input_ids'][0]}, padding="max_length", max_length=1024, return_tensors='pt', pad_to_multiple_of=8)['input_ids'])

In [41]:
for step, batch in enumerate(train_dataloader):
    # print(batch)
    print('input_ids: ', batch.input_ids[0].shape)
    print('attention_mask: ', batch.attention_mask[0].shape)
    print('labels: ', batch.labels[0])
    labels = batch.labels[0]
    batch_labels = batch.labels
    # print('input_ids: ', batch.input_ids[1].shape)
    # print('attention_mask: ', batch.attention_mask[1].shape)
    # print('negative_input_ids: ', batch.negative_input_ids[1].shape)
    # print('negative_attention_mask: ', batch.negative_attention_mask[1].shape)
    # print('labels: ', batch.labels[1].shape)
    if step == 1:
        break

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


input_ids:  torch.Size([416])
attention_mask:  torch.Size([416])
labels:  tensor([    0, 10431, 41761,   176, 10431,  1072,     7,   671,     5,  3369,
         3400,   142,    51,   218,    75,   914,    69,  3089, 11556,    53,
          849, 41761,   134, 10431,   161,     5,  3369,  3400,    58,    15,
         1392,    98,   849, 41761,   176, 10431,    64,    75,   671,    24,
            4,  1773,   849, 41761,   176, 10431,    16,    10,  1675,  2111,
            6,     5,  1044, 11687,     7,   146,    41,  8219,    98,   849,
        41761,   134, 10431,  2029,    69,  1400,  7751,     4,     2])
input_ids:  torch.Size([191])
attention_mask:  torch.Size([191])
labels:  tensor([    0, 41004,    18, 15126,     7,   989,   334,     8,   708,     7,
          213,     7,  2737,    53,    24,  7971,    15,    69,    83,   672,
          775,     4,     2])


In [42]:
# = = = Training Preparation = = =
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]

if args.ctrlen_model: 
    no_decay_emb_matrix = ["bias", "LayerNorm.weight", "shared"]
else:
    no_decay_emb_matrix = ["bias", "LayerNorm.weight"]

optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay_emb_matrix)],
        "weight_decay": args.weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]

if args.ctrlen_model:
    if args.model_type == 'bart': 
        optimizer_grouped_parameters.extend([{
            "params": model.seq2seq_model.model.shared.parameters(),
            "lr": args.embedding_lr}])
    elif args.model_type == 't5':
        optimizer_grouped_parameters.extend([{
            "params": model.seq2seq_model.shared.parameters(),
            "lr": args.embedding_lr}])
    else:
        raise ValueError('{} model type not implemented'.format(args.model_type))

# optimizer
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)



In [43]:
model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, test_dataloader
)

In [44]:
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=args.num_warmup_steps,
    num_training_steps=args.max_train_steps,
)

# = = = = = = = = = = = = = = = = Train = = = = = = = = = = = = = = = = = = =
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), desc="Training: ", disable=not accelerator.is_local_main_process)
completed_steps = 0

val_results = []
acc_losses  = []
best_r2_f1  = None
best_epoch  = 0

if args.model_type == 'bart' or args.model_type == 't5':
    task_specific_params = model.config.task_specific_params
    params = task_specific_params.get('summarization', {})
    params['min_length'] = args.min_target_length
    params['max_length'] = args.max_target_length
    params['length_penalty'] = args.length_penalty
    params['num_beams'] = args.num_beams
    model.config.update(params)
else:
    raise ValueError('{} model type not implemented'.format(args.model_type))

10/18/2023 14:31:31 - INFO - __main__ - ***** Running training *****
10/18/2023 14:31:31 - INFO - __main__ -  Num examples = 1500
10/18/2023 14:31:31 - INFO - __main__ -  Num Epochs = 3
10/18/2023 14:31:31 - INFO - __main__ -  Instantaneous batch size per device = 2
10/18/2023 14:31:31 - INFO - __main__ -  Total train batch size (w. parallel, distributed & accumulation) = 128
10/18/2023 14:31:31 - INFO - __main__ -  Gradient Accumulation steps = 64
10/18/2023 14:31:31 - INFO - __main__ -  Total optimization steps = 36


Training:   0%|          | 0/36 [00:00<?, ?it/s]

In [45]:
# = = = = = = = = = = = = = = = = Train = = = = = = = = = = = = = = = = = = =
total_batch_size = args.per_device_train_batch_size * \
    accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(
    f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(
    f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(
    f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), desc="Training: ",
                    disable=not accelerator.is_local_main_process)
completed_steps = 0

val_results = []
acc_losses = []
best_r2_f1 = None
best_epoch = 0

# edit #
if args.model_type == 'bart' or args.model_type == 't5':
    # task_specific_params = model.module.config.task_specific_params
    task_specific_params = model.config.task_specific_params
    params = task_specific_params.get('summarization', {})
    params['min_length'] = args.min_target_length
    params['max_length'] = args.max_target_length
    params['length_penalty'] = args.length_penalty
    params['num_beams'] = args.num_beams
    # model.module.config.update(params)
    model.config.update(params)
else:
    raise ValueError(
        '{} model type not implemented'.format(args.model_type))

loss_list = []
train_loss_list = []
val_loss_list = []
last_output = None
hidden_states = None

# =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  = Train =  =  =  =  =  =  =  =  =  =  =  =  =  =  =
for epoch in range(args.num_train_epochs):
    # train
    model.train()
    epoch_loss = 0.0
    for step, batch in enumerate(train_dataloader):

        if args.ctrlen_model:  # CTRLen model
            outputs, loss = model(batch, tokenizer)
        # w/ and w/o label smoothing (always better with label smoothing)
        else:
            if args.label_smoothing == 0:
                outputs = model(**batch)
                loss = outputs.loss
            else:
                outputs = model(**batch, output_hidden_states=True)
                last_output = outputs
                output_logits = outputs.logits
                hidden_states = outputs.decoder_hidden_states
                # print(f"logits: {output_logits.shape}")
                # print("="*100)
                # # print(f"hidden states: {hidden_states.shape}")
                # print(f"loss: {outputs.loss}")
                # print(outputs.keys())
                # print("="*100)
                output_probs = torch.nn.functional.log_softmax(
                    output_logits, dim=-1)
                # edit #
                # output_probs = output_probs.view(-1,
                #                                  model.module.config.vocab_size)
                output_probs = output_probs.view(-1,
                                                 model.config.vocab_size)

                gt_logits = batch['labels']
                # print(f"label: {gt_logits.shape}")
                # print("="*100)
                gt_logits = gt_logits.view(-1)

                # print(f"output_probs: {output_probs.shape}")
                # print("-"*100)
                # print(f"gt_logits: {gt_logits.shape}")
                # print("="*100)
                
                loss_nll, nll = label_smoothed_nll_loss(
                    output_probs, gt_logits, args.label_smoothing, ignore_index=tokenizer.pad_token_id)
                
                cosine_loss = torch.nn.CosineEmbeddingLoss()
                
                # loss_cs = cosine_loss(outputs.encoder_last_hidden_state[0], outputs.encoder_last_hidden_state[1], torch.ones(outputs.encoder_last_hidden_state.size(dim=1)).to(torch.device('cuda')))
                positive_embeddings_1 = outputs.encoder_last_hidden_state[0]
                print(positive_embeddings_1.shape)
                positive_embeddings_2 = outputs.encoder_last_hidden_state[1]
                print(positive_embeddings_2.shape)
                negative_embeddings_1 = outputs.encoder_last_hidden_state[2]
                print(negative_embeddings_1.shape)
                negative_embeddings_2 = outputs.encoder_last_hidden_state[3]
                print(negative_embeddings_2.shape)
                # break
                # Compute contrastive loss
                loss_1 = cosine_loss(positive_embeddings_1, negative_embeddings_1, torch.ones(positive_embeddings_1.size(dim=0)))
                loss_2 = cosine_loss(positive_embeddings_1, negative_embeddings_1, torch.ones(positive_embeddings_2.size(dim=0)))
                loss_cs = (loss_1 + loss_2) / 2
                # print("loss 1: ", loss_1)
                # print('-'*100)
                # print("loss 2: ", loss_2)
                # print('-'*100)
                # print("loss: ", loss)
                # print('='*100)
                
                alpha = 0.5
                
                loss = loss_nll + alpha * (1 - loss_cs)
                
                # print(f"loss_fn: {loss}")
                # print("-"*100)
                # print(f"nll: {nll}")
                # print("="*100)

        acc_losses.append(loss.item())
        loss_list.append(loss)
        epoch_loss += loss.item()
        loss = loss / args.gradient_accumulation_steps
        # print(f"loss_grad: {loss}")
        accelerator.backward(loss)
        # break

        if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_postfix(lr=lr_scheduler.get_last_lr()[
                                     0], loss=np.mean(acc_losses[-50:]))
            completed_steps += 1
            train_loss_list.append(epoch_loss/len(batch))

        if completed_steps >= args.max_train_steps:
            break

    # # =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  = EVAL =  =  =  =  =  =  =  =  =  =  =  =  =  =  =
    model.eval()
    val_predict = []
    val_groundtruth = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            val_loss = []
            generated_tokens = accelerator.unwrap_model(model).generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"]
            )

            # print(generated_tokens)
            # print("="*100)
            
            generated_tokens = accelerator.pad_across_processes(
                generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
            )
            
            # print(generated_tokens)
            # print("="*100)
            
            labels = batch["labels"]
            if not args.pad_to_max_length:
                # If we did not pad to max length, we need to pad the labels too
                labels = accelerator.pad_across_processes(
                    batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
            
            # print(generated_tokens)
            # print("="*100)
            # print(labels)
            # print("="*100)
            # loss, _ = label_smoothed_nll_loss_gen(generated_tokens, labels, args.label_smoothing, ignore_index=tokenizer.pad_token_id)
            # break    
            # loss, _ = map(label_smoothed_nll_loss, generated_tokens, labels, args.label_smoothing, ignore_index=tokenizer.pad_token_id)
            # val_loss.extend(loss)
            
            generated_tokens = accelerator.gather(
                generated_tokens).cpu().numpy()
            labels = accelerator.gather(labels).cpu().numpy()

            if args.ignore_pad_token_for_loss:
                # Replace -100 in the labels as we can't decode them.
                labels = np.where(labels != -100, labels,
                                tokenizer.pad_token_id)
            if isinstance(generated_tokens, tuple):
                generated_tokens = generated_tokens[0]

            # print(generated_tokens[0])
            # print("="*100)
            # print(labels[0])
            # loss, _ = label_smoothed_nll_loss(generated_tokens[0], labels[0], args.label_smoothing, ignore_index=tokenizer.pad_token_id)
            # break 
            
            decoded_preds = tokenizer.batch_decode(
                generated_tokens, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(
                labels, skip_special_tokens=True)
            
            # print(decoded_preds)
            # print("="*100)
            # print(decoded_labels)
            # print("="*100)
            
            decoded_preds, decoded_labels = postprocess_text(
                decoded_preds, decoded_labels)

            # print(decoded_preds)
            # print("="*100)
            # print(decoded_labels)
            # print("="*100)
            # if step == 1:
            #     break

            val_predict.extend(decoded_preds)
            val_groundtruth.extend(decoded_labels)

    if args.len_output == 'real':
        new_val_predict = []
        for sample in val_predict:
            try:
                gen_sum = sample.split('Summary: ')[2]
                new_val_predict.append(gen_sum)
            except:
                new_val_predict.append(sample)
        val_predict = new_val_predict
    else:
        new_val_predict = val_predict

    logger.info("")
    logger.info("Rouge score on val set after epoch {}".format(epoch+1))
    eval_results = py_rouge_scores(val_predict, val_groundtruth)

    if best_r2_f1 is None:
        best_r2_f1 = eval_results
    if eval_results['rouge-2']['f'] >= best_r2_f1['rouge-2']['f']:
        best_r2_f1 = eval_results
        best_epoch = epoch + 1

        os.makedirs(args.output_dir+'/best', exist_ok=True)
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(
            args.output_dir+'/best', save_function=accelerator.save)
        if accelerator.is_main_process:
            tokenizer.save_pretrained(args.output_dir+'/best')
        # save vocab
        vocab = tokenizer.vocab.copy()
        vocab = {k: v for k, v in sorted(
            vocab.items(), key=lambda item: item[1])}
        with open(args.output_dir + '/best/vocab.txt', 'w') as f:
            for word, index in vocab.items():
                # it lead to encoding bug on some machines, so i add this line
                word = word.encode('ascii', 'ignore').decode('ascii')
                f.write(str(index) + ': ' + word + '\n')

    # = = = = = = = = = = = = = = = = = = = = = = = = =
    logger.info(
        "Current Best Validation Result is at epoch {}".format(best_epoch))
    py_rouge_scores(None, None, best_r2_f1)
        

10/18/2023 14:31:31 - INFO - __main__ - ***** Running training *****
10/18/2023 14:31:31 - INFO - __main__ -  Num examples = 1500
10/18/2023 14:31:31 - INFO - __main__ -  Num Epochs = 3
10/18/2023 14:31:31 - INFO - __main__ -  Instantaneous batch size per device = 2
10/18/2023 14:31:31 - INFO - __main__ -  Total train batch size (w. parallel, distributed & accumulation) = 128
10/18/2023 14:31:31 - INFO - __main__ -  Gradient Accumulation steps = 64
10/18/2023 14:31:31 - INFO - __main__ -  Total optimization steps = 36


Training:   0%|          | 0/36 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 30.00 MiB (GPU 0; 47.54 GiB total capacity; 2.73 GiB already allocated; 28.56 MiB free; 2.80 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [25]:
len(outputs.encoder_last_hidden_state)

2

In [26]:
outputs.encoder_last_hidden_state

tensor([[[ 1.4340e-02,  3.4438e-02,  3.4452e-02,  ..., -8.5761e-03,
           2.4686e-03, -2.0838e-03],
         [-1.3822e-01, -4.3015e-02,  2.6078e-01,  ..., -1.4171e-01,
          -1.0360e-01, -2.5367e-01],
         [-5.5475e-04,  1.2314e-02,  7.7116e-03,  ...,  5.5711e-03,
          -3.1495e-03,  3.5236e-04],
         ...,
         [-1.9751e-03,  1.7929e-02,  9.2430e-03,  ...,  4.5511e-03,
           8.8687e-04, -5.6736e-04],
         [-1.0416e-02,  6.9710e-03,  2.3643e-02,  ...,  5.7164e-03,
           5.1921e-03, -4.1290e-03],
         [-8.8324e-04,  1.6706e-02,  9.7642e-03,  ...,  5.1859e-03,
          -4.2920e-03,  4.3312e-04]],

        [[ 2.3276e-02,  3.6314e-02,  3.2427e-02,  ..., -7.2097e-03,
          -1.4225e-02, -8.0660e-03],
         [-1.6026e-01, -2.3157e-02,  2.5404e-01,  ..., -4.2496e-02,
          -2.6232e-02, -2.0289e-01],
         [-3.6479e-02, -5.3065e-01, -4.0594e-02,  ...,  2.8221e-02,
          -5.6595e-02, -3.1839e-01],
         ...,
         [-1.8181e-02,  6

In [27]:
outputs.encoder_last_hidden_state.size(dim=1)

192

In [28]:
outputs.decoder_hidden_states[-1].size()

torch.Size([2, 32, 1024])

In [29]:
outputs.decoder_hidden_states[-1][0].size()

torch.Size([32, 1024])

In [30]:
print(len(outputs.logits))
print(outputs.logits[0].size())

2
torch.Size([32, 50266])


In [31]:
cosine_loss = torch.nn.CosineEmbeddingLoss()

In [32]:
outputs.decoder_hidden_states[-1].shape[1]

32

In [33]:
loss_cs = cosine_loss(outputs.decoder_hidden_states[-1][0], outputs.decoder_hidden_states[-1][1], torch.ones(outputs.decoder_hidden_states[-1].shape[1]).to(torch.device('cuda')))

In [34]:
loss_cs

tensor(0.4350, device='cuda:0', grad_fn=<MeanBackward0>)

In [38]:
x = -1 * torch.ones(outputs.decoder_hidden_states[-1].shape[1])

In [39]:
loss_cs = cosine_loss(outputs.decoder_hidden_states[-1][0], outputs.decoder_hidden_states[-1][1], x.to(torch.device('cuda')))

In [40]:
loss_cs

tensor(0.5650, device='cuda:0', grad_fn=<MeanBackward0>)

In [41]:
1 - loss_cs

tensor(0.4350, device='cuda:0', grad_fn=<RsubBackward1>)

In [35]:
# =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  = Test =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  = 
# load best model
logger.info("Loading Best Result is at epoch {} for Testing".format(best_epoch))

unwrapped_model = accelerator.unwrap_model(model)
config          = config.from_pretrained(args.output_dir+'/best')
tokenizer       = tokenizer.from_pretrained(args.output_dir+'/best', config=config)
unwrapped_model = unwrapped_model.from_pretrained(args.output_dir+'/best', config=config)
model           = accelerator.prepare(unwrapped_model)

if args.model_type == 'bart' or args.model_type == 't5':
    task_specific_params = model.config.task_specific_params
    params = task_specific_params.get('summarization', {})
    params['min_length'] = args.min_target_length
    params['max_length'] = args.max_target_length
    params['length_penalty'] = args.length_penalty
    params['num_beams'] = args.num_beams
    model.config.update(params)
else:
    raise ValueError('{} model type not implemented'.format(args.model_type))

# start Test 
logger.info("Collecting Testing Result...")
model.eval()

test_predict     = []
test_groundtruth = []
for step, batch in enumerate(tqdm(test_dataloader, leave=False)):
    with torch.no_grad():
        generated_tokens = accelerator.unwrap_model(model).generate(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
        )

        generated_tokens = accelerator.pad_across_processes(
            generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
        )
        labels = batch["labels"]

        if not args.pad_to_max_length:
            # If we did not pad to max length, we need to pad the labels too
            labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

        generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
        labels = accelerator.gather(labels).cpu().numpy()

        if args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]

        decoded_preds  = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

        decoded_preds  = [' '.join(sent.split('\n')) for sent in decoded_preds]
        decoded_labels = [' '.join(sent.split('\n')) for sent in decoded_labels]

        test_predict.extend(decoded_preds)
        test_groundtruth.extend(decoded_labels)

print(raw_datasets['test']['dialogue'][0])

if args.len_output == 'real':
    new_test_predict = []
    for sample in test_predict:
        try:
            gen_sum = sample.split('Summary: ')[2]
            new_test_predict.append(gen_sum)
        except:
            new_test_predict.append(sample)
    test_predict = new_test_predict

logger.info("")
logger.info("ROUGE score on test set")
test_scores = py_rouge_scores(test_predict, test_groundtruth)
logger.info("")


# Save generated summaries
if args.len_input == 'predict':
    os.makedirs(args.output_dir+'/predict_gen_samples', exist_ok=True)
else:
    os.makedirs(args.output_dir+'/gen_samples', exist_ok=True)

for i in range(len(test_predict)):
    test_id        = raw_datasets['test']['id'][i]
    test_dialogue  = raw_datasets['test']['dialogue'][i]
    test_summary   = raw_datasets['test']['summary'][i]
    test_predict_s = test_predict[i]

    if args.len_input == 'predict':
        with open(args.output_dir+'/predict_gen_samples/'+str(test_id)+'.txt', 'w') as f:
            test_dialogue = test_dialogue.encode('ascii', 'ignore').decode('ascii')
            f.write(test_dialogue)
            f.write('\n\n')
            f.write('Golden Summary:\n')
            test_summary = test_summary.encode('ascii', 'ignore').decode('ascii')
            f.write(test_summary)
            f.write('\n\n')
            f.write('Generate Summary:\n')
            test_predict_s = test_predict_s.encode('ascii', 'ignore').decode('ascii')
            f.write(test_predict_s)
    else:
        with open(args.output_dir+'/gen_samples/'+str(test_id)+'.txt', 'w') as f:
            test_dialogue = test_dialogue.encode('ascii', 'ignore').decode('ascii')
            f.write(test_dialogue)
            f.write('\n\n')
            f.write('Golden Summary:\n')
            test_summary = test_summary.encode('ascii', 'ignore').decode('ascii')
            f.write(test_summary)
            f.write('\n\n')
            f.write('Generate Summary:\n')
            test_predict_s = test_predict_s.encode('ascii', 'ignore').decode('ascii')
            f.write(test_predict_s)

10/08/2023 11:39:42 - INFO - __main__ - Loading Best Result is at epoch 2 for Testing
loading configuration file ./output/1-bart-baseline-loss/best/config.json
Model config BartConfig {
  "_name_or_path": "facebook/bart-large",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"

  0%|          | 0/19 [00:00<?, ?it/s]

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "max_length": 128,
  "min_length": 1,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "max_length": 128,
  "min_length": 1,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "max_length": 128,
  "min_length": 1,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stop

