In [1]:
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to /home/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package wordnet to /home/nltk_data...
[nltk_data] Downloading package stopwords to /home/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

In [2]:
import math
import os
import pprint
import logging

import datasets
import nltk
from nltk.corpus import wordnet
from nltk.corpus import stopwords
import numpy as np
import torch
from tqdm.auto import tqdm

import transformers
from accelerate import Accelerator
from filelock import FileLock
from transformers import AdamW, get_scheduler, set_seed

from transformers.file_utils import is_offline_mode
from transformers.utils.versions import require_version

from args import parse_args
from data_loader import raw_data_loader, data_processor
from model_loader import model_loader
from rouge_s import py_rouge_scores
from utils import label_smoothed_nll_loss, postprocess_text
from contrastive_loss import margin_ranking_loss, cosine_embedding_loss

[nltk_data] Downloading package punkt to /home/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to /home/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /home/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

In [4]:
from transformers import SchedulerType
import argparse

In [5]:
arg_parser = argparse.ArgumentParser(description="bart")
arg_parser.add_argument("--topic_prompt_input", dest="topic_prompt_input", type=bool,
                        default=False, help="Use topic prompt or not")
arg_parser.add_argument("--length_prompt_input", dest="length_prompt_input", type=bool,
                        default=False, help="Use length prompt or not")
arg_parser.add_argument("--predict_summary", dest="predict_summary", type=bool,
                        default=False, help="Use predict summary or not")
arg_parser.add_argument("--output_dir", dest="output_dir",
                        type=str, default="./output/1", help="default")
arg_parser.add_argument("--train_file", dest="train_file", type=str,
                        default=None, help="A json file containing the training data.")
arg_parser.add_argument("--validation_file", dest="validation_file", type=str,
                        default=None, help="A json file containing the validation data.")
arg_parser.add_argument("--test_file", dest="test_file", type=str,
                        default=None, help="A json file containing the test data.")
arg_parser.add_argument("--ignore_pad_token_for_loss", dest="ignore_pad_token_for_loss", type=bool, default=True,
                        help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",)
arg_parser.add_argument("--text_column", dest="text_column", type=str, default="dialogue",
                        help="The name of the column in the datasets containing the full texts (for summarization).")
arg_parser.add_argument("--summary_column", dest="summary_column", type=str, default="summary",
                        help="The name of the column in the datasets containing the summaries (for summarization).")
arg_parser.add_argument("--model_name_or_path", dest="model_name_or_path", type=str, default="facebook/bart-large",
                        help="Path to pretrained model or model identifier from huggingface.co/models.")
arg_parser.add_argument("--model_type", dest="model_type", type=str, default="bart",
                        help="Model type to use if training from scratch.")
arg_parser.add_argument("--max_source_length", dest="max_source_length", 
                        type=int, default=1024, help="default")
arg_parser.add_argument("--preprocessing_num_workers", type=int, default=None,
                        help="The number of processes to use for the preprocessing.")
arg_parser.add_argument("--overwrite_cache", dest="overwrite_cache", type=bool,
                        default=None, help="Overwrite the cached training and evaluation sets")
arg_parser.add_argument("--min_target_length", dest="min_target_length", type=int,
                        default=1, help="The minimal total sequence length for target text")
arg_parser.add_argument("--max_target_length", dest="max_target_length", type=int, default=128, help="The maximum total sequence length for target text"
                        "after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. during ``evaluate`` and"
                        "``predict``.")
arg_parser.add_argument("--num_beams", dest="num_beams", type=int, default=4, help="Number of beams to use for evaluation. This argument will be "
                        "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.")
arg_parser.add_argument("--learning_rate", dest="learning_rate", type=float, default=5e-5,
                        help="Initial learning rate (after the potential warmup period) to use.")
arg_parser.add_argument("--pad_to_max_length", action="store_true",
                        help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",)
arg_parser.add_argument("--weight_decay", dest="weight_decay",
                        type=float, default=1e-3, help="Weight decay to use.")
arg_parser.add_argument("--label_smoothing", dest="label_smoothing",
                        type=float, default=0.1, help="hyperparameter for label smoothing.")
arg_parser.add_argument("--length_penalty", dest="length_penalty", type=float,
                        default=1.0, help="large - longer sequence, small - shorter sequence")
arg_parser.add_argument("--num_train_epochs", dest="num_train_epochs",
                        type=int, default=15, help="Total number of training epochs to perform.")
arg_parser.add_argument("--per_device_train_batch_size", dest="per_device_train_batch_size",
                        type=int, default=8, help="Batch size (per device) for the training dataloader.")
arg_parser.add_argument("--gradient_accumulation_steps", dest="gradient_accumulation_steps", type=int,
                        default=64, help="Number of updates steps to accumulate before performing a backward/update pass.")
arg_parser.add_argument("--per_device_eval_batch_size", dest="per_device_eval_batch_size",
                        type=int, default=8, help="Batch size (per device) for the evaluation dataloader.")
arg_parser.add_argument("--per_device_test_batch_size", dest="per_device_test_batch_size",
                        type=int, default=8, help="Batch size (per device) for the evaluation dataloader.")
arg_parser.add_argument("--num_warmup_steps", dest="num_warmup_steps", type=int,
                        default=0, help="Number of steps for the warmup in the lr scheduler.")
arg_parser.add_argument("--cache_dir", dest="cache_dir",
                        type=str, default="./output/cache", help="default")
arg_parser.add_argument("--seed", dest="seed",
                        type=int, default=12345, help="default")
arg_parser.add_argument("--config_name", type=str, default=None,
                        help="Pretrained config name or path if not the same as model_name")
arg_parser.add_argument("--tokenizer_name", type=str, default=None,
                        help="Pretrained tokenizer name or path if not the same as model_name")
arg_parser.add_argument("--use_slow_tokenizer", dest="use_slow_tokenizer", action="store_true",
                        help="If passed, will use a slow tokenizer (not backed by the HuggingFaceTokenizers library).")
arg_parser.add_argument("--max_train_steps", type=int, default=None,
                        help="Total number of training steps to perform. If provided, overrides num_train_epochs.")
arg_parser.add_argument("--lr_scheduler_type", type=SchedulerType, default="linear", help="The scheduler type to use.",
                        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"])
arg_parser.add_argument("--embedding_lr", type=float, default=5e-5,
                        help="Initial learning rate for embedding layers.")
arg_parser.add_argument("--len_start", type=int,
                        default=1, help="start length.")
arg_parser.add_argument("--len_end", type=int,
                        default=100, help="end length.")
arg_parser.add_argument("--contrastive_loss", dest="contrastive_loss", type=bool,
                        default=False, help="Use contrastive loss or not")
arg_parser.add_argument("--tagging", dest="tagging", type=str, default="no",
                        choices=('no', 'word', 'prompt'), help="Use tagging (<tp>, </tp>) in word, sentence, or not")
arg_parser.add_argument("--synonym_replacement", dest="synonym_replacement", type=bool,
                        default=False, help="Synonym replacement or not")
arg_parser.add_argument("--random_topic", dest="random_topic", type=bool,
                        default=False, help="Random topic or not")
arg_parser.add_argument("--contrastive_encoder", dest="contrastive_encoder", type=bool,
                        default=False, help="Contrastive encoder or not")
arg_parser.add_argument("--contrastive_decoder", dest="contrastive_decoder", type=bool,
                        default=False, help="Contrastive decoder or not")
arg_parser.add_argument("--gen_sample", dest="gen_sample", type=int,
                        default=1, help="The number of sample")
arg_parser.add_argument("--alpha", dest="alpha", type=float,
                        default=0.5, help="ration of computation loss in encoder")
arg_parser.add_argument("--beta", dest="beta", type=float,
                        default=0.5, help="ration of computation loss in decoder")
arg_parser.add_argument("--margin", dest="margin", type=float,
                        default=0, help="margin of computation loss")
arg_parser.add_argument("--debug", action='store_true',
                        default=False, help="Use the debug mode or not")

_StoreTrueAction(option_strings=['--debug'], dest='debug', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Use the debug mode or not', metavar=None)

In [6]:
args = arg_parser.parse_args('')

In [7]:
args.train_file = "./data/dialogtest_aug/dialogsum.train.jsonl"
args.validation_file = "./data/dialogtest_aug/dialogsum.dev.jsonl"
args.test_file = "./data/dialogtest_aug/dialogsum.test.jsonl"
args.text_column = "prompt"
args.summary_column = "summary"
args.model_name_or_path = "facebook/bart-large"
args.model_type = "bart"
args.max_source_length = 1024
args.min_target_length = 1
args.max_target_length = 128
args.num_beams = 4
args.learning_rate = 5e-5
args.weight_decay = 1e-3
args.label_smoothing = 0.1
args.length_penalty = 1.0
args.num_train_epochs = 2
args.per_device_train_batch_size = 2
args.gradient_accumulation_steps = 64
args.per_device_eval_batch_size = 8
args.per_device_test_batch_size = 8
args.num_warmup_steps = 0
args.cache_dir = "./output/cache"
args.overwrite_cache = True
args.seed = 12345
args.output_dir = "./output/bart-test"

args.topic_prompt_input = True
args.length_prompt_input = True

args.contrastive_loss = True
args.tagging = "prompt"

args.contrastive_encoder = True
args.contrastive_decoder = False

args.synonym_replacement = True
args.random_topic = True

In [8]:
print(args.topic_prompt_input)
print(args.length_prompt_input)
print(args.contrastive_loss)
print(args.tagging)
print(args.contrastive_encoder)
print(args.contrastive_decoder)
print(args.synonym_replacement)
print(args.random_topic)

True
True
True
prompt
True
False
True
True


In [9]:
raw_datasets = raw_data_loader(args)

In [10]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'prompt', 'summary', 'topic', 'synonym_prompt', 'synonym_topic', 'random_prompt', 'random_topic'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['id', 'prompt', 'summary', 'topic', 'synonym_prompt', 'synonym_topic', 'random_prompt', 'random_topic'],
        num_rows: 50
    })
    test: Dataset({
        features: ['id', 'prompt', 'summary', 'topic', 'synonym_prompt', 'synonym_topic', 'random_prompt', 'random_topic'],
        num_rows: 150
    })
})

In [11]:
raw_datasets['train']['prompt'][10]

"<t>Topic of Summary: do a favor</t>. Length of Summary: 24. Dialogue: # Person1 # : Could you do me a <t>favor</t> ? # Person2 # : Sure . What is it ? # Person1 # : Could you run over to the store ? We need a few things . # Person2 # : All right . What do you want me to get ? # Person1 # : Well , could you pick up some sugar ? # Person2 # : Okay . How much ? # Person1 # : A small bag . I guess we also need a few oranges . # Person2 # : How many ? # Person1 # : Oh , let 's see . . . About six . # Person2 # : Anything else ? # Person1 # : Yes . We 're out of milk . # Person2 # : Okay . How much do you want me to get ? A gallon ? # Person1 # : No . I think a half gallon will be enough . # Person2 # : Is that all ? # Person1 # : I think so . Have you got all that ? # Person2 # : Yes . That 's small bag of sugar , four oranges , and a half gallon of milk . # Person1 # : Do you have enough money ? # Person2 # : I think so . # Person1 # : Thanks very much . I appreciate it ."

In [12]:
raw_datasets['train']['synonym_prompt'][10]

"<t>Topic of Summary: do a party favor</t>. Length of Summary: 24. Dialogue: # Person1 # : Could you do me a <t><t>favor</t></t> ? # Person2 # : Sure . What is it ? # Person1 # : Could you run over to the store ? We need a few things . # Person2 # : All right . What do you want me to get ? # Person1 # : Well , could you pick up some sugar ? # Person2 # : Okay . How much ? # Person1 # : A small bag . I guess we also need a few oranges . # Person2 # : How many ? # Person1 # : Oh , let 's see . . . About six . # Person2 # : Anything else ? # Person1 # : Yes . We 're out of milk . # Person2 # : Okay . How much do you want me to get ? A gallon ? # Person1 # : No . I think a half gallon will be enough . # Person2 # : Is that all ? # Person1 # : I think so . Have you got all that ? # Person2 # : Yes . That 's small bag of sugar , four oranges , and a half gallon of milk . # Person1 # : Do you have enough money ? # Person2 # : I think so . # Person1 # : Thanks very much . I appreciate it ."

In [13]:
raw_datasets['train']['synonym_topic'][10]

'do a party favor'

In [14]:
raw_datasets['train']['random_prompt'][10]

"<t>Topic of Summary: appointment</t>. Length of Summary: 24. Dialogue: # Person1 # : Could you do me a <t><t>favor</t></t> ? # Person2 # : Sure . What is it ? # Person1 # : Could you run over to the store ? We need a few things . # Person2 # : All right . What do you want me to get ? # Person1 # : Well , could you pick up some sugar ? # Person2 # : Okay . How much ? # Person1 # : A small bag . I guess we also need a few oranges . # Person2 # : How many ? # Person1 # : Oh , let 's see . . . About six . # Person2 # : Anything else ? # Person1 # : Yes . We 're out of milk . # Person2 # : Okay . How much do you want me to get ? A gallon ? # Person1 # : No . I think a half gallon will be enough . # Person2 # : Is that all ? # Person1 # : I think so . Have you got all that ? # Person2 # : Yes . That 's small bag of sugar , four oranges , and a half gallon of milk . # Person1 # : Do you have enough money ? # Person2 # : I think so . # Person1 # : Thanks very much . I appreciate it ."

In [15]:
raw_datasets['train']['random_topic'][10]

'appointment'

In [16]:
raw_datasets['train']['summary'][10]

'#Person1# asks #Person2# to do a favor. #Person2# agrees and helps buy a small bag of sugar, six oranges, and a half-gallon of milk.'

In [17]:
# raw_datasets['train']['synonym_summary'][10]

In [18]:
# raw_datasets['train']['random_summary'][10]

In [19]:
accelerator = Accelerator(mixed_precision="fp16")
logger.info(accelerator.state)

11/15/2023 05:08:54 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16



In [20]:
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
accelerator.is_local_main_process
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
set_seed(args.seed)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
accelerator.is_main_process
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()

In [21]:
config, tokenizer, model = model_loader(accelerator, logger, args)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

loading configuration file config.json from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/config.json
Model config BartConfig {
  "_name_or_path": "facebook/bart-large",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2

Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

loading configuration file config.json from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/config.json
Model config BartConfig {
  "_name_or_path": "facebook/bart-large",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

loading file vocab.json from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/vocab.json
loading file merges.txt from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/merges.txt
loading file tokenizer.json from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/tokenizer_config.json
loading configuration file config.json from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/config.json
Model config BartConfig {
  "_name_or_path": "facebook/bart-large",
  "activation_dropout": 0.1,
  "activation_function": "gelu"

Downloading pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

loading weights file pytorch_model.bin from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/pytorch_model.bin
Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1
}

All model checkpoint weights were used when initializing BartForConditionalGeneration.

All the weights of BartForConditionalGeneration were initialized from the model checkpoint at facebook/bart-large.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BartForConditionalGeneration for predictions without further training.
Generation config file not found, using a generation config created from the model config.
You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the 

In [22]:
model.vocab_size

50267

In [23]:
tokenizer.SPECIAL_TOKENS_ATTRIBUTES

['bos_token',
 'eos_token',
 'unk_token',
 'sep_token',
 'pad_token',
 'cls_token',
 'mask_token',
 'additional_special_tokens']

In [24]:
print(tokenizer.additional_special_tokens)

['<t>', '</t>']


In [25]:
dataloader, processed_dataset = data_processor(logger, args, accelerator, raw_datasets, tokenizer, model)
train_dataloader, eval_dataloader, test_dataloader = dataloader
train_dataset, _, _ = processed_dataset

Running tokenizer on dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]



Running tokenizer on dataset:   0%|          | 0/50 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/150 [00:00<?, ? examples/s]



In [26]:
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels', 'synonym_inputs', 'random_inputs'],
    num_rows: 1500
})

In [27]:
tokenizer.decode(train_dataset['input_ids'][0])

"<s><t>Topic of Summary: get a check-up</t>. Length of Summary: 30. Dialogue: # Person1 # : Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today? # Person2 # : I found it would be a good idea to <t>get</t> a <t>check-up</t>. # Person1 # : Yes, well, you haven't had one for 5 years. You should have one every year. # Person2 # : I know. I figure as long as there is nothing wrong, why go see the doctor? # Person1 # : Well, the best way to avoid serious illnesses is to find out about them early. So try to come at least once a year for your own good. # Person2 # : Ok. # Person1 # : Let me see here. Your eyes and ears look fine. Take a deep breath, please. Do you smoke, Mr. Smith? # Person2 # : Yes. # Person1 # : Smoking is the leading cause of lung cancer and heart disease, you know. You really should quit. # Person2 # : I've tried hundreds of times, but I just can't seem to kick the habit. # Person1 # : Well, we have classes and some medications that might help. I 'll give you more in

In [28]:
tokenizer.decode(train_dataset['synonym_inputs'][0])

"<s><t>Topic of Summary: get a confirmation up</t>. Length of Summary: 30. Dialogue: # Person1 # : Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today? # Person2 # : I found it would be a good idea to <t><t>get</t></t> a <t>check-up</t>. # Person1 # : Yes, well, you haven't had one for 5 years. You should have one every year. # Person2 # : I know. I figure as long as there is nothing wrong, why go see the doctor? # Person1 # : Well, the best way to avoid serious illnesses is to find out about them early. So try to come at least once a year for your own good. # Person2 # : Ok. # Person1 # : Let me see here. Your eyes and ears look fine. Take a deep breath, please. Do you smoke, Mr. Smith? # Person2 # : Yes. # Person1 # : Smoking is the leading cause of lung cancer and heart disease, you know. You really should quit. # Person2 # : I've tried hundreds of times, but I just can't seem to kick the habit. # Person1 # : Well, we have classes and some medications that might help. I 'll gi

In [29]:
tokenizer.decode(train_dataset['random_inputs'][0])

"<s><t>Topic of Summary: an appointment</t>. Length of Summary: 30. Dialogue: # Person1 # : Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today? # Person2 # : I found it would be a good idea to <t><t>get</t></t> a <t>check-up</t>. # Person1 # : Yes, well, you haven't had one for 5 years. You should have one every year. # Person2 # : I know. I figure as long as there is nothing wrong, why go see the doctor? # Person1 # : Well, the best way to avoid serious illnesses is to find out about them early. So try to come at least once a year for your own good. # Person2 # : Ok. # Person1 # : Let me see here. Your eyes and ears look fine. Take a deep breath, please. Do you smoke, Mr. Smith? # Person2 # : Yes. # Person1 # : Smoking is the leading cause of lung cancer and heart disease, you know. You really should quit. # Person2 # : I've tried hundreds of times, but I just can't seem to kick the habit. # Person1 # : Well, we have classes and some medications that might help. I 'll give you 

In [30]:
for step, batch in enumerate(train_dataloader):
    # print(batch['input_ids'].shape)
    # print(batch['attention_mask'].shape)
    # print(batch['labels'].shape)
    # print(batch['decoder_input_ids'].shape)
    print(tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True))
    print(tokenizer.decode(batch['input_ids'][2], skip_special_tokens=True))
    print(tokenizer.decode(batch['input_ids'][4], skip_special_tokens=True))
    print("="*100)
    print(tokenizer.decode(batch['decoder_input_ids'][0], skip_special_tokens=True))
    print(tokenizer.decode(batch['decoder_input_ids'][2], skip_special_tokens=True))
    print(tokenizer.decode(batch['decoder_input_ids'][4], skip_special_tokens=True))
    print("*"*100)
    print(tokenizer.decode(batch['input_ids'][1], skip_special_tokens=True))
    print(tokenizer.decode(batch['input_ids'][3], skip_special_tokens=True))
    print(tokenizer.decode(batch['input_ids'][5], skip_special_tokens=True))
    print("="*100)
    print(tokenizer.decode(batch['decoder_input_ids'][1], skip_special_tokens=True))
    print(tokenizer.decode(batch['decoder_input_ids'][3], skip_special_tokens=True))
    print(tokenizer.decode(batch['decoder_input_ids'][5], skip_special_tokens=True))
    break

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Topic of Summary: medical treatment. Length of Summary: 16. Dialogue: # Person1 # : Could you give me something for the pain? I couldn't get to sleep until 3 o'clock this morning. # Person2 # : Aspirin is the strongest medicine I can give you. # Person1 # : That isn't strong enough, and I don't have to meet my doctor until next week. # Person2 # : Who is your doctor? # Person1 # : Dr. Hilary. # Person2 # : Doesn't he have his office on the corner? # Person1 # : Yes, he does. # Person2 # : Are you a regular patient? # Person1 # : Yes. # Person2 # : Oh. Then I can call him if you like. Dr. Hilary will give me a pain treatment over the phone. # Person1 # : I 'd appreciate that very much. Do you think that he 'll still be in his office? # Person2 # : Sure. It's only 4:30. He should be there until five.
Topic of Summary: aesculapian treatment. Length of Summary: 16. Dialogue: # Person1 # : Could you give me something for the pain? I couldn't get to sleep until 3 o'clock this morning. # Pers

In [31]:
for step, batch in enumerate(test_dataloader):
    print(batch.keys())
    print(batch['input_ids'].shape)
    print(batch['attention_mask'].shape)
    print(batch['labels'].shape)
    print(batch['decoder_input_ids'].shape)
    break

dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_input_ids'])
torch.Size([8, 472])
torch.Size([8, 472])
torch.Size([8, 72])
torch.Size([8, 72])


In [32]:
for step, batch in enumerate(eval_dataloader):
    print(batch.keys())
    print(batch['input_ids'].shape)
    print(batch['attention_mask'].shape)
    print(batch['labels'].shape)
    print(batch['decoder_input_ids'].shape)
    break

dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_input_ids'])
torch.Size([8, 256])
torch.Size([8, 256])
torch.Size([8, 48])
torch.Size([8, 48])


In [33]:
# = = = Training Preparation = = =
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]

no_decay_emb_matrix = ["bias", "LayerNorm.weight"]

optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay_emb_matrix)],
        "weight_decay": args.weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]


# if args.model_type == 'bart':
#     optimizer_grouped_parameters.extend([{
#         "params": model.seq2seq_model.model.shared.parameters(),
#         "lr": args.embedding_lr}])
# elif args.model_type == 't5':
#     optimizer_grouped_parameters.extend([{
#         "params": model.seq2seq_model.shared.parameters(),
#         "lr": args.embedding_lr}])

# optimizer
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)



In [34]:
model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, test_dataloader
)

In [35]:
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=args.num_warmup_steps,
    num_training_steps=args.max_train_steps,
)

# = = = = = = = = = = = = = = = = Train = = = = = = = = = = = = = = = = = = =
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), desc="Training: ", disable=not accelerator.is_local_main_process)
completed_steps = 0

val_results = []
acc_losses  = []
best_r2_f1  = None
best_epoch  = 0

if args.model_type == 'bart' or args.model_type == 't5':
    task_specific_params = model.config.task_specific_params
    params = task_specific_params.get('summarization', {})
    params['min_length'] = args.min_target_length
    params['max_length'] = args.max_target_length
    params['length_penalty'] = args.length_penalty
    params['num_beams'] = args.num_beams
    model.config.update(params)
else:
    raise ValueError('{} model type not implemented'.format(args.model_type))

11/15/2023 05:09:19 - INFO - __main__ - ***** Running training *****
11/15/2023 05:09:19 - INFO - __main__ -  Num examples = 1500
11/15/2023 05:09:19 - INFO - __main__ -  Num Epochs = 2
11/15/2023 05:09:19 - INFO - __main__ -  Instantaneous batch size per device = 2
11/15/2023 05:09:19 - INFO - __main__ -  Total train batch size (w. parallel, distributed & accumulation) = 128
11/15/2023 05:09:19 - INFO - __main__ -  Gradient Accumulation steps = 64
11/15/2023 05:09:19 - INFO - __main__ -  Total optimization steps = 24


Training:   0%|          | 0/24 [00:00<?, ?it/s]

In [36]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [37]:
for step, batch in enumerate(train_dataloader):
    print(batch['input_ids'].shape[1])
    break

272


In [38]:
# = = = = = = = = = = = = = = = = Train = = = = = = = = = = = = = = = = = = =
total_batch_size = args.per_device_train_batch_size * \
    accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(
    f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(
    f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(
    f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), desc="Training: ",
                    disable=not accelerator.is_local_main_process)
completed_steps = 0

val_results = []
acc_losses = []
best_r2_f1 = None
best_epoch = 0

# edit #
if args.model_type == 'bart' or args.model_type == 't5':
    # task_specific_params = model.module.config.task_specific_params
    task_specific_params = model.config.task_specific_params
    params = task_specific_params.get('summarization', {})
    params['min_length'] = args.min_target_length
    params['max_length'] = args.max_target_length
    params['length_penalty'] = args.length_penalty
    params['num_beams'] = args.num_beams
    # model.module.config.update(params)
    model.config.update(params)
else:
    raise ValueError(
        '{} model type not implemented'.format(args.model_type))

loss_list = []
train_loss_list = []
val_loss_list = []
last_output = None
hidden_states = None

# =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  = Train =  =  =  =  =  =  =  =  =  =  =  =  =  =  =
for epoch in range(args.num_train_epochs):
    # train
    model.train()
    for step, batch in enumerate(train_dataloader):
        if args.label_smoothing == 0:
            outputs = model(**batch)
            loss = outputs.loss
        else:
            outputs = model(**batch, output_hidden_states=True)
            output_logits = outputs.logits

            output_probs = torch.nn.functional.log_softmax(
                output_logits, dim=-1)

            if args.contrastive_loss:
                max_encoder_token = model.config.max_position_embeddings
                # print(max_encoder_token)

                divide_num = int(output_probs.shape[0] / 2)
                # print(divide_num)

                embeddings_1 = outputs.encoder_last_hidden_state[0,:,:max_encoder_token]
                synonym_embeddings_1 = outputs.encoder_last_hidden_state[2,:,:max_encoder_token]
                random_embeddings_1 = outputs.encoder_last_hidden_state[4,:,:max_encoder_token]
                # synonym_embeddings = synonym_embeddings.view(-1, max_encoder_token)
                # random_embeddings = random_embeddings.view(-1, max_encoder_token)
                synonym_1 = -1 * torch.ones(synonym_embeddings_1.size(dim=0)).to(device)
                random_1 = -1 * torch.ones(random_embeddings_1.size(dim=0)).to(device)
                embeddings_2 = outputs.encoder_last_hidden_state[1,:,:max_encoder_token]
                synonym_embeddings_2 = outputs.encoder_last_hidden_state[3,:,:max_encoder_token]
                random_embeddings_2 = outputs.encoder_last_hidden_state[5,:,:max_encoder_token]
                synonym_2 = -1 * torch.ones(synonym_embeddings_2.size(dim=0)).to(device)
                random_2 = -1 * torch.ones(random_embeddings_2.size(dim=0)).to(device)
                # print(embeddings_1.shape)
                # print(synonym_embeddings_1.shape)
                # print(random_embeddings_1.shape)
                # print(embeddings.shape)
                # print(synonym_embeddings_2.shape)
                # print(random_embeddings_2.shape)
                # break

                loss_cs_synonym_1 = cosine_embedding_loss(embeddings_1, synonym_embeddings_1, synonym_1, args.margin)
                loss_cs_random_1 = cosine_embedding_loss(embeddings_1, random_embeddings_1, random_1, args.margin)
                loss_cs_synonym_2 = cosine_embedding_loss(embeddings_2, synonym_embeddings_2, synonym_2, args.margin)
                loss_cs_random_2 = cosine_embedding_loss(embeddings_2, random_embeddings_2, random_2, args.margin)
                # loss_cs_1 = loss_cs_synonym_1 + loss_cs_random_1
                # loss_cs_2 = loss_cs_synonym_2 + loss_cs_random_2
                # loss_cs = (loss_cs_1 + loss_cs_2) / 2
                # loss_cs_synonym = (loss_cs_synonym_1 + loss_cs_synonym_2) / 2
                # loss_cs_random = (loss_cs_random_1 + loss_cs_random_2) / 2
                loss_cs = (loss_cs_synonym_1 + loss_cs_synonym_2 + loss_cs_random_1 + loss_cs_random_2) / 4
                # print(f"loss_cs: {loss_cs}")


                output_probs_1 = output_probs[0,:,:]
                # print(output_probs_1.shape)
                output_probs_2 = output_probs[1,:,:]
                # print(output_probs_2.shape)
                output_probs_all = torch.stack((output_probs_1, output_probs_2), dim=1)
                # print("output_probs_all: ", output_probs_all.shape)

                # ## decoder
                # output_probs_synonym_1 = output_probs[2,:,:]
                # output_probs_synonym_2 = output_probs[4,:,:]
                # output_probs_synonym = torch.stack((output_probs_synonym_1, output_probs_synonym_2), dim=1)
                # # print("output_probs_synonym: ", output_probs_synonym.shape)
                # output_probs_random_1 = output_probs[3,:,:]
                # output_probs_random_2 = output_probs[5,:,:]
                # output_probs_random = torch.stack((output_probs_random_1, output_probs_random_2), dim=1)
                # # print("output_probs_random: ", output_probs_random.shape)
                # output_probs_all_mr = output_probs_all.view(-1,
                #                                  model.config.vocab_size)
                # # print("output_probs_all_mr: ", output_probs_all_mr.shape)
                # output_probs_synonym = output_probs_synonym.view(-1,
                #                                  model.config.vocab_size)
                # # print("output_probs_synonym: ", output_probs_synonym.shape)
                # output_probs_random = output_probs_random.view(-1,
                #                                  model.config.vocab_size)
                # # print("output_probs_random: ", output_probs_random.shape)
                # # (pos, neg, target, ignore_index=-100, ,device)
                # target_one = torch.ones(gt_logits_all_mr.shape[0]).to(device)
                # # print("target_one: ", target_one.shape)
                # loss_mr_1 = margin_ranking_loss(output_probs_all_mr, output_probs_synonym, 
                #                                           gt_logits_all_mr, target_one, ignore_index=tokenizer.pad_token_id)
                # loss_mr_2 = margin_ranking_loss(output_probs_all_mr, output_probs_random, 
                #                                           gt_logits_all_mr, target_one, ignore_index=tokenizer.pad_token_id)
                # loss_mr = (loss_mr_1 + loss_mr_2) / 2
                # # print(f"loss_mr: {loss_mr}")


                ## negative log-likelihood

                # gt_logits = batch['labels'][:divide_num]
                gt_logits = batch['labels']
                # print("gt_logits: ", gt_logits.shape)
                # gt_logits = gt_logits.view(-1)
                gt_logits_1 = gt_logits[0,:]
                gt_logits_2 = gt_logits[1,:]
                gt_logits_all = torch.stack((gt_logits_1, gt_logits_2), dim=1)

                # ## decoder
                # gt_logits_all_mr = gt_logits_all.view(-1)
                # # print("gt_logits_all_mr: ", gt_logits_all_mr.shape)

                loss_nll, nll = label_smoothed_nll_loss(
                    output_probs_all, gt_logits_all, args.label_smoothing, ignore_index=tokenizer.pad_token_id)

                loss = loss_nll + (args.alpha * loss_cs)

                # ## decoder
                # loss = loss_nll + (args.alpha * loss_cs) + (args.beta * loss_mr)
                # print(loss)
                break

            else:
                output_probs = output_probs
                output_probs = output_probs.view(-1,
                                                 model.config.vocab_size)

                gt_logits = batch['labels']
                gt_logits = gt_logits.view(-1)

                loss_nll, nll = label_smoothed_nll_loss(
                    output_probs, gt_logits, args.label_smoothing, ignore_index=tokenizer.pad_token_id)

                loss = loss_nll
                # break

        acc_losses.append(loss.item())
        loss = loss / args.gradient_accumulation_steps
        accelerator.backward(loss)

        if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_postfix(lr=lr_scheduler.get_last_lr()[0], loss=np.mean(acc_losses[-50:]))
            completed_steps += 1

        if completed_steps >= args.max_train_steps:
            break

11/15/2023 05:09:19 - INFO - __main__ - ***** Running training *****
11/15/2023 05:09:19 - INFO - __main__ -  Num examples = 1500
11/15/2023 05:09:19 - INFO - __main__ -  Num Epochs = 2
11/15/2023 05:09:19 - INFO - __main__ -  Instantaneous batch size per device = 2
11/15/2023 05:09:19 - INFO - __main__ -  Total train batch size (w. parallel, distributed & accumulation) = 128
11/15/2023 05:09:19 - INFO - __main__ -  Gradient Accumulation steps = 64
11/15/2023 05:09:19 - INFO - __main__ -  Total optimization steps = 24


Training:   0%|          | 0/24 [00:00<?, ?it/s]

In [42]:
print(tokenizer.decode(batch['decoder_input_ids'][0], skip_special_tokens=True))
print(tokenizer.decode(batch['decoder_input_ids'][1], skip_special_tokens=True))

Nicole and #Person1# talk about how they spent their last weekends.
#Person1# asks #Person2# for the notes as #Person1# didn't come to class last week.


In [44]:
batch['labels'][0]

tensor([    0, 31988,  4104,     8,   849, 41761,   134, 10431,  1067,    59,
          141,    51,  1240,    49,    94, 12729,     4,     2,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100], device='cuda:0')

In [48]:
a = batch['labels'][0][~batch['labels'][0].eq(-100)]
b = batch['labels'][1][~batch['labels'][1].eq(-100)]

In [49]:
print(tokenizer.decode(a))
print(tokenizer.decode(b))

<s>Nicole and #Person1# talk about how they spent their last weekends.</s>
<s>#Person1# asks #Person2# for the notes as #Person1# didn't come to class last week.</s>


In [None]:
import torch.nn as nn
mr_loss = nn.MarginRankingLoss()

In [None]:
def margin_ranking_loss(pos, neg, target, target_one, ignore_index=-100):
    
    probs_pos = pos[~target.eq(-100)]
    print(probs_pos.shape)
    probs_neg = neg[~target.eq(-100)]
    print(probs_neg.shape)
    target = target[~target.eq(-100)]
    print(target.shape)
    target_one = target_one[:target.shape[0]]
    print(target_one.shape)

    if target.dim() == probs_pos.dim() - 1:
        target = target.unsqueeze(-1)
    
    nll_pos = -probs_pos.gather(dim=-1, index=target)
    nll_neg = -probs_neg.gather(dim=-1, index=target)

    nll_sq_pos = nll_pos.squeeze(-1)
    nll_sq_neg = nll_neg.squeeze(-1)

    loss_margin_ranking = mr_loss(nll_sq_pos, nll_sq_neg, target_one)

    return loss_margin_ranking

In [None]:
m = margin_ranking_loss(output_probs_all_mr, output_probs_synonym, gt_logits_all_mr, target_one, ignore_index=tokenizer.pad_token_id)

In [131]:
# cs_loss = nn.CosineEmbeddingLoss(margin=0.5)
cs_loss = torch.nn.CosineEmbeddingLoss()

In [138]:
contrastive = -1 * torch.ones(embeddings.size(dim=0)).to(device)

In [139]:
loss_cosine_embedding = cs_loss(embeddings, positive_embeddings, contrastive)

In [140]:
loss_cosine_embedding

tensor(0.1410, device='cuda:0', grad_fn=<MeanBackward0>)

In [141]:
cos = torch.nn.CosineSimilarity(dim=-1)

In [142]:
s = cos(embeddings, positive_embeddings)

In [143]:
s.mean()

tensor(0.1221, device='cuda:0', grad_fn=<MeanBackward0>)

In [122]:
# a = torch.mean(torch.relu(0.5 - s))
a = torch.mean(torch.max(torch.zeros(s.shape[0]).to(device), (0.5 - s)))

In [123]:
a

tensor(0.3107, device='cuda:0', grad_fn=<MeanBackward0>)

In [144]:
contrastive = -1 * torch.ones(embeddings.size(dim=0)).to(device)

In [146]:
loss_cosine_embedding = cs_loss(embeddings, negative_embeddings, contrastive)

In [147]:
loss_cosine_embedding

tensor(0.2975, device='cuda:0', grad_fn=<MeanBackward0>)

In [148]:
n = cos(embeddings, negative_embeddings)

In [149]:
n.mean()

tensor(0.2833, device='cuda:0', grad_fn=<MeanBackward0>)

In [150]:
# b = torch.mean(torch.relu(0.5 - n))
b = torch.mean(torch.max(torch.zeros(n.shape[0]).to(device), (0.5 - n)))

In [151]:
b

tensor(0.2489, device='cuda:0', grad_fn=<MeanBackward0>)

In [152]:
s.shape

torch.Size([264])

In [153]:
# Calculate distances
pos_distance = torch.norm(embeddings - positive_embeddings, p=2, dim=1)
neg_distance = torch.norm(embeddings - negative_embeddings, p=2, dim=1)
# Calculate loss
loss = torch.mean(torch.relu(pos_distance - neg_distance + 0))

In [156]:
pos_distance.mean()

tensor(7.2529, device='cuda:0', grad_fn=<MeanBackward0>)

In [157]:
neg_distance.mean()

tensor(6.3096, device='cuda:0', grad_fn=<MeanBackward0>)

In [154]:
loss

tensor(1.0811, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
outputs.encoder_last_hidden_state[:2,:,:].shape

In [None]:
outputs.encoder_last_hidden_state[2:,:,:].shape

In [None]:
output_probs.shape

In [None]:
int(output_probs.shape[0] / 2)

In [None]:
output_probs_pos = output_probs[:2,:,:]
output_probs_pos = output_probs_pos.view(-1,
                                 model.config.vocab_size)
output_probs_neg = output_probs[2:,:,:]
output_probs_neg = output_probs_neg.view(-1,
                                 model.config.vocab_size)
gt_logits = batch['labels'][:2]
gt_logits = gt_logits.view(-1)

In [None]:
output_probs_pos.shape

In [None]:
output_probs_neg.shape

In [None]:
gt_logits.shape

In [None]:
probs_pos = output_probs_pos[~gt_logits.eq(-100)]
probs_neg = output_probs_neg[~gt_logits.eq(-100)]

In [None]:
gt = gt_logits[~gt_logits.eq(-100)]

In [None]:
if gt.dim() == probs_pos.dim() - 1:
    gt = gt.unsqueeze(-1)

In [None]:
nll_pos = -probs_pos.gather(dim=-1, index=gt)
nll_neg = -probs_neg.gather(dim=-1, index=gt)

In [None]:
nll_sq_pos = nll_pos.squeeze(-1)
nll_sq_neg = nll_neg.squeeze(-1)

In [None]:
target_one = torch.ones(nll_sq_pos.shape[0])

In [None]:
import torch.nn as nn
loss = nn.MarginRankingLoss()
input1 = nll_sq_pos
input2 = nll_sq_neg
target = target_one.to("cuda:0")
output = loss(input1, input2, target)
# output.backward()

In [None]:
output

In [None]:
len(outputs.encoder_last_hidden_state)

In [None]:
outputs.encoder_last_hidden_state

In [None]:
outputs.encoder_last_hidden_state.size(dim=1)

In [None]:
print(len(outputs.logits))
print(outputs.logits[0].size())

In [None]:
cosine_loss = torch.nn.CosineEmbeddingLoss()

In [None]:
loss_cs = cosine_loss(outputs.encoder_last_hidden_state[0], outputs.encoder_last_hidden_state[1], torch.ones(outputs.encoder_last_hidden_state.size(dim=1)).to(torch.device('cuda')))

In [None]:
loss_cs

In [None]:
# =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  = Test =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =
# load best model
logger.info("Loading Best Result is at epoch {} for Testing".format(best_epoch))

unwrapped_model = accelerator.unwrap_model(model)
config          = config.from_pretrained(args.output_dir+'/best')
tokenizer       = tokenizer.from_pretrained(args.output_dir+'/best', config=config)
unwrapped_model = unwrapped_model.from_pretrained(args.output_dir+'/best', config=config)
model           = accelerator.prepare(unwrapped_model)

if args.model_type == 'bart' or args.model_type == 't5':
    task_specific_params = model.config.task_specific_params
    params = task_specific_params.get('summarization', {})
    params['min_length'] = args.min_target_length
    params['max_length'] = args.max_target_length
    params['length_penalty'] = args.length_penalty
    params['num_beams'] = args.num_beams
    model.config.update(params)
else:
    raise ValueError('{} model type not implemented'.format(args.model_type))

# start Test
logger.info("Collecting Testing Result...")
model.eval()

test_predict     = []
test_groundtruth = []
for step, batch in enumerate(tqdm(test_dataloader, leave=False)):
    with torch.no_grad():
        generated_tokens = accelerator.unwrap_model(model).generate(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
        )

        generated_tokens = accelerator.pad_across_processes(
            generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
        )
        labels = batch["labels"]

        if not args.pad_to_max_length:
            # If we did not pad to max length, we need to pad the labels too
            labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

        generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
        labels = accelerator.gather(labels).cpu().numpy()

        if args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]

        decoded_preds  = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

        decoded_preds  = [' '.join(sent.split('\n')) for sent in decoded_preds]
        decoded_labels = [' '.join(sent.split('\n')) for sent in decoded_labels]

        test_predict.extend(decoded_preds)
        test_groundtruth.extend(decoded_labels)

print(raw_datasets['test']['dialogue'][0])

if args.len_output == 'real':
    new_test_predict = []
    for sample in test_predict:
        try:
            gen_sum = sample.split('Summary: ')[2]
            new_test_predict.append(gen_sum)
        except:
            new_test_predict.append(sample)
    test_predict = new_test_predict

logger.info("")
logger.info("ROUGE score on test set")
test_scores = py_rouge_scores(test_predict, test_groundtruth)
logger.info("")


# Save generated summaries
if args.len_input == 'predict':
    os.makedirs(args.output_dir+'/predict_gen_samples', exist_ok=True)
else:
    os.makedirs(args.output_dir+'/gen_samples', exist_ok=True)

for i in range(len(test_predict)):
    test_id        = raw_datasets['test']['id'][i]
    test_dialogue  = raw_datasets['test']['dialogue'][i]
    test_summary   = raw_datasets['test']['summary'][i]
    test_predict_s = test_predict[i]

    if args.len_input == 'predict':
        with open(args.output_dir+'/predict_gen_samples/'+str(test_id)+'.txt', 'w') as f:
            test_dialogue = test_dialogue.encode('ascii', 'ignore').decode('ascii')
            f.write(test_dialogue)
            f.write('\n\n')
            f.write('Golden Summary:\n')
            test_summary = test_summary.encode('ascii', 'ignore').decode('ascii')
            f.write(test_summary)
            f.write('\n\n')
            f.write('Generate Summary:\n')
            test_predict_s = test_predict_s.encode('ascii', 'ignore').decode('ascii')
            f.write(test_predict_s)
    else:
        with open(args.output_dir+'/gen_samples/'+str(test_id)+'.txt', 'w') as f:
            test_dialogue = test_dialogue.encode('ascii', 'ignore').decode('ascii')
            f.write(test_dialogue)
            f.write('\n\n')
            f.write('Golden Summary:\n')
            test_summary = test_summary.encode('ascii', 'ignore').decode('ascii')
            f.write(test_summary)
            f.write('\n\n')
            f.write('Generate Summary:\n')
            test_predict_s = test_predict_s.encode('ascii', 'ignore').decode('ascii')
            f.write(test_predict_s)