Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:

In [1]:
from transformers import pipeline
model = pipeline("summarization", model="facebook/bart-large-cnn")
model.save_pretrained("bart_model")

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.
Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


In [4]:
import transformers

print(transformers.__version__)

4.45.1


You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/seq2seq).

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

In [5]:
from transformers.utils import send_example_telemetry

send_example_telemetry("summarization_notebook", framework="pytorch")

# Fine-tuning a model on a summarization task

In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model for a summarization task. We will use the [XSum dataset](https://arxiv.org/pdf/1808.08745.pdf) (for extreme summarization) which contains BBC articles accompanied with single-sentence summaries.

![Widget inference on a summarization task](https://github.com/huggingface/notebooks/blob/main/examples/images/summarization.png?raw=1)

We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using the `Trainer` API.

In [6]:
model_checkpoint = "t5-small"

This notebook is built to run  with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the [`t5-small`](https://huggingface.co/t5-small) checkpoint.

## Loading the dataset

We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.  

In [7]:
from datasets import load_dataset
from evaluate import load

raw_datasets = load_dataset("xsum")
metric = load("rouge")

xsum.py:   0%|          | 0.00/5.76k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/6.24k [00:00<?, ?B/s]

The repository for xsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/xsum.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


(…)SUM-EMNLP18-Summary-Data-Original.tar.gz:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.72M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:

In [8]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

To access an actual element, you need to select a split first, then give an index:

In [9]:
raw_datasets["train"][0]

 'summary': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.',
 'id': '35232142'}

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

In [10]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [11]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,summary,id
0,"Scientists have found that the chemical components of a fragrance can transfer from one person's clothing to another's - even if any contact is brief.\nThe scent's signature lingers for days, although it lessens over time.\nThe team says this is a proof-of-principal study, but suggests that perfumes have the potential to be used as trace evidence.\nThe researchers, writing in the journal Science and Justice, said that analysing fragrances could be a useful tool in cases where there has been close physical contact, such as sexual assaults.\nLead researcher Simona Gherghel, from University College London, said: ""We thought there was a lot of potential with perfume because a lot of people use it. We know about 90% of women and 60% of men use perfume on a regular basis.\n""While there is a lot of work in forensic science on transfers - for example, the transfer of fibres or the transfer of gun-shot residue - until now there has been no research on the transfer of perfumes.""\nForensic reconstruction\nPerfumes are concocted from many different chemical components, which in combination give an individual fragrance its distinctive smell.\nThe researchers, looking at a single male fragrance, found that some of these components were easily transferred from one piece of cotton to another.\nWhen the two pieces of material were pressed together for just a minute, 15 out of 44 chemical components were detected on the second piece of fabric. If the contact time increased to 10 minutes, 18 components were measured.\nThe scientists also tracked how time affected the transfer of the volatile compounds.\nThey found that five minutes after an initial spray of fragrance, 24 out of 44 perfume components were detected on the second piece of fabric after it had been in contact for 10 minutes. Six hours after the perfume was applied, 12 components were transferred and seven days later, six volatile components were retained.\nDr Ruth Morgan, director of the UCL Centre for the Forensic Sciences, said: ""It is a pilot study and a proof-of-concept study. We've shown that first, perfume does transfer, and second, we can identify when that transfer has happened.\n""In the future there could well be situations where contact between two individuals is made and this is a way of discerning what kind of contact is made and when it was made.""\nHowever the team added that any evidence would have to be collected extremely quickly after an offence, which could limit its usefulness. They said it was also unlikely that fragrance would be used alone to solve a case.\nDr Morgan added: ""It is not going to be a one-stop indicator. In most investigations we would be hopeful that there would be multiple lines of investigation. We wouldn't want it to just be DNA or just a fingerprint or just perfume. But in combination, with other forms off evidence, that's the way it builds up into a very compelling picture.""\nThe team said more work now needs to be done to assess how perfumes transfer in more realistic forensic reconstructions.\nFollow Rebecca on Twitter: @BBCMorelle","Detecting traces of perfume could help in the fight against crime, a study suggests.",37160171
1,"Gen Sir Adrian Bradshaw said Nato should promote successful strategies so other countries could follow them.\nHe explained how Jordan was training its imams to practice a tolerant traditional form of Islam.\nThe deputy supreme allied commander in Europe also spoke about Afghanistan, migration, and the role of Nato.\nGen Bradshaw told BBC Radio 4's Today programme that the fight against Islamist extremism was multi-faceted.\nHe said part of Nato's approach should be based around drawing attention to de-radicalisation programmes that were working.\nGen Bradshaw added that it would be effective to promote the potential benefits of a strategy such as Jordan's.\nGen Bradshaw said: ""One of the ways in particular that we can help as Nato is by drawing attention to the problem and the need for complementary activities to take place.\n""There is a lot more that can be done to support nations who are running their own counter-radicalisation programmes.\n""In particular, for example, I've become aware of the programme sponsored by King Abdullah of Jordan and delivered by Prince Ghazi to train imams in the tolerant, traditional form of Islam, which incidentally was the traditional form that was found in the Balkans, and in so doing reduce the scope for radicalisation of populations there.\n""We can help in drawing attention to the potential benefits of that sort of activity, although it's not military activity and it's not our primary responsibility.\n""I think it is our responsibility to make sure that people understand what sort of complementary activity needs to go alongside the security activity that we're directly involved in.""\nIn a wide ranging interview, he also defended Nato's involvement in Afghanistan even though casualty figures have been rising and parts of Helmand have been taken over by the Taliban again.\nIn relation to the EU's migration crisis, he said the alliance was monitoring the influx of migrants but that it was limited in what it could do because it was not a security operation.\nCritics including Republican US presidential candidate Donald Trump have accused Nato of being obsolete.\nThe general vehemently denied the suggestion, saying that its response to Russia's behaviour over the past couple of years proved that its role as a deterrent was alive and necessary.","Nato could do more to support countries that are running counter-radicalisation programmes, the UK's top representative in the alliance has told the BBC.",35974677
2,"Play had been halted with the final group on the 16th hole, with the fourth round still to play on Sunday.\nJohnson Wagner and Scott Piercy, both also American, are tied for second on 14 under par in Napa Valley.\nScotland's Laird shot a four-under 68 to draw level with Englishman Casey.\nMichael Kim of the United States carded a seven-under-par 65 to share the clubhouse lead with compatriot Brendan Steele on 11 under, while another American, five-time major winner Phil Mickelson, is at nine under.\nStarting times for the fourth round have been moved as more rain is forecast on Sunday.\nWe've launched a new BBC Sport newsletter, bringing all the best stories, features and video right to your inbox. You can sign up here.","American Patton Kizzire leads on 15 under par after the third round of the rain-delayed Safeway Open in California, USA, with British pair Paul Casey and Martin Laird two shots back.",37670770
3,"Also added to Historic England's latest Heritage at Risk Register are Brighton Old Town and a church dubbed ""the birthplace of feminism"".\nThe sites are considered to be at risk of being lost through neglect or decay.\nBut Historic England said there are fewer entries on the register than last year.\nSites including the Grade-I listed grounds at Castle Howard in North Yorkshire and the world's oldest ""pub music hall"", Wilton's Music Hall in London, have come off the register after being restored and saved.\nSee more historic and quirky buildings on BBC England's Pinterest board\nExperts warned that the gap between the cost of repairs and the value of restored properties was growing, driven partly by a skills shortage and a lack of scaffolding in some areas which pushed up costs.\nDuncan Wilson, chief executive of the government heritage agency, said ""thousands of historic sites"" were at risk of being lost across the country.\n""Many lie decaying and neglected and the gap between the cost of repair and their end value is growing,"" he said.\n""The good news is this year there are fewer entries on the Heritage at Risk Register than last year.\n""But as some places are rescued, others fall into disrepair.""\nLondon Zoo's aviary, designed by Lord Snowdon and built in 1965, is in need of repair but has secured Heritage Lottery funding to turn it into a new space for animals and visitors.\nA 16th Century shipwreck in Dunwich, Suffolk, thought to be that of an armed merchant vessel, has been added to the list after a bronze gun was stolen from the site.",A 16th Century shipwreck and London Zoo's aviary are among English heritage sites now considered to be at risk.,37719556
4,"Billy Whitehouse's debut goal, James Bolton's driven effort and Nicky Wroe's cool finish gave the home side a surprise three-goal lead at half-time.\nPadraig Amond gave the visitors hope late on before Jake Hibbs placed one into the top corner to restore the two-goal cushion.\nThe Mariners' Jon Nolan smashed in a late consolation as Halifax held on.\nGrimsby Town manager Paul Hurst told BBC Radio Humberside:\nMedia playback is not supported on this device\n""We were certainly punished. They were very dangerous on the break and I think we had more of the ball, but then give it away in some bad areas.\n""The first two goals were an inability to defend a set piece and a ball in the box. They were pretty lethal in their attacks.\n""In the end, what happens in both boxes determines games of football and certainly Halifax were very clinical and we were perhaps a bit toothless.\n""We did have our opportunities to score and when you're behind in the game you need to take them.""",Halifax moved out of the National League relegation zone with a stunning victory over high-flying Grimsby.,35568166


The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric):

In [12]:
metric

EvaluationModule(name: "rouge", module_type: "metric", features: [{'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id=None)}, {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}], usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLsum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/

You can call its `compute` method with your predictions and labels, which need to be list of decoded strings:

In [13]:
fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
metric.compute(predictions=fake_preds, references=fake_labels)

{'rouge1': 1.0, 'rouge2': 1.0, 'rougeL': 1.0, 'rougeLsum': 1.0}

## Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.

To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:

- we get a tokenizer that corresponds to the model architecture we want to use,
- we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

In [14]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library.

You can directly call this tokenizer on one sentence or a pair of sentences:

In [15]:
tokenizer("Hello, this one sentence!")

{'input_ids': [8774, 6, 48, 80, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested.

Instead of one sentence, we can pass along a list of sentences:

In [16]:
tokenizer(["Hello, this one sentence!", "This is another sentence."])

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

To prepare the targets for our model, we need to tokenize them using the `text_target` parameter. This will make sure the tokenizer uses the special tokens corresponding to the targets:

In [17]:
print(tokenizer(text_target=["Hello, this one sentence!", "This is another sentence."]))

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}


If you are using one of the five T5 checkpoints we have to prefix the inputs with "summarize:" (the model can also translate and it needs the prefix to know which task it has to perform).

In [18]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

In [19]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [20]:
preprocess_function(raw_datasets['train'][:2])

{'input_ids': [[21603, 10, 37, 423, 583, 13, 1783, 16, 20126, 16496, 6, 80, 13, 8, 844, 6025, 4161, 6, 19, 341, 271, 14841, 5, 7057, 161, 19, 4912, 16, 1626, 5981, 11, 186, 7540, 16, 1276, 15, 2296, 7, 5718, 2367, 14621, 4161, 57, 4125, 387, 5, 15059, 7, 30, 8, 4653, 4939, 711, 747, 522, 17879, 788, 12, 1783, 44, 8, 15763, 6029, 1813, 9, 7472, 5, 1404, 1623, 11, 5699, 277, 130, 4161, 57, 18368, 16, 20126, 16496, 227, 8, 2473, 5895, 15, 147, 89, 22411, 139, 8, 1511, 5, 1485, 3271, 3, 21926, 9, 472, 19623, 5251, 8, 616, 12, 15614, 8, 1783, 5, 37, 13818, 10564, 15, 26, 3, 9, 3, 19513, 1481, 6, 18368, 186, 1328, 2605, 30, 7488, 1887, 3, 18, 8, 711, 2309, 9517, 89, 355, 5, 3966, 1954, 9233, 15, 6, 113, 293, 7, 8, 16548, 13363, 106, 14022, 84, 47, 14621, 4161, 6, 243, 255, 228, 59, 7828, 8, 1249, 18, 545, 11298, 1773, 728, 8, 8347, 1560, 5, 611, 6, 255, 243, 72, 1709, 1528, 161, 228, 43, 118, 4006, 91, 12, 766, 8, 3, 19513, 1481, 410, 59, 5124, 5, 96, 196, 17, 19, 1256, 68, 27, 103, 317, 132

To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.

In [21]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

Map:   0%|          | 0/204045 [00:00<?, ? examples/s]

Map:   0%|          | 0/11332 [00:00<?, ? examples/s]

Map:   0%|          | 0/11334 [00:00<?, ? examples/s]

Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.

Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

## Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the `AutoModelForSeq2SeqLM` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us.

In [22]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Note that  we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.

To instantiate a `Seq2SeqTrainer`, we will need to define three more things. The most important is the [`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments), which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:

In [23]:
batch_size = 16
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=True,
)



Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, use the `batch_size` defined at the top of the cell and customize the weight decay. Since the `Seq2SeqTrainer` will save the model regularly and our dataset is quite large, we tell it to make three saves maximum. Lastly, we use the `predict_with_generate` option (to properly generate summaries) and activate mixed precision training (to go a bit faster).

The last argument to setup everything so we can push the model to the [Hub](https://huggingface.co/models) regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally in a name that is different than the name of the repository it will be pushed, or if you want to push your model under an organization and not your name space, use the `hub_model_id` argument to set the repo name (it needs to be the full name, including your namespace: for instance `"sgugger/t5-finetuned-xsum"` or `"huggingface/t5-finetuned-xsum"`).

Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels:

In [24]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

The last thing to define for our `Seq2SeqTrainer` is how to compute the metrics from the predictions. We need to define a function for this, which will just use the `metric` we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:

In [25]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

Then we just need to pass all of this along with our datasets to the `Seq2SeqTrainer`:

In [28]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


We can now finetune our model by just calling the `train` method:

In [29]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ····························


[34m[1mwandb[0m: [32m[41mERROR[0m API key must be 40 characters long, yours was 28
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112694855554158, max=1.0…

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.7074,2.47801,28.3005,7.7413,22.2784,22.2767,18.8253




TrainOutput(global_step=12753, training_loss=2.7688838121498986, metrics={'train_runtime': 10763.8184, 'train_samples_per_second': 18.957, 'train_steps_per_second': 1.185, 'total_flos': 5.402928774709248e+16, 'train_loss': 2.7688838121498986, 'epoch': 1.0})

You can now upload the result of the training to the Hub, just execute this instruction:

In [30]:
trainer.push_to_hub()

events.out.tfevents.1729103129.1086aae01f13.30.0:   0%|          | 0.00/12.1k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Ahmed-shetaia/t5-small-finetuned-xsum/commit/693b8d6815d3da087f4d13c25b4e4d7740b0b68d', commit_message='End of training', commit_description='', oid='693b8d6815d3da087f4d13c25b4e4d7740b0b68d', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Ahmed-shetaia/t5-small-finetuned-xsum', endpoint='https://huggingface.co', repo_type='model', repo_id='Ahmed-shetaia/t5-small-finetuned-xsum'), pr_revision=None, pr_num=None)