In [None]:
! pip install transformers
! pip install datasets
! pip install sentencepiece
! pip install rouge_score
! pip install wandb

In [None]:
# Natural Language Processing 
# Assignment 3
# Implementation of Pegasus and DistilBart
# Modeled from https://colab.research.google.com/github/
# elsanns/xai-nlp-notebooks/blob/master/fine_tune_bart_summarization_
# two_langs.ipynb#scrollTo=ClRTrG2ETUm3

# run in Colab

import torch
import numpy as np
import datasets
import nltk
from datetime import datetime
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, \
	Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from google.colab import drive
drive.mount("/content/gdrive")


Mounted at /content/gdrive


In [None]:
#---------------------------------
# Initialize variables and 
# start Wandb utilization
#---------------------------------
import wandb
wandb.login()

# if loading pretrained model, set from_scratch to 0
# to train from scratch, set to 1
from_scratch = 0
language = "english"
model_name = "sshleifer/distilbart-xsum-12-3" # or "google/pegasus-cnn_dailymail"
encoder_max_length = 256
decoder_max_length = 64
nltk.download("punkt", quiet=True)
metric = datasets.load_metric("rouge")
data = datasets.load_dataset('csv', data_files=['/content/gdrive/My Drive/Auburn/NLP/cnn_dailymail_test_assignmemt3.csv', \
                                          '/content/gdrive/My Drive/Auburn/NLP/cnn_dailymail_train_assignmemt3.csv' \
                                          ], split="train[:3000]")

#---------------------------------
# Function Definitions
#---------------------------------

# returns a dictionary consisting of article and highlights entries
def build_dict(example):
  article = []
  highlights = []
  for sample in zip(example["article"], example["highlights"]):
    if len(sample[0]) > 0:
      article.append(sample[0])
      highlights.append(sample[1])
  return {"article": article, "highlights": highlights}

# preprocesses/prepares data for use in batches
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
  source, target = batch["article"], batch["highlights"]
  source_tokenized = tokenizer(source, padding="max_length", truncation=True, max_length=max_source_length)
  target_tokenized = tokenizer(target, padding="max_length", truncation=True, max_length=max_target_length)
  batch = dict()
  for key, value in source_tokenized.items():
    batch[key] = value
  # handle padding
  batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in l] for l in target_tokenized["input_ids"]]
  return batch

# prepare text to be human-readable
def postprocess_text(predicates, labels):
  predicates = [pred.strip() for pred in predicates]
  labels = [label.strip() for label in labels]
  predicates = ["\n".join(nltk.sent_tokenize(pred)) for pred in predicates]
  labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
  return predicates, labels

# get rouge ready
def compute_metrics(eval_predicates):
  predicates, labels = eval_predicates
  if isinstance(predicates, tuple):
      predicates = predicates[0]
  decoded_predicates = tokenizer.batch_decode(predicates, skip_special_tokens=True)
  # handle the -100 tokens
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  decoded_predicates, decoded_labels = postprocess_text(decoded_predicates, decoded_labels)
  result = metric.compute(predictions=decoded_predicates, references=decoded_labels, 
                          rouge_types=["rouge1", "rouge2", "rouge3", "rougeL", "rougeLsum"], use_stemmer=True)
  # get rouge results
  result = {key: value.mid.fmeasure * 100.0 for key, value in result.items()}
  prediction_lengths = [np.count_nonzero(predicate != tokenizer.pad_token_id) for predicate in predicates]
  result["gen_len"] = np.mean(prediction_lengths)
  result = {key: round(value, 5) for key, value in result.items()}
  return result

Using custom data configuration default-959af8ed08faccd5
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-959af8ed08faccd5/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff)


In [None]:
#---------------------------------
# Build dataset and call model
#---------------------------------
dataset = data.map(build_dict, batched=True, remove_columns=["Unnamed: 0"])

train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.18).values()

# If training from scratch
if from_scratch == 1:
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
  tokenizer = AutoTokenizer.from_pretrained(model_name)
else: 
  # Replace with whatever checkpoint you want to use
  model = AutoModelForSeq2SeqLM.from_pretrained('/content/gdrive/My Drive/Auburn/NLP/results/checkpoint-500',local_files_only=True)
  tokenizer = AutoTokenizer.from_pretrained('/content/gdrive/My Drive/Auburn/NLP/results/checkpoint-500',local_files_only=True)

train_data = train_data_txt.map(lambda batch: batch_tokenize_preprocess(batch, tokenizer,
  encoder_max_length, decoder_max_length), batched=True, remove_columns=train_data_txt.column_names)

validation_data = validation_data_txt.map(lambda batch: batch_tokenize_preprocess( 
  batch, tokenizer, encoder_max_length, decoder_max_length), batched=True, 
  remove_columns=validation_data_txt.column_names)

In [None]:
#---------------------------------
# Training Parameters
#---------------------------------

training_args = Seq2SeqTrainingArguments(
  output_dir="results",
  num_train_epochs=1,
  do_train=True,
  do_eval=True,
  per_device_train_batch_size=4,
  per_device_eval_batch_size=4,
  warmup_steps=500,
  weight_decay=0.1,
  label_smoothing_factor=0.1,
  predict_with_generate=True,
  logging_dir="logs",
  logging_steps=50,
  save_total_limit=3,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
  model=model,
  args=training_args,
  data_collator=data_collator,
  train_dataset=train_data,
  eval_dataset=validation_data,
  tokenizer=tokenizer,
  compute_metrics=compute_metrics,
)

# Prepare wandB run
wandb_inst = wandb.init(project="Bart_CNN_Dailymail", config={"per_device_train_batch_size": training_args.per_device_train_batch_size,
    "learning_rate": training_args.learning_rate, "dataset": "CNN_Dailymail "})
now = datetime.now()
current_time = now.strftime("%H%M%S")
wandb_inst.name = "sesh_" + current_time

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [None]:
#---------------------------------
# Train and evaluate 
#---------------------------------
if from_scratch == 1:
  trainer.train()
results = trainer.evaluate()
for key, value in results.items():
  print("Key: ", key, " Value: ", value)

wandb_inst.finish()

***** Running Evaluation *****
  Num examples = 540
  Batch size = 4


Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Key:  eval_loss  Value:  3.6606717109680176
Key:  eval_rouge1  Value:  36.94819
Key:  eval_rouge2  Value:  17.2156
Key:  eval_rouge3  Value:  10.31946
Key:  eval_rougeL  Value:  27.27825
Key:  eval_rougeLsum  Value:  34.14828
Key:  eval_gen_len  Value:  43.55741
Key:  eval_runtime  Value:  134.9688
Key:  eval_samples_per_second  Value:  4.001
Key:  eval_steps_per_second  Value:  1.0


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
eval/gen_len,▁
eval/loss,▁
eval/rouge1,▁
eval/rouge2,▁
eval/rouge3,▁
eval/rougeL,▁
eval/rougeLsum,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁

0,1
eval/gen_len,43.55741
eval/loss,3.66067
eval/rouge1,36.94819
eval/rouge2,17.2156
eval/rouge3,10.31946
eval/rougeL,27.27825
eval/rougeLsum,34.14828
eval/runtime,134.9688
eval/samples_per_second,4.001
eval/steps_per_second,1.0
