If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/language-modeling).

# Fine-tuning a language model

## Preparing the dataset

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
from datasets import load_dataset
datasets = load_dataset('wikitext', 'wikitext-103-raw-v1', cache_dir='/mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset')

Reusing dataset wikitext (/mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)


You can replace the dataset above with any dataset hosted on [the hub](https://huggingface.co/datasets) or use your own files. Just uncomment the following cell and replace the paths with values that will lead to your files:

In [4]:
# datasets = load_dataset("text", data_files={"train": path_to_train.txt, "validation": path_to_validation.txt}

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

In [5]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [6]:
show_random_elements(datasets["train"])

Unnamed: 0,text
0,"The Melbourne Cricket Club decided that their players should have a uniform while playing the match . They chose the colours red , white and blue for the Victorian team 's match jerseys . These remain the official colours of the Melbourne Cricket Club till date . The Victorian team , wearing the multi @-@ coloured jerseys , and with friends and ladies in tow , reached Launceston on the SS Shamrock . The team was met with much fanfare on its arrival and the players were put up at the Cornwall Hotel , now known as the Batman Fawkner Inn . This was followed by considerable banqueting , which led to the team not having enough time to practice for the match . \n"
1,"Although Colen Campbell was employed by Thomas Coke in the early 1720s , the oldest existing working and construction plans for Holkham were drawn by Matthew Brettingham , under the supervision of Thomas Coke , in 1726 . These followed the guidelines and ideals for the house as defined by Kent and Burlington . The Palladian revival style chosen was at this time making its return in England . The style made a brief appearance in England before the Civil War , when it was introduced by Inigo Jones . However , following the Restoration it was replaced in popular favour by the Baroque style . The "" Palladian revival "" , popular in the 18th century , was loosely based on the appearance of the works of the 16th @-@ century Italian architect Andrea Palladio . However it did not adhere to Palladio 's strict rules of proportion . The style eventually evolved into what is generally referred to as Georgian , still popular in England today . It was the chosen style for numerous houses in both town and country , although Holkham is exceptional for both its severity of design and for being closer than most in its adherence to Palladio 's ideals . \n"
2,
3,"Alan E. Diehl , a former safety manager for the U.S. Navy , described the USS Iowa incident in his 2003 book Silent Knights : Blowing the Whistle on Military Accidents and Their Cover @-@ Ups . Diehl called the incident and its aftermath the worst military cover @-@ up he had ever seen . \n"
4,
5,
6,
7,
8,
9,= = = Lawsuits = = = \n


As we can see, some of the texts are a full paragraph of a Wikipedia article while others are just titles or empty lines.

## Masked language modeling

For masked language modeling (MLM) we are going to use the same preprocessing as before for our dataset with one additional step: we will randomly mask some tokens (by replacing them by `[MASK]`) and the labels will be adjusted to only include the masked tokens (we don't have to predict the non-masked tokens).

In [7]:
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import Trainer, TrainingArguments

def tokenize_function(examples):
    return tokenizer(examples["text"])

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

# block_size = tokenizer.model_max_length
block_size = 128

model_checkpoint = "facebook/bart-base"

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=12, remove_columns=["text"], )

  

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-f634e0480f896944.arrow


  

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-acda6ad9dce2eef6.arrow


        

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-ef1820969fd66179.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-b2e8e2fb15f27ced.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-b0a93392c0e57d6b.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-ec9b6b83172baf30.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0

  

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-acaadf73a5276c05.arrow


  

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-86d7b48270f968fe.arrow


  

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-9b8972040809fe4a.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-d2a3cba251fc04bc.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-e584b9de701c284f.arrow


  

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-c368a37cdef00f7c.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-7918b78dee2498e5.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-d64777fbe8f86c00.arrow


    

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-790b13fd639bbbbf.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-4c4a0f165d4e6700.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-3fe8a540a216ad05.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-cdc93d4d874abd3e.arrow


            

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-aad840a9386dbd2c.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-01c266822d940667.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-fce967ff89cc7970.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-bc0ced26469fd6f3.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0

And like before, we group texts together and chunk them in samples of length `block_size`. You can skip that step if your dataset is composed of individual sentences.

In [9]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=12,
)

            

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-155ef5fd874a9e3e.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-cdac13cd431436f1.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-e7295f2dc93fd806.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-06b162f33ab6472b.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0

   

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-326a2c0ccb34cd82.arrow


 

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-aec6f16773e95988.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-fcc8db69655998a7.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-dfb417735f3804d3.arrow


    

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-2cbff0bebd88ce8a.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-fae087fcf387d618.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-0ac8823d5f2f0b9e.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-177447177f7d31ef.arrow


  

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-8b95adea019c97f7.arrow


  

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-a333675c57ad5bce.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-370ec394d2f6a6b5.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-edaf561d22a2e80e.arrow


            

Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-c563c0365c8be8d7.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-126f20dad585773d.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-f2bc93fc3536ac91.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91/cache-33f3746ab260c58b.arrow
Loading cached processed dataset at /mnt/NAS/users/ntr/tmp/text_models_dumps/wiki_dataset/wikitext/wikitext-103-raw-v1/1.0.0

In [10]:
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

And second, we use a special `data_collator`. The `data_collator` is a function that is responsible of taking the samples and batching them in tensors. In the previous example, we had nothing special to do, so we just used the default for this argument. Here we want to do the random-masking. We could do it as a pre-processing step (like the tokenization) but then the tokens would always be masked the same way at each epoch. By doing this step inside the `data_collator`, we ensure this random masking is done in a new way each time we go over the data.

To do this masking for us, the library provides a `DataCollatorForLanguageModeling`. We can adjust the probability of the masking:

In [11]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

Then we just have to pass everything to `Trainer` and begin training:

In [12]:
training_args = TrainingArguments(
    output_dir="/mnt/NAS/users/ntr/tmp/text_models_dumps/bart_mlm",
    evaluation_strategy = "epoch",
    num_train_epochs=1,
    learning_rate=2e-5,
    weight_decay=0.01,
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
    data_collator=data_collator, 
)

In [13]:
trainer.train()

Epoch,Training Loss,Validation Loss,Runtime,Samples Per Second
1,1.7955,1.539457,14.292,138.958


TrainOutput(global_step=118563, training_loss=2.0065071245787114, metrics={'train_runtime': 36666.0367, 'train_samples_per_second': 3.234, 'total_flos': 101560630269247488, 'epoch': 1.0})

In [16]:
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

Perplexity: 4.49


In [15]:
model.save_pretrained('/mnt/NAS/users/ntr/tmp/text_models_dumps/bart_mlm')
tokenizer.save_pretrained('/mnt/NAS/users/ntr/tmp/text_models_dumps/bart_mlm')

('/mnt/NAS/users/ntr/tmp/text_models_dumps/bart_mlm/tokenizer_config.json',
 '/mnt/NAS/users/ntr/tmp/text_models_dumps/bart_mlm/special_tokens_map.json',
 '/mnt/NAS/users/ntr/tmp/text_models_dumps/bart_mlm/vocab.json',
 '/mnt/NAS/users/ntr/tmp/text_models_dumps/bart_mlm/merges.txt',
 '/mnt/NAS/users/ntr/tmp/text_models_dumps/bart_mlm/added_tokens.json')