In [1]:
import torch
import pprint
import evaluate
import nltk
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer
from transformers import AutoTokenizer
from datasets import load_dataset, load_metric
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer


In [2]:
max_input_length = 1024
max_target_length = 128

In [3]:
train = load_dataset("cnn_dailymail", "3.0.0", split="train[:8]")
val = load_dataset("cnn_dailymail", "3.0.0", split="validation[:2]")
val.shape, train.shape

((2, 3), (8, 3))

In [4]:
model_checkpoint ='t5-small'
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

pad_on_right = tokenizer.padding_side == "right"

In [5]:
def preprocess(examples):
    inputs = ['summarize:' + doc for doc in examples["article"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True,padding='max_length')

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["highlights"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [6]:
tokenized_train = train.map(preprocess, batched=True)
tokenized_valid = val.map(preprocess, batched=True)

Map:   0%|          | 0/2 [00:00<?, ? examples/s]



In [7]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
batch_size = 16
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [20]:
metric = evaluate.load("rouge", trust_remote_code=True)

In [9]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [10]:
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    "finetuned-t5-small",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    save_steps=2
)

In [11]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_valid,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [12]:
trainer.train()

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,No log,3.059961,16.5266,3.125,11.2045,16.5266,19.0
2,No log,3.053365,16.5266,3.125,11.2045,16.5266,19.0
3,No log,3.047538,16.5266,3.125,11.2045,16.5266,19.0
4,No log,3.042458,16.5266,3.125,11.2045,16.5266,19.0
5,No log,3.03825,16.5266,3.125,11.2045,16.5266,19.0
6,No log,3.034699,16.5266,3.125,11.2045,16.5266,19.0
7,No log,3.031888,16.5266,3.125,11.2045,16.5266,19.0
8,No log,3.029845,16.5266,3.125,11.2045,16.5266,19.0
9,No log,3.028525,16.5266,3.125,11.2045,16.5266,19.0
10,No log,3.027865,16.5266,3.125,11.2045,16.5266,19.0




TrainOutput(global_step=10, training_loss=2.3597612380981445, metrics={'train_runtime': 367.6403, 'train_samples_per_second': 0.218, 'train_steps_per_second': 0.027, 'total_flos': 21654688235520.0, 'train_loss': 2.3597612380981445, 'epoch': 10.0})

In [13]:
def predict_summary(document):
  device = model.device
  tokenized = tokenizer([document], truncation=True, padding ='longest',return_tensors='pt')
  tokenized = {k: v.to(device) for k, v in tokenized.items()}
  tokenized_result = model.generate(**tokenized, max_length=128)
  tokenized_result = tokenized_result.to('cpu')
  predicted_summary = tokenizer.decode(tokenized_result[0])
  return predicted_summary

In [17]:
doc = train[2]['article']
predict_summary(doc)

'<pad><extra_id_0>: "I probably had a 30-, 35-foot free fall. And there\'s cars on fire. The whole bridge is down" driver: "it just gave way, and it just gave way, all the way down" dozens of people were on the bridge when it collapsed. "it just gave way, and it just gave way, all the way to the ground," survivor says.</s>'

In [18]:
train[2]['highlights']

'NEW: "I thought I was going to die," driver says .\nMan says pickup truck was folded in half; he just has cut on face .\nDriver: "I probably had a 30-, 35-foot free fall"\nMinnesota bridge collapsed during rush hour Wednesday .'

In [19]:
doc

'MINNEAPOLIS, Minnesota (CNN) -- Drivers who were on the Minneapolis bridge when it collapsed told harrowing tales of survival. "The whole bridge from one side of the Mississippi to the other just completely gave way, fell all the way down," survivor Gary Babineau told CNN. "I probably had a 30-, 35-foot free fall. And there\'s cars in the water, there\'s cars on fire. The whole bridge is down." He said his back was injured but he determined he could move around. "I realized there was a school bus right next to me, and me and a couple of other guys went over and started lifting the kids off the bridge. They were yelling, screaming, bleeding. I think there were some broken bones."  Watch a driver describe his narrow escape » . At home when he heard about the disaster, Dr. John Hink, an emergency room physician, jumped into his car and rushed to the scene in 15 minutes. He arrived at the south side of the bridge, stood on the riverbank and saw dozens of people lying dazed on an expansive