# Fine tuning the T5 for text summarization
## T5
- It is a text to text transfer transformer. 
- T5 alone can be used to perform different NLP taska such as Text classification, language translation, text summarization, question answering makes itself a most flexible model.

**What it does?**  
1. Convert all problems to text to text generation.
    - For example: In Language Translation, English to Italian  
    Input: I love you  
    Output: Ti amo
    
    - For example: In Text Classification,  
    Input: This product is trash.  
    Output: Negative

2. Learns to predict [MASK] words.
3. Use task specific prefixes to guide the model during fine tuning.
   For example, it adds specific token at the beginning of the input text to indicate what  task is it performing.

T5 has been shown to achieve state-of-the-art results on a wide range of NLP tasks, and it’s considered a highly sophisticated and powerful NLP model, showing a high level of versatility, fine-tuning capability, and an efficient way to transfer knowledge.

## Implementation

In [1]:
!pip install datasets evaluate transformers rouge-score nltk

Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: rouge-score
  Building wheel for rouge-score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=3aa9bfbebbbf257e5eb7b26c906a452b326077d1154a975407688a9261c90a05
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge-score
Installing collected packages: rouge-score, evaluate
Successfully installed evaluate-0.4.1 rouge-score-0.1.2


If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:

In [2]:
import transformers

print(transformers.__version__)

4.37.0


We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

In [3]:
from transformers.utils import send_example_telemetry

send_example_telemetry("summarization_notebook", framework="pytorch")

In [4]:
model_checkpoint = "t5-small"

This notebook is built to run  with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the [`t5-small`](https://huggingface.co/t5-small) checkpoint.

### 1. Loading the dataset

We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.  

In [34]:
from datasets import load_dataset
from evaluate import load

raw_datasets = load_dataset("xsum")
metric = load("rouge")

  0%|          | 0/3 [00:00<?, ?it/s]

The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:

In [35]:
print(raw_datasets)

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})


To access an actual element, you need to select a split first, then give an index:

In [37]:
print(raw_datasets["train"][0])



To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

In [8]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [38]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,summary,id
0,"The sentences were suspended for three years, meaning they will not go to prison unless they reoffend, he adds.\nThe video shows three men and three unveiled women dancing on the streets and rooftops of Tehran.\nIn six months, it has been viewed by over one million people on YouTube.\nThe majority of people involved in the video were sentenced to six months in prison, with one member of the group given one year, lawyer Farshid Rofugaran was quoted by Iran Wire as saying.\nThe ""Happy we are from Tehran"" video was brought to the attention of the Iranian authorities in May, after receiving more than 150,000 views.\nMembers of the group behind the video were subsequently arrested by Iranian police for violating Islamic laws of the country, which prohibit dancing with members of the opposite sex and women from appearing without a headscarf.\nThey later appeared on state-run TV saying they were actors who had been tricked into make the Happy video for an audition.\nThe arrests drew condemnation from international rights groups and sparked a social media campaign calling for their release.\nWilliams, whose song was nominated for an Oscar earlier this year, also protested at the arrests.\n""It is beyond sad that these kids were arrested for trying to spread happiness,"" he wrote on Facebook.","Six Iranians arrested for appearing in a video dancing to Pharrell Williams' song Happy have been sentenced to up to one year in prison and 91 lashes, their lawyer says.",29272732
1,"Romelu Lukaku put the visitors in front with a penalty after keeper Jack Butland had brought down Tom Cleverley.\nSeamus Coleman headed in a Cleverley corner and Aaron Lennon intercepted a pass before slotting in as Everton went 3-0 up at the break.\nLukaku also had a header tipped on to the crossbar by Butland, while a poor Stoke struggled to create chances.\nRelive Everton's win against Stoke\nFollow reaction to Saturday's games\nMedia playback is not supported on this device\nA lot was made about Everton boss Roberto Martinez's dance moves at a Jason Derulo concert in the week, but it was his team impressing with their performance at the Britannia Stadium.\nEngland manager Roy Hodgson was at the game and he will have liked what he saw from Toffees midfielders Cleverley and Ross Barkley.\nThe industrious Cleverley burst through before being brought down for Everton's penalty, while his delivery from corners was a constant threat and led to Coleman's goal.\nBarkley's attacking instincts also played a part in the win and he could have had an assist when he crossed for Lukaku, whose header from close range was brilliantly saved by Butland.\n""I thought we were very strong in every department,"" said Martinez. ""Cleverley had a big influence in the game throughout.""\nStoke have endured a month to forget since their last league win against Norwich on 13 January.\nMark Hughes' side have been knocked out the Capital One Cup after a semi-final defeat on penalties by Liverpool, while they were beaten by Crystal Palace in the FA Cup.\nThe Potters have gained just one point from 12 in the league, dropping from seventh to 11th, and scored just one goal in six games.\nThe home side gave a debut to record £18.3m signing Gianelli Imbula but, like the rest of his team-mates, the midfielder struggled to make any kind of impact.\n""I thought Imbila did OK. I felt sorry for him because as a debut that was a hard one to come into,"" said Hughes.\nStoke boss Mark Hughes: ""We huffed and puffed and didn't really create again and that is a concern for us.\nMedia playback is not supported on this device\n""A disappointing day. We made mistakes at key times in the game and couldn't recover.\n""We have to pick ourselves up and start doing the fundamentals and basics.""\nEverton manager Roberto Martinez: ""We defended really well when we had to but the amount of opportunities we created is pleasing. If anything we should have scored three or four more in the second half.\n""We have to make sure we don't drop our standards now.""\nStoke's next game is at Bournemouth on 13 February, while Everton host West Brom on the same day with both games kicking off at 15:00 GMT.\nMatch ends, Stoke City 0, Everton 3.\nSecond Half ends, Stoke City 0, Everton 3.\nAttempt saved. Mame Biram Diouf (Stoke City) header from the right side of the six yard box is saved in the centre of the goal. Assisted by Joselu with a cross.\nFoul by Peter Odemwingie (Stoke City).\nBryan Oviedo (Everton) wins a free kick on the left wing.\nSubstitution, Everton. Leon Osman replaces James McCarthy.\nCorner, Everton. Conceded by Marc Muniesa.\nSubstitution, Everton. Kevin Mirallas replaces Ross Barkley.\nAttempt missed. Stephen Ireland (Stoke City) right footed shot from outside the box is too high. Assisted by Joselu.\nCorner, Stoke City. Conceded by Phil Jagielka.\nOffside, Stoke City. Joselu tries a through ball, but Glen Johnson is caught offside.\nAttempt missed. Ramiro Funes Mori (Everton) header from the centre of the box misses to the right. Assisted by Tom Cleverley with a cross.\nCorner, Everton. Conceded by Jack Butland.\nAttempt saved. Arouna Koné (Everton) right footed shot from the centre of the box is saved in the bottom left corner. Assisted by Ross Barkley.\nSubstitution, Stoke City. Joselu replaces Marko Arnautovic.\nSubstitution, Everton. Arouna Koné replaces Romelu Lukaku.\nGiannelli Imbula (Stoke City) wins a free kick in the defensive half.\nFoul by Gareth Barry (Everton).\nCorner, Stoke City. Conceded by Ramiro Funes Mori.\nAttempt blocked. Giannelli Imbula (Stoke City) left footed shot from the centre of the box is blocked. Assisted by Stephen Ireland.\nAttempt blocked. Ross Barkley (Everton) right footed shot from the centre of the box is blocked. Assisted by Romelu Lukaku.\nAttempt saved. Glen Johnson (Stoke City) left footed shot from outside the box is saved in the top centre of the goal. Assisted by Marko Arnautovic.\nAttempt missed. James McCarthy (Everton) right footed shot from outside the box is close, but misses to the left. Assisted by Romelu Lukaku.\nAttempt blocked. Ross Barkley (Everton) right footed shot from the left side of the box is blocked. Assisted by Romelu Lukaku.\nAttempt blocked. Peter Odemwingie (Stoke City) left footed shot from the centre of the box is blocked.\nOffside, Everton. Seamus Coleman tries a through ball, but Aaron Lennon is caught offside.\nAttempt blocked. Ross Barkley (Everton) right footed shot from the left side of the box is blocked. Assisted by Gareth Barry with a cross.\nAttempt saved. Romelu Lukaku (Everton) left footed shot from the centre of the box is saved in the centre of the goal.\nAttempt blocked. Romelu Lukaku (Everton) left footed shot from the centre of the box is blocked. Assisted by Aaron Lennon.\nSubstitution, Stoke City. Peter Odemwingie replaces Xherdan Shaqiri.\nSubstitution, Stoke City. Stephen Ireland replaces Ibrahim Afellay.\nXherdan Shaqiri (Stoke City) wins a free kick on the right wing.\nFoul by Bryan Oviedo (Everton).\nMame Biram Diouf (Stoke City) is shown the yellow card for a bad foul.\nFoul by Mame Biram Diouf (Stoke City).\nAaron Lennon (Everton) wins a free kick on the right wing.\nAttempt saved. Romelu Lukaku (Everton) header from very close range is saved in the top centre of the goal. Assisted by Ross Barkley with a cross.\nFoul by Erik Pieters (Stoke City).\nSeamus Coleman (Everton) wins a free kick on the right wing.\nCorner, Stoke City. Conceded by Gareth Barry.","Everton moved up to seventh after beating Stoke, who suffered a third successive Premier League defeat.",35447704
2,"The large billboard on Warwick Road urged electors to vote for Cat Smith - Labour's candidate for Lancaster and Fleetwood, 70 miles away.\nLabour's candidate for Carlisle is Lee Sherriff.\nA Labour spokesman said the advertising company working for the Lancaster party branch had made the error.\nThe spokesman said: ""This has now been removed. Carlisle Labour Party campaign were not involved in any way.""\nThe candidates for the Carlisle constituency are:\nThe candidates for the Lancaster and Fleetwood constituency are:",Labour has admitted a mistake was made after a poster featuring the wrong election candidate was put up in Carlisle.,32333596
3,"The images capture everything from personal moments to be shared with friends and family to large public events, image curation on a global scale.\nPenguin art director Jim Stoddart has pulled together a selection of those that caught his eye, for publication in a book.\nLife on Instagram, curated by Jim Stoddart, is published by Particular Books.","Since its launch in 2010, the photo-sharing site Instagram has grown at a phenomenal rate, with nearly 100 million pictures and videos being posted everyday.",37230012
4,"The Swiss, 32, broke Raonic in the first game and went on to win 6-4 6-4 6-4 in one hour and 42 minutes.\nHe will take on top seed Novak Djokovic in Sunday's final, after in four sets.\nFederer is trying to win his 18th Grand Slam title, and his first since beating Andy Murray at Wimbledon in 2012.\nFifteen years after making his first appearance at the All England Club, Federer has the chance to extend the record he has already set for major victories and break new ground for Wimbledon titles in the men's game.\n""That was a big victory,"" said the Swiss, who lost in the second round last year. ""I really had to focus on every point. I know that is always the case at this stage but it was hard.\n""I had to be very careful on my service games and I knew there were only going to be a few chances on his serve, but I am very, very happy.\n""I played some great tennis under pressure at times because I didn't play well here last year, and I expect a lot of myself. In the second week I have played better as the week has gone on.\n""Now I can look forward to another great match with Novak.""\nRaonic had made history just by reaching the last four, as the first Canadian man to do so, but suggestions the 23-year-old was ready to strike a blow for the younger generation proved misguided.\nThe difference in experience was vast, with Federer playing in his 35th Grand Slam semi-final and unbeaten in eight previous Wimbledon semi-finals.\nMoving superbly, attacking the net when possible and patiently waiting for his chances on the return, the Swiss looked as sharp as ever on the familiar ground of Centre Court.\nRaonic topped the standings for aces going into the semi-final, hit the second-fastest serve of the tournament at 141mph and dropped serve just twice.\nBut despite lacking his opponent's raw power, Federer had only been broken once and he offered up just a single break point as he dominated the match.\nHe got a huge boost with an immediate break following a double fault and an error from Raonic, and calmly served his way out of trouble at 4-3 on his way to clinching the set.\nThere was the expected flow of huge Raonic serves as the second set sped by, before Federer made his move at 4-4.\nA sweeping backhand down the line put the pressure on at 0-30 and Raonic succumbed with a wayward smash, allowing Federer to arrow another backhand winner.\nThe pattern repeated itself at 4-4 in the third, when Raonic opened with a double fault and soon found himself at 0-40, thumping a forehand over the baseline on the second break point.\nFederer drew a gasp from the 15,000 spectators with an unexpectedly rash forehand drive-volley when trying to close out the match, but a forehand into the corner brought up match point and a big serve finished the job.",Seven-time champion Roger Federer dismantled the big-serving game of Canadian Milos Raonic to reach his ninth Wimbledon final.,28165048


The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric):

In [39]:
metric

EvaluationModule(name: "rouge", module_type: "metric", features: [{'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id=None)}, {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}], usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLsum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/

You can call its `compute` method with your predictions and labels, which need to be list of decoded strings:

In [11]:
fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
metric.compute(predictions=fake_preds, references=fake_labels)

{'rouge1': 1.0, 'rouge2': 1.0, 'rougeL': 1.0, 'rougeLsum': 1.0}

### 2. Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.

To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:

- we get a tokenizer that corresponds to the model architecture we want to use,
- we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

In [12]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library.

You can directly call this tokenizer on one sentence or a pair of sentences:

In [13]:
tokenizer("Hello, this one sentence!")

{'input_ids': [8774, 6, 48, 80, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested.

Instead of one sentence, we can pass along a list of sentences:

In [14]:
tokenizer(["Hello, this one sentence!", "This is another sentence."])

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

To prepare the targets for our model, we need to tokenize them using the `text_target` parameter. This will make sure the tokenizer uses the special tokens corresponding to the targets:

In [15]:
print(tokenizer(text_target=["Hello, this one sentence!", "This is another sentence."]))

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}


If you are using one of the five T5 checkpoints we have to prefix the inputs with "summarize:" (the model can also translate and it needs the prefix to know which task it has to perform).

In [16]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

In [17]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [40]:
preprocess_function(raw_datasets['train'][:2])

{'input_ids': [[21603, 10, 37, 423, 583, 13, 1783, 16, 20126, 16496, 6, 80, 13, 8, 844, 6025, 4161, 6, 19, 341, 271, 14841, 5, 7057, 161, 19, 4912, 16, 1626, 5981, 11, 186, 7540, 16, 1276, 15, 2296, 7, 5718, 2367, 14621, 4161, 57, 4125, 387, 5, 15059, 7, 30, 8, 4653, 4939, 711, 747, 522, 17879, 788, 12, 1783, 44, 8, 15763, 6029, 1813, 9, 7472, 5, 1404, 1623, 11, 5699, 277, 130, 4161, 57, 18368, 16, 20126, 16496, 227, 8, 2473, 5895, 15, 147, 89, 22411, 139, 8, 1511, 5, 1485, 3271, 3, 21926, 9, 472, 19623, 5251, 8, 616, 12, 15614, 8, 1783, 5, 37, 13818, 10564, 15, 26, 3, 9, 3, 19513, 1481, 6, 18368, 186, 1328, 2605, 30, 7488, 1887, 3, 18, 8, 711, 2309, 9517, 89, 355, 5, 3966, 1954, 9233, 15, 6, 113, 293, 7, 8, 16548, 13363, 106, 14022, 84, 47, 14621, 4161, 6, 243, 255, 228, 59, 7828, 8, 1249, 18, 545, 11298, 1773, 728, 8, 8347, 1560, 5, 611, 6, 255, 243, 72, 1709, 1528, 161, 228, 43, 118, 4006, 91, 12, 766, 8, 3, 19513, 1481, 410, 59, 5124, 5, 96, 196, 17, 19, 1256, 68, 27, 103, 317, 132

To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.

In [41]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

  0%|          | 0/205 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.

Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

### 3. Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the `AutoModelForSeq2SeqLM` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us.

In [42]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

Note that  we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.

To instantiate a `Seq2SeqTrainer`, we will need to define three more things. The most important is the [`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments), which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:

In [43]:
batch_size = 16
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
)

In [44]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

The last thing to define for our `Seq2SeqTrainer` is how to compute the metrics from the predictions. We need to define a function for this, which will just use the `metric` we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:

In [45]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

Then we just need to pass all of this along with our datasets to the `Seq2SeqTrainer`:

In [46]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

We can now finetune our model by just calling the `train` method:

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss


The training works, so we are not gonna train it to even a single epoch, cause it takes almost 3 hours for a single epoch.