In [1]:
import nltk
# nltk.download('punkt')
# nltk.download('averaged_perceptron_tagger')
# nltk.download('wordnet')
# nltk.download('stopwords')

In [2]:
import math
import os
import pprint
import logging
import random

import datasets
import nltk
from nltk.corpus import wordnet
from nltk.corpus import stopwords
import numpy as np
import torch
from tqdm.auto import tqdm

import transformers
from accelerate import Accelerator
from filelock import FileLock
from transformers import AdamW, get_scheduler, set_seed

from transformers.file_utils import is_offline_mode
from transformers.utils.versions import require_version

from args import parse_args
from data_loader import raw_data_loader, data_processor
from model_loader import model_loader
from rouge_s import py_rouge_scores
from utils import label_smoothed_nll_loss, postprocess_text, cosine_embedding_loss

from transformers import (
    MODEL_MAPPING,
    SchedulerType,
)

# You should update this to your particular problem to have better documentation of `model_type`
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

[nltk_data] Downloading package punkt to /home/worachotn/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/worachotn/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/worachotn/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/worachotn/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

In [4]:
from transformers import SchedulerType
import argparse

In [5]:
parser = argparse.ArgumentParser(description="bart")
parser.add_argument(
    "--train_file",
    type=str,
    default=None,
    help="A csv or a json file containing the training data.",
)
parser.add_argument(
    "--validation_file",
    type=str,
    default=None,
    help="A csv or a json file containing the validation data.",
)
parser.add_argument(
    "--test_file",
    type=str,
    default=None,
    help="A csv or a json file containing the test data.",
)

parser.add_argument(
    "--ignore_pad_token_for_loss",
    type=bool,
    default=True,
    help="Whether to ignore the tokens corresponding to "
    "padded labels in the loss computation or not.",
)
parser.add_argument(
    "--max_source_length",
    type=int,
    default=1024,
    help="The maximum total input sequence length after "
    "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
    "--source_prefix",
    type=str,
    default=None,
    help="A prefix to add before every source text " "(useful for T5 models).",
)
parser.add_argument(
    "--preprocessing_num_workers",
    type=int,
    default=None,
    help="The number of processes to use for the preprocessing.",
)
parser.add_argument(
    "--overwrite_cache",
    type=bool,
    default=None,
    help="Overwrite the cached training and evaluation sets",
)
parser.add_argument(
    "--min_target_length",
    type=int,
    default=1,
    help="The minimal total sequence length for target text",
)
parser.add_argument(
    "--max_target_length",
    type=int,
    default=128,
    help="The maximum total sequence length for target text after "
    "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
    "during ``evaluate`` and ``predict``.",
)
parser.add_argument(
    "--length_penalty",
    type=float,
    default=1.0,
    help="large - longer sequence, small - shorter sequence",
)
parser.add_argument(
    "--num_beams",
    type=int,
    default=4,
    help="Number of beams to use for evaluation. This argument will be "
    "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
)
parser.add_argument(
    "--pad_to_max_length",
    action="store_true",
    help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
)
parser.add_argument(
    "--model_name_or_path",
    type=str,
    help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
    "--config_name",
    type=str,
    default=None,
    help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
    "--tokenizer_name",
    type=str,
    default=None,
    help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
    "--text_column",
    type=str,
    default=None,
    help="The name of the column in the datasets containing the full texts (for summarization).",
)
parser.add_argument(
    "--summary_column",
    type=str,
    default=None,
    help="The name of the column in the datasets containing the summaries (for summarization).",
)
parser.add_argument(
    "--use_slow_tokenizer",
    action="store_true",
    help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
    "--per_device_train_batch_size",
    type=int,
    default=8,
    help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
    "--per_device_eval_batch_size",
    type=int,
    default=8,
    help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
    "--per_device_test_batch_size",
    type=int,
    default=8,
    help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
    "--learning_rate",
    type=float,
    default=5e-5,
    help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
    "--weight_decay", type=float, default=0.0, help="Weight decay to use."
)
parser.add_argument(
    "--num_train_epochs",
    type=int,
    default=3,
    help="Total number of training epochs to perform.",
)
parser.add_argument(
    "--max_train_steps",
    type=int,
    default=None,
    help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
    "--gradient_accumulation_steps",
    type=int,
    default=1,
    help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
    "--lr_scheduler_type",
    type=SchedulerType,
    default="linear",
    help="The scheduler type to use.",
    choices=[
        "linear",
        "cosine",
        "cosine_with_restarts",
        "polynomial",
        "constant",
        "constant_with_warmup",
    ],
)
parser.add_argument(
    "--num_warmup_steps",
    type=int,
    default=0,
    help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument(
    "--output_dir", type=str, default=None, help="Where to store the final model."
)
parser.add_argument(
    "--cache_dir",
    type=str,
    default=None,
    help="Cache directory for pre-trained models.",
)
parser.add_argument(
    "--seed", type=int, default=None, help="A seed for reproducible training."
)
parser.add_argument(
    "--model_type",
    type=str,
    default=None,
    help="Model type to use if training from scratch.",
    choices=MODEL_TYPES,
)
parser.add_argument(
    "--len_input",
    type=str,
    default="no",
    help="Use the ctrlen model or not",
    choices=(
        "no",
        "topic",
        "length",
        "topic-length",
    ),
)
parser.add_argument(
    "--len_output",
    type=str,
    default="no",
    help="Use the ctrlen model or not",
    choices=("no", "real"),
)
parser.add_argument(
    "--ctrlen_model",
    action="store_true",
    default=False,
    help="Use the ctrlen model or not",
)
parser.add_argument(
    "--sim_window_size", type=int, default=5, help="window size for computing loss."
)
parser.add_argument(
    "--sim_loss",
    type=float,
    default=0,
    help="the loss weight for similarity scores.",
)
parser.add_argument(
    "--special_len_token_init",
    type=str,
    default=None,
    help="ways to initialize special token for length (random, zero, token_embs)",
)
parser.add_argument(
    "--embedding_lr",
    type=float,
    default=5e-5,
    help="Initial learning rate for embedding layers.",
)
parser.add_argument(
    "--len_start",
    type=int,
    default=1,
    help="start length.",
)
parser.add_argument(
    "--len_end",
    type=int,
    default=100,
    help="end length.",
)
parser.add_argument(
    "--data_aug",
    action="store_true",
    default=False,
    help="whether to perform data augmentation or not",
)
parser.add_argument(
    "--pred_len",
    action="store_true",
    default=False,
    help="whether to use the golden length or predicted length",
)
parser.add_argument(
    "--shuffle",
    action="store_true",
    default=False,
    help="whether to shuffle the dataset to balance train/validation/test",
)
parser.add_argument(
    "--label_smoothing",
    type=float,
    default=0.0,
    help="hyperparameter for label smoothing.",
)
parser.add_argument(
    "--contrastive",
    type=str,
    default="no",
    help="Use contrastive or not",
    choices=(
        "no",
        "synonym",
        "random",
        "combine",
    ),
)
parser.add_argument(
    "--tagging",
    type=str,
    default="no",
    help="Use tagging or not",
    choices=(
        "no",
        "word",
        "prompt",
    ),
)
parser.add_argument(
    "--alpha",
    type=float,
    default=0.5,
    help="Initial alpha",
)
parser.add_argument(
    "--margin",
    type=float,
    default=0.5,
    help="Initial margin",
)
parser.add_argument(
    "--run_test",
    action="store_true",
    default=False,
    help="Run for testing",
)
parser.add_argument(
    "--debug",
    action="store_true",
    default=False,
    help="Use the debug mode or not",
)

_StoreTrueAction(option_strings=['--debug'], dest='debug', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Use the debug mode or not', metavar=None)

In [6]:
args = parser.parse_args('')

# Set Parameter

In [7]:
args.len_input = 'topic-length' 
args.len_output = 'no'
args.output_dir = './output/1'
args.train_file = './data/dialogsum/dialogsum.train.jsonl'
args.validation_file = './data/dialogsum/dialogsum.dev.jsonl'
args.test_file = './data/dialogsum/dialogsum.test.jsonl'
args.text_column = 'dialogue'
args.summary_column = 'summary'
args.model_name_or_path = 'facebook/bart-large'
args.model_type = 'bart'
args.max_target_length = 128
args.num_beams = 4
args.learning_rate = 5e-5
args.weight_decay = 1e-3
args.label_smoothing = 0.1
args.length_penalty = 1.0
args.num_train_epochs = 15
args.per_device_train_batch_size = 4
args.gradient_accumulation_steps = 32
args.per_device_eval_batch_size = 8
args.per_device_test_batch_size = 8
args.num_warmup_steps = 0
args.cache_dir = './output/cache'
args.overwrite_cache = True
args.seed = 12345
args.contrastive = "combine"
args.tagging = "no"
args.run_test = True

In [8]:
# Initialize the accelerator. The accelerator will handle device placement for us.
accelerator = Accelerator(mixed_precision="fp16")

# Setup logging, we only want one process per machine to log things on the screen.
# accelerator.is_local_main_process is only True for one process per machine.
if accelerator.is_local_main_process:
    device = accelerator.device
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_info()
else:
    datasets.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()

In [9]:
# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)
    random.seed(args.seed)
    os.environ["PYTHONHASHSEED"] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [10]:
if accelerator.is_main_process:
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()

In [11]:
raw_datasets = raw_data_loader(args)
print(raw_datasets['train'].features.keys(), raw_datasets['train'].num_rows)
print(raw_datasets['validation'].features.keys(), raw_datasets['validation'].num_rows)
print(raw_datasets['test'].features.keys(), raw_datasets['test'].num_rows)

dict_keys(['id', 'dialogue', 'summary', 'synonym_dialogue', 'random_dialogue']) 1500
dict_keys(['id', 'dialogue', 'summary', 'synonym_dialogue', 'random_dialogue']) 50
dict_keys(['id', 'dialogue', 'summary', 'synonym_dialogue', 'random_dialogue']) 150


In [14]:
num = 4
for feature in raw_datasets['train'].features:
    print(feature)
    # if feature == "topic" or feature == "synonym_topic" or feature == "random_topic":
    #     print(raw_datasets['train'][feature][num])
    print(raw_datasets['train'][feature][num])
    print("-"*20)

id
train_4
--------------------
dialogue
Topic of Summary: dance. Length of Summary: 16. Dialogue: #Person1#: Watsup, ladies! Y'll looking'fine tonight. May I have this dance?
#Person2#: He's cute! He looks like Tiger Woods! But, I can't dance. . .
#Person1#: It's all good. I'll show you all the right moves. My name's Malik.
#Person2#: Nice to meet you. I'm Wen, and this is Nikki.
#Person1#: How you feeling', vista? Mind if I take your friend'round the dance floor?
#Person2#: She doesn't mind if you don't mind getting your feet stepped on.
#Person1#: Right. Cool! Let's go!
--------------------
summary
Malik invites Nikki to dance. Nikki agrees if Malik doesn't mind getting his feet stepped on.
--------------------
synonym_dialogue
Topic of Summary: dancing. Length of Summary: 16. Dialogue: #Person1#: Watsup, ladies! Y'll looking'fine tonight. May I have this dance?
#Person2#: He's cute! He looks like Tiger Woods! But, I can't dance. . .
#Person1#: It's all good. I'll show you all the r

In [54]:
config, tokenizer, model = model_loader(accelerator, logger, args)

loading configuration file config.json from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/config.json
Model config BartConfig {
  "_name_or_path": "facebook/bart-large",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2

pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

loading weights file pytorch_model.bin from cache at ./output/cache/models--facebook--bart-large/snapshots/cb48c1365bd826bd521f650dc2e0940aee54720c/pytorch_model.bin
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.33.3"
}

All model checkpoint weights were used when initializing BartForConditionalGeneration.

All the weights of BartForConditionalGeneration were initialized from the model checkpoint at facebook/bart-large.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BartForConditionalGeneration for predictions without further training.
Generation config file not found, using a generation config created from the model config.
You are resizing the embedding layer withou

In [55]:
print(model.vocab_size)
print(tokenizer.SPECIAL_TOKENS_ATTRIBUTES)
print(tokenizer.additional_special_tokens)

50265
['bos_token', 'eos_token', 'unk_token', 'sep_token', 'pad_token', 'cls_token', 'mask_token', 'additional_special_tokens']
[]


In [56]:
dataloader, processed_dataset = data_processor(logger, args, accelerator, raw_datasets, tokenizer, model)
train_dataloader, eval_dataloader, test_dataloader = dataloader
train_dataset, _, _ = processed_dataset

Running tokenizer on dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]



Running tokenizer on dataset:   0%|          | 0/50 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/150 [00:00<?, ? examples/s]

11/26/2023 17:33:24 - INFO - __main__ - Sample 970 of the training set: {'input_ids': [0, 48931, 9, 19584, 35, 2593, 544, 4, 41852, 9, 19584, 35, 504, 4, 33854, 35, 849, 41761, 134, 10431, 35, 2497, 662, 6, 2649, 4, 50118, 10431, 41761, 176, 10431, 35, 2497, 662, 6, 21958, 4, 50118, 10431, 41761, 134, 10431, 35, 9918, 47, 1137, 162, 2540, 147, 64, 38, 465, 10, 28124, 2081, 116, 50118, 10431, 41761, 176, 10431, 35, 5143, 259, 6, 141, 64, 939, 244, 47, 452, 6, 21958, 116, 50118, 10431, 41761, 134, 10431, 35, 38, 1017, 101, 7, 2081, 23221, 2920, 1932, 88, 5, 382, 1932, 2540, 4, 50118, 10431, 41761, 176, 10431, 35, 1832, 47, 33, 41, 1316, 42, 827, 6, 21958, 116, 50118, 10431, 41761, 134, 10431, 35, 3216, 6, 259, 16, 127, 1316, 346, 4, 50118, 10431, 41761, 176, 10431, 35, 392, 939, 2540, 192, 4576, 116, 50118, 10431, 41761, 134, 10431, 35, 9136, 6, 259, 16, 127, 12373, 6, 30, 5, 169, 99, 16, 5, 731, 452, 116, 50118, 10431, 41761, 176, 10431, 35, 2477, 18, 731, 16, 132, 4, 4981, 2920, 1932, 

In [57]:
print(train_dataset.features.keys(), train_dataset.num_rows)

dict_keys(['input_ids', 'attention_mask', 'synonym_inputs', 'random_inputs', 'labels']) 1500


In [58]:
for step, batch in enumerate(train_dataloader):
    for ind, batch_keys in enumerate(batch.keys()):
        print(batch[batch_keys].shape)
        for indx in range(batch[batch_keys].shape[0]):
            print("index: ", indx)
            if batch_keys == 'labels':
                batch[batch_keys][indx] = torch.where(batch[batch_keys][indx] != -100, batch[batch_keys][indx], tokenizer.pad_token_id)
                print(tokenizer.decode((batch[batch_keys][indx]), skip_special_tokens=True))
            else:
                print(tokenizer.decode((batch[batch_keys][indx])))
                # print(tokenizer.decode((batch[batch_keys][indx]), skip_special_tokens=True))
    if step == 1:
        break

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([12, 360])
index:  0
<s>Topic of Summary: search for books. Length of Summary: 37. Dialogue: #Person1#: Sir, you've been using the online catalogue for quite a while. Is there anything I can do to help you?
#Person2#: Well, I've got to write a paper about Hollywood in the 30s and 40s, and I'm really struggling. There are hundreds of books, and I just don't know where to begin.
#Person1#: Your topic sounds pretty big. Why don't you narrow it down to something like.., uh... the history of the studios during that time?
#Person2#: You know, I was thinking about doing that, but more than 30 books came up when I typed in'movie studios'.
#Person1#: You could cut that down even further by listing the specific years you want. Try adding '1930s' or '1940s' or maybe 'Golden Age'.
#Person2#: 'Golden Age' is a good idea, Let me type that in. Hey, look, just 6 books this time That's a lot better.
#Person1#: Oh, another thing you might consider. Have you tried looking for any magazines or 

In [59]:
for step, batch in enumerate(test_dataloader):
    print(batch.keys())
    print(batch['input_ids'].shape)
    print(batch['attention_mask'].shape)
    print(batch['labels'].shape)
    print(batch['decoder_input_ids'].shape)
    break

dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_input_ids'])
torch.Size([8, 448])
torch.Size([8, 448])
torch.Size([8, 72])
torch.Size([8, 72])


In [19]:
for step, batch in enumerate(eval_dataloader):
    print(batch.keys())
    print(batch['input_ids'].shape)
    print(batch['attention_mask'].shape)
    print(batch['labels'].shape)
    print(batch['decoder_input_ids'].shape)
    break

dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_input_ids'])
torch.Size([8, 256])
torch.Size([8, 256])
torch.Size([8, 48])
torch.Size([8, 48])


# Testing DataLoader Above

In [33]:
# = = = Training Preparation = = =
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]

if args.ctrlen_model: 
    no_decay_emb_matrix = ["bias", "LayerNorm.weight", "shared"]
else:
    no_decay_emb_matrix = ["bias", "LayerNorm.weight"]

optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay_emb_matrix)],
        "weight_decay": args.weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]

if args.ctrlen_model:
    if args.model_type == 'bart': 
        optimizer_grouped_parameters.extend([{
            "params": model.seq2seq_model.model.shared.parameters(),
            "lr": args.embedding_lr}])
    elif args.model_type == 't5':
        optimizer_grouped_parameters.extend([{
            "params": model.seq2seq_model.shared.parameters(),
            "lr": args.embedding_lr}])
    else:
        raise ValueError('{} model type not implemented'.format(args.model_type))

# optimizer
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, test_dataloader
)

# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=args.num_warmup_steps,
    num_training_steps=args.max_train_steps,
)

# = = = = = = = = = = = = = = = = Train = = = = = = = = = = = = = = = = = = =
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), desc="Training: ", disable=not accelerator.is_local_main_process)
completed_steps = 0

val_results = []
acc_losses  = []
best_r2_f1  = None
best_epoch  = 0

if args.model_type == 'bart' or args.model_type == 't5':
    task_specific_params = model.config.task_specific_params
    params = task_specific_params.get('summarization', {})
    params['min_length'] = args.min_target_length
    params['max_length'] = args.max_target_length
    params['length_penalty'] = args.length_penalty
    params['num_beams'] = args.num_beams
    model.config.update(params)
else:
    raise ValueError('{} model type not implemented'.format(args.model_type))



In [37]:
# =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  = Train =  =  =  =  =  =  =  =  =  =  =  =  =  =  = 
for epoch in range(args.num_train_epochs):
    # train
    model.train()
    for step, batch in enumerate(train_dataloader):
        
        if args.ctrlen_model: # CTRLen model
            outputs, loss = model(batch, tokenizer)
        else: # w/ and w/o label smoothing (always better with label smoothing)
            if args.label_smoothing == 0:
                outputs = model(**batch)
                loss = outputs.loss
            else:
                outputs = model(**batch)
                output_logits = outputs.logits
                output_probs = torch.nn.functional.log_softmax(output_logits, dim=-1)
                output_probs = output_probs.view(-1, model.config.vocab_size)

                gt_logits = batch['labels']
                gt_logits = gt_logits.view(-1)

                loss, _ = label_smoothed_nll_loss(output_probs, gt_logits, args.label_smoothing, ignore_index=tokenizer.pad_token_id)

        acc_losses.append(loss.item())
        loss = loss / args.gradient_accumulation_steps
        accelerator.backward(loss)

        if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_postfix(lr=lr_scheduler.get_last_lr()[0], loss=np.mean(acc_losses[-50:]))
            completed_steps += 1

        if completed_steps >= args.max_train_steps:
            break


272
