# Fine-tune T5 on x-sum

## Libraries and environment preparation

In [None]:
#Install essential packages
! pip install datasets transformers rouge-score nltk wandb

Collecting datasets
  Downloading datasets-1.15.1-py3-none-any.whl (290 kB)
[?25l[K     |█▏                              | 10 kB 34.2 MB/s eta 0:00:01[K     |██▎                             | 20 kB 20.4 MB/s eta 0:00:01[K     |███▍                            | 30 kB 16.3 MB/s eta 0:00:01[K     |████▌                           | 40 kB 15.2 MB/s eta 0:00:01[K     |█████▋                          | 51 kB 7.8 MB/s eta 0:00:01[K     |██████▊                         | 61 kB 7.5 MB/s eta 0:00:01[K     |████████                        | 71 kB 7.9 MB/s eta 0:00:01[K     |█████████                       | 81 kB 8.9 MB/s eta 0:00:01[K     |██████████▏                     | 92 kB 9.3 MB/s eta 0:00:01[K     |███████████▎                    | 102 kB 7.4 MB/s eta 0:00:01[K     |████████████▍                   | 112 kB 7.4 MB/s eta 0:00:01[K     |█████████████▌                  | 122 kB 7.4 MB/s eta 0:00:01[K     |██████████████▋                 | 133 kB 7.4 MB/s eta 0:00:01

In [None]:
#install Git-LFS
!apt install git-lfs

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following NEW packages will be installed:
  git-lfs
0 upgraded, 1 newly installed, 0 to remove and 37 not upgraded.
Need to get 2,129 kB of archives.
After this operation, 7,662 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 git-lfs amd64 2.3.4-1 [2,129 kB]
Fetched 2,129 kB in 1s (1,524 kB/s)
Selecting previously unselected package git-lfs.
(Reading database ... 155222 files and directories currently installed.)
Preparing to unpack .../git-lfs_2.3.4-1_amd64.deb ...
Unpacking git-lfs (2.3.4-1) ...
Setting up git-lfs (2.3.4-1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...


In [None]:
#Colab Environment Check for GPU and RAM
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

#GPU check
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Your runtime has 27.3 gigabytes of available RAM

Fri Nov 19 19:20:12 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------------------------------------

Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:

In [None]:
# Make sure your version of Transformers is at least 4.11.0 
# to run the following code correctly:

import transformers
print(transformers.__version__)

4.12.5


In [None]:
from transformers import AutoTokenizer    
# Huggingface Automodel class
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model_checkpoint = "t5-small"

In [None]:
# Import Wandb 
import os
import wandb
API_KEY = '39991c538626bee25c64d4f8a4c3403dd635537c'
os.environ["WANDB_API_KEY"] = API_KEY

[34m[1mwandb[0m: Currently logged in as: [33mshusunny[0m (use `wandb login --relogin` to force relogin)


## Loading the dataset

In [None]:
# import dataset and metrics with huggingface
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")

Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/954 [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset xsum/default (download: 245.38 MiB, generated: 507.60 MiB, post-processed: Unknown size, total: 752.98 MiB) to /root/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499...


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset xsum downloaded and prepared to /root/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

In [None]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [None]:
# Visualize the Data

import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=3):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(raw_datasets["train"], num_examples=4)

Unnamed: 0,document,summary,id
0,"The shooting in Cannon Street, Bolton, on 10 October was one of four which took place in 24 hours.\nGreater Manchester Police said the shooting outside a pub was an ""isolated incident"".\nIt was said to be the only one of the four incidents on the day that was not connected to a gang feud in Salford.",A 22-year-old man has been arrested on suspicion of attempted murder over the shooting of a man in the stomach in Greater Manchester earlier this month.,34590168
1,"The suspects include an army general, senior police officers and local politicians.\nThe court hearing comes after an investigation into the trafficking of Bangladeshis and Rohingya Muslims from Myanmar through Thailand.\nSeveral mass graves were discovered near the Thai-Malaysian border this year, sparking an international outcry.\nThe bodies were thought to be those of migrants trapped in camps in the jungle.\nBangladeshi migrants, and ethnic Rohingyas who face persecution in Myanmar, are normally brought by people smugglers to Thailand where they are effectively held ransom until they can raise money for their onward journey to Malaysia and beyond.\nAn investigation by the BBC's Jonathan Head has found entire communities in Thailand helping the traffickers.\nRead more on Asia's migrant crisis\nSold for sale: Tracing Thailand's human traffickers\nWhy are so many Rohingya migrants stranded at sea?\nOn the trail of Myanmar's Rohingya migrants\nThe 88 suspects appeared in court on Tuesday for a pre-trial hearing.\nThe trial begins next week - however, there are about 500 witnesses, and a court official has warned that it could take two years before the court reaches a verdict.\nDescribing the charges, one judge was quoted by AFP as saying: ""All 88 defendants together let victims starve, denied health treatments for sick victims and hid bodies on the mountain [camps] where they died.""\nMeanwhile, a senior policeman leading the investigation into human trafficking has submitted his resignation, saying he fears for his safety.\nMajor General Paween Pongsirin said his superiors had ordered him to transfer to a new position in southern Thailand, where he could be at risk from retribution from trafficking gangs.\nLast month, he said in an interview with AFP that the investigation had been wound up before the case was ""completely finished"", and that ""there are more people involved because this problem has accumulated for a long time"".\nThe Thai government launched a crackdown on smugglers in May, arresting suspects including local mayor Banjong Pongphon, and transferring over 50 police officers suspected of links to human traffickers.\nFollowing the crackdown, smugglers began to abandon boats carrying migrants, leading to thousands of Bangladeshis and Rohingya Muslims being stranded at sea.",Thailand has begun a court hearing in a major people smuggling case involving 88 suspected human traffickers.,34774879
2,"He is almost half-way through his 31-city tour of UK cities during which he is taking part in debates to encourage the public to vote to remain in the EU.\nIzzard claimed a British exit (Brexit) from the EU could lead to a recession.\nBut his claims were rejected by Leave campaigner Sammy Wilson from the DUP who described them as ""not factual"".\nThe pair disagreed over a number of Brexit-related issues, including the consequences for the economy, border controls, freedom of movement and immigration.\nIzzard, who lived in Bangor, County Down, from the age of one to five, visited Methodist College in Belfast on Thursday evening as part of his 'Stand Up for Europe' campaign.\nSpeaking to BBC Newsline, he said: ""Immigration is an issue and people do get worried about it but if we change it - if we pull out then we'll go into recession.\n""Then you're worried about jobs. Jobs will get even worse, You'll have even less opportunity for jobs because the economy will have gone down.\n""So immigration is a thing that we need to control but I don't think Brexit - pulling out - is going to make the big difference.""\nHowever Mr Wilson told BBC NI's The View programme: ""If you look at the countries which are in recession at the moment, they are the countries which are most closely tied into the EU and especially into the eurozone.\n""If you look at the countries where young people don't have a chance for the future - 50% youth unemployment in Spain, Italy, the economy of Greece in ruins.\n""Why? Because they tied themselves into the European project and they tied themselves into the Euro.""\nThe referendum takes place on 23 June.","The comedian Eddie Izzard has returned to Northern Ireland, where he spent part of his childhood, to campaign for the UK to stay in the European Union.",36438971
3,"Judge Lord Eassie said the Court of Justice of the European Union should give its opinion on the proposal.\nThe case was brought by The Scotch Whisky Association, which argued the legislation breached European law.\nHolyrood ministers have said minimum pricing was vital to address Scotland's ""unhealthy relationship with drink"".\nThe Court of Session judgement means there could be a delay of up to two years before Scottish government plans to set a 50p rate per unit of alcohol can be implemented.\nBy Reevel AldersonHome affairs correspondent, BBC Scotland\nAlthough the Court of Session is Scotland's highest civil court, it is now asking another court for its views on the government's minimum pricing proposals.\nJudges in Edinburgh will ask the Court of Justice of the European Union for a ""preliminary ruling"" - on whether the proposals would be valid under EU law.\nThe procedure exists to ensure EU law is properly applied in each country.\nBut before the case can be heard by the court in Luxembourg, the questions it will be asked must be decided.\nThis will involve another hearing in Edinburgh at which the Scottish government and the Scotch Whisky Association (SWA) along with other parties to the case will give their views.\nOverall, it could take between 15 months and two years before a ruling will be given by Luxembourg.\nEven then it may not be possible for the Scottish government to implement the policy, enacted in May 2012, for a minimum price of 50p per unit of alcohol.\nEither the government or the SWA could appeal to the UK Supreme in London, a process which would take several months more.\nHowever today's reference to Luxembourg would at least mean that there could be no further appeal in the case to a European court.\nThe Scotch Whisky Association, whose members account for more than 90% of the industry's production, had appealed against a Court of Session ruling that the minimum alcohol pricing policy was within the powers of Scottish ministers and not incompatible with EU law.\nTwo major European wine and spirit organisations are also party to the SWA's appeal.\nScottish Health Secretary Alex Neil said he was ""frustrated"" at the challenge to a democratic decision of the Scottish Parliament but expressed determination to see it through.\nMr Neil said a final decision would be made by the Court of Session, once legal opinion was received from the Court of Justice, in Luxembourg.\nHe added: ""The first time we went to the Court of Session they gave us a ringing endorsement and were very clear that what we were doing was perfectly legitimate in law and I'm very confident we will end up with that decision being reinforced in two years' time.""\nIn his written judgement, Lord Eassie said: ""We have come to the view that - as heralded in the debate before us - the present proceedings raise aspects of those tests and of the role of the national court which are not clearly established.\n""There are thus aspects relating to the Scottish ministers claim of justification under article 36 TFEU (Treaty of the Functioning of the EU) which we consider that it would be of help to have the guidance of the Court of Justice of the European Law.""\nLiver deaths\nLegislation to bring in the government's price plan was passed by parliament in May 2012 but ongoing legal challenges have prevented the policy from being implemented.\nAccording to NHS figures, Scottish deaths from chronic liver disease are among the highest in Europe, while alcohol kills the equivalent of 20 people a week in Scotland.\nScottish ministers said their minimum pricing plan, under which the Â­cheapest bottle of wine would be Â£4.69 and a four-pack of lager would cost at least Â£3.52, would help tackle the problem.\nScotch Whisky Association chief executive David Frost, said of the latest ruling: ""We are pleased that the Court of Session in Edinburgh is referring the minimum unit pricing case to the Court of Justice of the European Union.\n""From the outset we said that we believed minimum unit pricing was contrary to European Union law and that it was likely in the end to go to the European Court.\n""We also believe minimum unit pricing would be ineffective in tackling alcohol misuse and would damage the Scotch Whisky industry in the UK and overseas.""\nThe UK government previously shelved plans for minimum pricing in England and Wales, after Prime Minister David Cameron cited concerns over evidence it would not work and possible legal challenges.",A legal challenge to the Scottish government's policy on minimum alcohol pricing has been referred to a European court by the Court of Session.,27219905


## Preprocessing the data

In [None]:
# Import tokenizer from model checkpoint
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

In [None]:
# If you are using one of the five T5 checkpoints we have to prefix 
# the inputs with "summarize:" (t5 is a multi-task model).

if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

For Xsum, the input tokens are about 1500 and the length of the summaries are about 160. Here we truncate to 1024 and 128

In [None]:
# tokenlize inputs into map

max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

  0%|          | 0/205 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

In [None]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'document', 'id', 'input_ids', 'labels', 'summary'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['attention_mask', 'document', 'id', 'input_ids', 'labels', 'summary'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['attention_mask', 'document', 'id', 'input_ids', 'labels', 'summary'],
        num_rows: 11334
    })
})

In [None]:
# Code to find out the max tokenized length
"""
max_document = 0
max_summary = 0
my_splits = ['test', 'train', 'validation']
for i in my_splits:
  for item in tokenized_datasets[i]['input_ids']:
      if len(item) > max_document:
          max_document = len(item)


  for item in tokenized_datasets[i]['labels']:
      if len(item) > max_summary:
          max_summary = len(item)

max_document, max_summary

"""

"\nmax_document = 0\nmax_summary = 0\nmy_splits = ['test', 'train', 'validation']\nfor i in my_splits:\n  for item in tokenized_datasets[i]['input_ids']:\n      if len(item) > max_document:\n          max_document = len(item)\n\n\n  for item in tokenized_datasets[i]['labels']:\n      if len(item) > max_summary:\n          max_summary = len(item)\n\nmax_document, max_summary\n\n"

## Fine-tuning the model

In [None]:
# Import tokenizer from model checkpoint and print detail
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model.config

Downloading:   0%|          | 0.00/231M [00:00<?, ?B/s]

T5Config {
  "_name_or_path": "t5-small",
  "architectures": [
    "T5WithLMHeadModel"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 6,
  "num_heads": 8,
  "num_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summarize: "
    },
    "translation_en_to_de": {
      "early_stopping": true,
      "max_length": 300,
      "num_beams": 4,
      "prefix": "translate English to German: "
    },
    "translation_en_to_fr": {
      "early_stopping": true,

In [None]:
# data collator: pad the inputs and labels during each batch to save space
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

Define `Seq2SeqTrainer` to compute the metrics from the predictions, and also do a bit of pre-processing to decode the predictions into texts:

In [None]:
# Define traing args, batch size and epoch
# batch size max 8 for input length 1024 on Colab Pro

batch_size = 8
epochs = 1
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_strategy = "epoch",
    save_total_limit=3,
    num_train_epochs=epochs,
    predict_with_generate=True,
    fp16=True,
    #push_to_hub=True,
)

In [None]:
import nltk
import numpy as np
nltk.download('punkt')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
# Pass into the trainer

train_dataset=tokenized_datasets["train"]
eval_dataset=tokenized_datasets["validation"]

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using amp fp16 backend


We can now finetune our model by just calling the `train` method:

In [None]:
# keep track with wandb
wandb.init(project="Transformers")

In [None]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document.
***** Running training *****
  Num examples = 204045
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 25506
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.68,2.450021,28.7115,8.0254,22.6179,22.6134,18.8176


The following columns in the evaluation set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document.
***** Running Evaluation *****
  Num examples = 11332
  Batch size = 8
Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-25506
Configuration saved in t5-small-finetuned-xsum/checkpoint-25506/config.json
Model weights saved in t5-small-finetuned-xsum/checkpoint-25506/pytorch_model.bin
tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-25506/tokenizer_config.json
Special tokens file saved in t5-small-finetuned-xsum/checkpoint-25506/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=25506, training_loss=2.7389122431188304, metrics={'train_runtime': 10230.7036, 'train_samples_per_second': 19.944, 'train_steps_per_second': 2.493, 'total_flos': 5.260162153729229e+16, 'train_loss': 2.7389122431188304, 'epoch': 1.0})

In [None]:
# Init new logging params
wandb.init(project="Transformers")

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
eval/gen_len,▁
eval/loss,▁
eval/rouge1,▁
eval/rouge2,▁
eval/rougeL,▁
eval/rougeLsum,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████

0,1
eval/gen_len,18.8176
eval/loss,2.45002
eval/rouge1,28.7115
eval/rouge2,8.0254
eval/rougeL,22.6179
eval/rougeLsum,22.6134
eval/runtime,539.7401
eval/samples_per_second,20.995
eval/steps_per_second,2.625
train/epoch,1.0


## Trying with a smaller dataset

In [None]:
# Select to get smaller dataset
small_train = raw_datasets['train'].select(list(range(0, 5000)))
small_val = raw_datasets['validation'].select(list(range(0, 500)))
small_train

Dataset({
    features: ['document', 'summary', 'id'],
    num_rows: 5000
})

In [None]:
tokenized_train = small_train.map(preprocess_function, batched=True)
tokenized_val = small_val.map(preprocess_function, batched=True)
tokenized_train

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Dataset({
    features: ['attention_mask', 'document', 'id', 'input_ids', 'labels', 'summary'],
    num_rows: 5000
})

In [None]:
# Import a new T5-small
model_small = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

loading configuration file https://huggingface.co/t5-small/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fe501e8fd6425b8ec93df37767fcce78ce626e34cc5edc859c662350cf712e41.406701565c0afd9899544c1cb8b93185a76f00b31e5ce7f6e18bbaef02241985
Model config T5Config {
  "architectures": [
    "T5WithLMHeadModel"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 6,
  "num_heads": 8,
  "num_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "pre

T5Config {
  "_name_or_path": "t5-small",
  "architectures": [
    "T5WithLMHeadModel"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 6,
  "num_heads": 8,
  "num_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summarize: "
    },
    "translation_en_to_de": {
      "early_stopping": true,
      "max_length": 300,
      "num_beams": 4,
      "prefix": "translate English to German: "
    },
    "translation_en_to_fr": {
      "early_stopping": true,

In [None]:
# data collator: pad the inputs and labels during each batch to save space
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model_small)

In [None]:
# Define traing args, batch size and epoch
# batch size max 16 on Colab Pro

batch_size = 8
epochs = 30
model_name = model_checkpoint.split("/")[-1]
args_small = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum-small",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_strategy = "epoch",
    save_total_limit=3,
    num_train_epochs=epochs,
    predict_with_generate=True,
    fp16=True,
    #push_to_hub=True,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [None]:
# Pass into the trainer

train_dataset=tokenized_train
eval_dataset=tokenized_val

trainer_small = Seq2SeqTrainer(
    model_small,
    args_small,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using amp fp16 backend


In [42]:
trainer_small.train()

The following columns in the training set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document.
***** Running training *****
  Num examples = 5000
  Num Epochs = 30
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 18750
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,3.1016,2.691421,23.2339,4.4677,17.9706,17.9657,18.726
2,2.9066,2.642424,25.2222,5.3867,19.8453,19.8184,18.728
3,2.8256,2.612656,25.795,5.9446,20.0724,20.0238,18.808
4,2.7699,2.596567,26.0262,6.1109,20.3682,20.3342,18.774
5,2.7422,2.584278,26.7538,6.5249,20.8441,20.8379,18.8
6,2.7231,2.574184,26.4894,6.5564,20.8198,20.7872,18.824
7,2.7074,2.566135,26.6138,6.4601,20.8963,20.8639,18.794
8,2.6697,2.560822,26.6053,6.4283,20.6104,20.576,18.816
9,2.6432,2.555719,26.8249,6.6531,20.9511,20.908,18.818
10,2.6359,2.549645,26.9239,6.7976,21.1995,21.1711,18.842


The following columns in the evaluation set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
Saving model checkpoint to t5-small-finetuned-xsum-small/checkpoint-625
Configuration saved in t5-small-finetuned-xsum-small/checkpoint-625/config.json
Model weights saved in t5-small-finetuned-xsum-small/checkpoint-625/pytorch_model.bin
tokenizer config file saved in t5-small-finetuned-xsum-small/checkpoint-625/tokenizer_config.json
Special tokens file saved in t5-small-finetuned-xsum-small/checkpoint-625/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
Saving model checkpoint to t5-small-finetuned-xsum-small/checkpoint-1250
Configuration saved i

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,3.1016,2.691421,23.2339,4.4677,17.9706,17.9657,18.726
2,2.9066,2.642424,25.2222,5.3867,19.8453,19.8184,18.728
3,2.8256,2.612656,25.795,5.9446,20.0724,20.0238,18.808
4,2.7699,2.596567,26.0262,6.1109,20.3682,20.3342,18.774
5,2.7422,2.584278,26.7538,6.5249,20.8441,20.8379,18.8
6,2.7231,2.574184,26.4894,6.5564,20.8198,20.7872,18.824
7,2.7074,2.566135,26.6138,6.4601,20.8963,20.8639,18.794
8,2.6697,2.560822,26.6053,6.4283,20.6104,20.576,18.816
9,2.6432,2.555719,26.8249,6.6531,20.9511,20.908,18.818
10,2.6359,2.549645,26.9239,6.7976,21.1995,21.1711,18.842


The following columns in the evaluation set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
Saving model checkpoint to t5-small-finetuned-xsum-small/checkpoint-11250
Configuration saved in t5-small-finetuned-xsum-small/checkpoint-11250/config.json
Model weights saved in t5-small-finetuned-xsum-small/checkpoint-11250/pytorch_model.bin
tokenizer config file saved in t5-small-finetuned-xsum-small/checkpoint-11250/tokenizer_config.json
Special tokens file saved in t5-small-finetuned-xsum-small/checkpoint-11250/special_tokens_map.json
Deleting older checkpoint [t5-small-finetuned-xsum-small/checkpoint-9375] due to args.save_total_limit
The following columns in the evaluation set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document.
***** Running Evaluation *****
  Num examples = 50

TrainOutput(global_step=18750, training_loss=2.5982989680989586, metrics={'train_runtime': 7877.977, 'train_samples_per_second': 19.04, 'train_steps_per_second': 2.38, 'total_flos': 3.853964378583859e+16, 'train_loss': 2.5982989680989586, 'epoch': 30.0})

In [43]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
eval/gen_len,▆▇▁▁▄▃▄▅▄▅▅▆▅▇▆▅▆▆▆▅▆▇▇████▇██▇█
eval/loss,▁▁█▇▆▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
eval/rouge1,██▁▃▄▄▅▅▅▅▅▅▆▆▇▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇
eval/rouge2,██▁▃▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▆▆
eval/rougeL,██▁▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇
eval/rougeLsum,██▁▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇
eval/runtime,▂▂▂▆▅▃▂█▆▅▅▇▅▅▅▆▅▄▃▄▅▄▂▂▅▃▁▃▄▄▃▂
eval/samples_per_second,▆▇▇▃▄▅▇▁▃▄▄▂▄▄▄▃▄▅▆▅▄▅▇▇▄▆█▆▅▅▆▇
eval/steps_per_second,▇▇▇▃▄▅▇▁▄▄▄▂▄▄▄▃▄▅▆▅▄▅▇▇▄▆█▆▅▅▆▇
train/epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████

0,1
eval/gen_len,18.882
eval/loss,2.5311
eval/rouge1,28.1674
eval/rouge2,7.3609
eval/rougeL,22.0142
eval/rougeLsum,21.9739
eval/runtime,24.0905
eval/samples_per_second,20.755
eval/steps_per_second,2.615
train/epoch,30.0
