# Summarization
## This notebook outlines the concepts behind finetuning a Summarization model using T-5 BERT variant model

In [1]:
import torch
torch.cuda.empty_cache()

In [3]:
! pip install -q datasets transformers rouge-score nltk

[K     |████████████████████████████████| 235kB 2.9MB/s 
[K     |████████████████████████████████| 2.3MB 14.3MB/s 
[K     |████████████████████████████████| 112kB 23.7MB/s 
[K     |████████████████████████████████| 245kB 23.0MB/s 
[K     |████████████████████████████████| 901kB 21.5MB/s 
[K     |████████████████████████████████| 3.3MB 24.1MB/s 
[31mERROR: transformers 4.6.1 has requirement huggingface-hub==0.0.8, but you'll have huggingface-hub 0.0.9 which is incompatible.[0m
[?25h

# Fine-tuning a model on a summarization task

In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model for a summarization task. We will use the [XSum dataset](https://arxiv.org/pdf/1808.08745.pdf) (for extreme summarization) which contains BBC articles accompanied with single-sentence summaries.

![Widget inference on a summarization task](https://github.com/huggingface/notebooks/blob/master/examples/images/summarization.png?raw=1)

We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using the `Trainer` API.

In [4]:
model_checkpoint = "t5-small"

This notebook is built to run  with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the [`t5-small`](https://huggingface.co/t5-small) checkpoint. 

## Loading the dataset

We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.  

In [5]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1930.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=954.0, style=ProgressStyle(description_…

Using custom data configuration default



Downloading and preparing dataset xsum/default (download: 245.38 MiB, generated: 507.60 MiB, post-processed: Unknown size, total: 752.98 MiB) to /root/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=254582292.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1001503.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset xsum downloaded and prepared to /root/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2170.0, style=ProgressStyle(description…




The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:

In [6]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

To access an actual element, you need to select a split first, then give an index:

In [7]:
raw_datasets["train"][0]

{'document': 'Recent reports have linked some France-based players with returns to Wales.\n"I\'ve always felt - and this is with my rugby hat on now; this is not region or WRU - I\'d rather spend that money on keeping players in Wales," said Davies.\nThe WRU provides £2m to the fund and £1.3m comes from the regions.\nFormer Wales and British and Irish Lions fly-half Davies became WRU chairman on Tuesday 21 October, succeeding deposed David Pickering following governing body elections.\nHe is now serving a notice period to leave his role as Newport Gwent Dragons chief executive after being voted on to the WRU board in September.\nDavies was among the leading figures among Dragons, Ospreys, Scarlets and Cardiff Blues officials who were embroiled in a protracted dispute with the WRU that ended in a £60m deal in August this year.\nIn the wake of that deal being done, Davies said the £3.3m should be spent on ensuring current Wales-based stars remain there.\nIn recent weeks, Racing Metro fla

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

In [8]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [9]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,id,summary
0,"Qantas had already cancelled flights to and from Tasmania and parts of New Zealand. Other airlines have also grounded flights in the region, stranding thousands of travellers.\nAustralian airline Qantas said it was too dangerous to fly through the ash.\nThe Puyehue-Cordon Caulle volcano has been erupting for more than a week.\nStrong winds have carried the fine particles of ash from the volcano to southern New Zealand and Australia at between 6,000 and 10,600m (20,000 and 35,000ft).\nThe particles have the potential to seriously damage jet engines.\nQantas said all its flights to and from Melbourne and Auckland, New Zealand, were being grounded from Sunday evening local time.\nAll its flights in and out of Tasmania and the New Zealand cities of Christchurch, Queenstown and Wellington have already been cancelled.\nQantas said about 8,000 passengers would be affected by its cancellations.\nVirgin Australia said it was cancelling 34 domestic flights and one international flight from Melbourne on Sunday.\nBudget carrier Jetstar has also cancelled flights within New Zealand and flights from New Zealand to Australia and from Tasmania to the rest of Australia.\nPassengers at Tasmania's Hobart airport told ABC news they might be stuck on the island for several days. Ferries from Tasmania to Australia's mainland are booked up for at least two days.\nAir New Zealand said it had not cancelled any flights, but was adjusting flight paths to steer aircraft below the ash.\nNew Zealand's Civil Aviation Authority warned that the ash was expected to be detected at the cruising level of aircraft but had not yet been seen below 20,000ft.\n""Given that the volcanic activity is continuing, it is expected that New Zealand airspace may be affected by these plumes for at least a week,"" it said in a statement.\nThe volcano has already caused severe disruption to flights in South America, with planes grounded for several days in Chile, Argentina, Uruguay and Brazil.\nIt is the first serious eruption of the volcano chain since 1960, when the area was hit by a massive earthquake.\nThousands of people are living in temporary shelters after being evacuated from the area around the volcano.",13740877,"Qantas has cancelled all its flights in and out of Melbourne, Australia, because of ash drifting over the Pacific Ocean from a volcano in Chile."
1,"Eighteen years ago - in 1999 - Mr Blair first advocated active military interventionism to overturn dictators and protect civilians.\nNow, Mrs May has repudiated much of what he said then.\nShe talked of ""the failed policies of the past"", before making her crucial declaration of new foreign policy doctrine: ""The days of Britain and America intervening in sovereign countries in an attempt to remake the world in our own image are over.""\nOf course, by saying that she was also overturning the approach of her predecessor, David Cameron. The current prime minister has also dismissed her predecessor's armed intervention in Libya.\nIts aftermath - a failed state, far from recovery - haunts Britain still.\nThis declaration of an apparently radical shift in policy by the prime minister should be read in conjunction with what appears to be an extraordinary British U-turn over Syria, which was set out in colourful terms by her foreign secretary only a few hours earlier.\nBoris Johnson conceded the most bitter and recent failure of British foreign policy when he openly acknowledged what amounts to a fundamental defeat over Syria.\nHe called Britain's stance ""catastrophic"", shifting from the pledge of support over many years to the non-jihadist opponents of President Assad, to a position where Britain - together with the United States - retreated from the field and left it open to Russian military dominance.\nMr Johnson told a committee in the House of Lords that President Assad should now be permitted to run for election as part of a ""democratic resolution"" of the civil war - although he did also make clear there could be no sustainable peace in Syria as long as he remains.\nHe admitted the downsides of doing ""such a complete flip-flop"", but said the UK had been unable at any stage to fulfil its mantra that the Syrian president should go.\nMr Johnson was accepting Russia's victory - and at the same time swallowing the bitter pill of defeat for London and for Washington.\nHe said that had flowed from the refusal of the House of Commons, in August 2013, to back punitive British military action against President Assad for his use of chemical weapons - something the Syrian leader still denies.\nWithin days, President Obama had followed Britain in retreat.\nPublic appetite in both countries for almost any military intervention overseas had drained away after the years of intervention in Afghanistan, in Iraq, and in Libya.\nIt is very difficult to see circumstances in which Britain or the US will send forces against a sovereign government in the future.\nExtremists - non-state actors - are almost the only acceptable target now.\nThe Foreign Office does not believe their political master was as explicit as I suggest, and believe that the essentials of British policy on Syria have not fundamentally changed.\nCertainly, the prime minister did leave herself some wriggle room.\nShe argued against the sort of increased isolationism which President Donald Trump has championed, and urged the maintenance of the ""special relationship"" as a way to provide joint leadership in the world.\nShe said the two nations should not ""stand idly by when the threat is real"".\nNevertheless, the political presentation of British foreign policy by the prime minister and foreign secretary has deployed a distinctly new and sometimes startling language.\nThe direction being set in response to past failures and disappointments is different.\nIt may be largely a public recognition of some brutal realities, which have been emerging over several years, but it is new and important.",38776377,Theresa May's Philadelphia speech is hugely significant - arguably the biggest by a British prime minister in the US since Tony Blair's in Chicago.
2,"The deal was made at talks in Berlin involving foreign ministers from the two countries, and France and Germany.\nIt reportedly covers mortars, tanks and heavy weapons below 100mm calibre.\nBut Germany warned the talks had also emphasised differences over the year-old conflict between Ukraine's military and pro-Russian rebels.\n""It was again a very long, very intensive discussion which in parts was very controversial,"" German Foreign Minister Frank-Walter Steinmeier said.\n""During these talks today the differences of opinion between Kiev and Moscow also became clear once again.""\nBut, he added, all parties had also reaffirmed the commitment to a ceasefire agreed in February in the Belarussian capital, Minsk.\nBoth sides are largely thought to have adhered to the deal - until a recent escalation of fighting in the flashpoints of Donetsk airport and Shyrokyne village, on the outskirts of the strategic town of Mariupol.\nA joint final statement by the ministers in Berlin expressed ""concern"" at the escalation.\n""What is decisive is that in light of the worsening situation, we agreed today not only to continue with the withdrawal of heavy weapons but also to include other categories of weapons in the withdrawal,"" Mr Steinmeier said.\n""Now, tanks, armoured vehicles, mortars and heavy weapons below a 100 mm (3.94 inches) calibre will be included in the withdrawal commitment.""\nBoth Ukraine and the rebels claim to have withdrawn heavy weapons from the line of contact, although sporadic shelling has continued.\nUkraine, Western leaders and Nato say there is clear evidence that Russia has helped the rebels with troops and heavy weapons. Russia denies that, insisting that any Russians on the rebel side are ""volunteers"".\nMore than 6,000 people have been killed in clashes since the rebels seized large parts of the Donetsk and Luhansk regions last April - a month after Russia annexed Ukraine's southern Crimea peninsula.",32296796,"Russia and Ukraine have agreed to call for the withdrawal of more types of weapons in Ukraine's east, as fresh clashes renew fears for a truce there."
3,"Opposition MSPs had called for Fiona Hyslop to be quizzed for a second time over the £150,000 grant.\nThe committee instead decided to write to her with follow up questions.\nMs Hyslop gave evidence to the education and culture committee last week, but opposition members claimed her answers were not satisfactory.\nLib Dem, Tory and Labour MSPs wrote to committee convener Stewart Maxwell asking him to recall the minister.\nHowever, the committee unanimously agreed to follow up the evidence session in writing.\nMs Hyslop had insisted funding was appropriate, transparent and in line with amounts given to other events.\nQuestions were raised over the awarding of the money after it emerged former SNP adviser Jennifer Dempsie set up meetings between the festival's promoters DF Concerts and ministers, including Ms Hyslop, ahead of the application for the funding.\nMs Dempsie was working on a contract for DF Concerts as a project manager on the festival, which moved to a new location at Strathallan this year.\nLib Dem Liam McArthur, Tory Liz Smith and Labour members John Pentland and Mark Griffin were the MSPs who signed the letter which has been sent to convener Stewart Maxwell.\nThey believe Ms Hyslop failed to provide satisfactory answers when she appeared at the parliament last week.\nHowever, the minister believed she had acted properly and the funding from the major events budget had been approved ""following a detailed consideration of options"" for operational costs associated with the transition to the festival's new site.\nMr Pentland said that while the event was something to be proud of, questions still remained unanswered.\nHe added: ""We have seen suggestions that the event would not have been viable, would have to have been a one-day event or would have to have been moved out of Scotland without public cash. None of this seems credible.""\nA spokesman for Ms Hyslop said: ""The education and culture committee's busy remit covers a range of important issues such as school attainment and improving access to childcare - yet opposition parties seem more interested in pursuing conspiracy theories that have already been comprehensively refuted.\n""Ms Hyslop has already answered questions from the opposition for well over an hour on this at the committee and stayed on until no members had any questions left to ask.\n""She has also answered questions on T in the Park both in writing and in the Scottish Parliament chamber, as well as publishing over 600 pages of relevant documents.""",34452182,Holyrood's culture committee is to write to the culture secretary with further questions about government funding for T in the Park festival.
4,"Taylor, 28, will sit out the World Cup qualifier in Serbia on 11 June after being sent off for the challenge that broke the leg of Republic captain Seamus Coleman in Friday's 0-0 draw.\nThe suspension could increase to three games if Fifa reviews the incident once it receives the referee's report.\nGareth Bale will also miss the Serbia game after being booked on Friday.\nIt was the Real Madrid forward's second yellow card of the qualification campaign.\nRepublic boss Martin O'Neill, preparing his side for a friendly against Iceland on Tuesday, said the tackles by Taylor and Bale - on John O'Shea - were ""very, very poor"".\nAnd Irish Prime Minister Enda Kenny told the Irish Times the tackle by Taylor was ""horrific"".\nO'Shea needed stitches after he was tackled by Bale in the second half of Friday's game, which left Wales four points behind the Republic and Serbia with five matches left.\nWales boss Chris Coleman defended his players, and said Bale did not even think the challenge on O'Shea merited a booking.",39409418,Wales defender Neil Taylor could face more than a one-game ban for his red card against the Republic of Ireland.


The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric):

In [10]:
metric

Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_agregator: Return aggregates if this is set to True
Retu

You can call its `compute` method with your predictions and labels, which need to be list of decoded strings:

In [11]:
fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
metric.compute(predictions=fake_preds, references=fake_labels)

{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rouge2': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeL': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}

## Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.

To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:

- we get a tokenizer that corresponds to the model architecture we want to use,
- we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

In [12]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1197.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1389353.0, style=ProgressStyle(descript…




By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library.

You can directly call this tokenizer on one sentence or a pair of sentences:

In [13]:
tokenizer("Hello, this one sentence!")

{'input_ids': [8774, 6, 48, 80, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested.

Instead of one sentence, we can pass along a list of sentences:

In [14]:
tokenizer(["Hello, this one sentence!", "This is another sentence."])

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

To prepare the targets for our model, we need to tokenize them inside the `as_target_tokenizer` context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets:

In [15]:
with tokenizer.as_target_tokenizer():
    print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}


If you are using one of the five T5 checkpoints we have to prefix the inputs with "summarize:" (the model can also translate and it needs the prefix to know which task it has to perform).

In [16]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

In [17]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [18]:
preprocess_function(raw_datasets['train'][:2])

{'input_ids': [[21603, 10, 17716, 2279, 43, 5229, 128, 1410, 18, 390, 1508, 28, 5146, 12, 10256, 5, 96, 196, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 976, 243, 28571, 5, 37, 549, 8503, 795, 17586, 51, 12, 8, 3069, 11, 3996, 13606, 51, 639, 45, 8, 6266, 5, 18263, 10256, 11, 2390, 11, 7262, 10371, 7, 3971, 18, 17114, 28571, 1632, 549, 8503, 13404, 30, 2818, 1401, 1797, 6, 7229, 53, 20, 12151, 1955, 8356, 49, 53, 826, 3, 19585, 643, 9768, 5, 216, 19, 230, 3122, 3, 9, 2103, 1059, 12, 1175, 112, 1075, 38, 24260, 350, 16103, 10282, 7, 5752, 4297, 227, 271, 3, 11060, 30, 12, 8, 549, 8503, 1476, 16, 1600, 5, 28571, 47, 859, 8, 1374, 5638, 859, 10282, 7, 6, 411, 7, 2026, 63, 7, 6, 14586, 7677, 11, 26911, 2419, 7, 4298, 113, 130, 10960, 52, 26786, 16, 3, 9, 813, 11674, 11044, 28, 8, 549, 8503, 24, 3492, 16, 3, 9, 3996, 3328, 51, 1154, 16, 1660, 48, 215, 5, 86, 8,

To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.

In [19]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

HBox(children=(FloatProgress(value=0.0, max=205.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.

Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

## Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the `AutoModelForSeq2SeqLM` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us.

In [20]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=242065649.0, style=ProgressStyle(descri…




Note that  we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.

To instantiate a `Seq2SeqTrainer`, we will need to define three more things. The most important is the [`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments), which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:

In [21]:
batch_size = 4
args = Seq2SeqTrainingArguments(
    "test-summarization",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
)

Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, use the `batch_size` defined at the top of the cell and customize the weight decay. Since the `Seq2SeqTrainer` will save the model regularly and our dataset is quite large, we tell it to make three saves maximum. Lastly, we use the `predict_with_generate` option (to properly generate summaries) and activate mixed precision training (to go a bit faster).

Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels:

In [22]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

The last thing to define for our `Seq2SeqTrainer` is how to compute the metrics from the predictions. We need to define a function for this, which will just use the `metric` we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:

In [23]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

Then we just need to pass all of this along with our datasets to the `Seq2SeqTrainer`:

In [24]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

We can now finetune our model by just calling the `train` method:

In [25]:
trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: ignored

In [None]:
# !pip install -q GPUtil

# import torch
# from GPUtil import showUtilization as gpu_usage
# from numba import cuda

# def free_gpu_cache():
#     print("Initial GPU Usage")
#     gpu_usage()                             

#     torch.cuda.empty_cache()

#     cuda.select_device(0)
#     cuda.close()
#     cuda.select_device(0)

#     print("GPU Usage after emptying the cache")
#     gpu_usage()

# free_gpu_cache() 

Initial GPU Usage
| ID | GPU | MEM |
------------------
|  0 |  0% |  1% |
GPU Usage after emptying the cache
| ID | GPU | MEM |
------------------
|  0 |  4% |  1% |
