<a href="https://colab.research.google.com/github/peteryushunli/huggingface_tutorials/blob/main/Fine_tuning_a_masked_language_model_(PyTorch)_Rap_Lyrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning a masked language model (PyTorch)

Install the Transformers, Datasets, and Evaluate libraries to run this notebook.

In [1]:
!pip install datasets evaluate transformers[sentencepiece]
!pip install accelerate
# To run the training on TPU, you will need to uncomment the following line:
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!apt install git-lfs

Collecting datasets
  Downloading datasets-2.14.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers[sentencepiece]
  Downloading transformers-4.32.1-py3-none-any.whl (7.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━

You will need to setup git, adapt your email and name in the following cell.

In [2]:
!git config --global user.email "peteryushunli@gmail.com"
!git config --global user.name "Peter Li"

You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials.

In [3]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
from transformers import AutoModelForMaskedLM

model_checkpoint = "distilbert-base-uncased"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [5]:
distilbert_num_parameters = model.num_parameters() / 1_000_000
print(f"'>>> DistilBERT number of parameters: {round(distilbert_num_parameters)}M'")
print(f"'>>> BERT number of parameters: 110M'")

'>>> DistilBERT number of parameters: 67M'
'>>> BERT number of parameters: 110M'


In [6]:
text = "This is a great [MASK]."

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [8]:
import torch

inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

'>>> This is a great deal.'
'>>> This is a great success.'
'>>> This is a great adventure.'
'>>> This is a great idea.'
'>>> This is a great feat.'


In [9]:
from datasets import load_dataset

rap_dataset = load_dataset("Cropinky/rap_lyrics_english")
rap_dataset

Downloading readme:   0%|          | 0.00/304 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Resolving data files:   0%|          | 0/47 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/76.0 [00:00<?, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/80.0 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/118 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/11.0 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.63M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.14M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/407k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/997k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/619k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/367k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/369k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/224k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.11M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.54M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.80M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.17M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/396k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/404k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/236k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/404k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.38M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/342k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/537k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.64M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/372k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/241k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/712k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/266k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/155k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/570k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/372k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/460k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/719k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/258k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/581k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.78M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.81M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.09M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/339k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/496k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/277k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/475k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/115k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/195k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1181216
    })
})

In [10]:
sample = rap_dataset["train"].shuffle(seed=42).select(range(3))

for row in sample:
    print(f"\n'>>> Lyrics: {row['text']}'")
    #print(f"'>>> Label: {row['label']}'")


'>>> Lyrics: Givin' head like she knew me for years'

'>>> Lyrics: Search her purse, search her clean'

'>>> Lyrics: Like I don't no nuffin''


In [11]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result


# Use batched=True to activate fast multithreading!
tokenized_datasets = rap_dataset.map(
    tokenize_function, batched=True, remove_columns=["text"]
)
tokenized_datasets

Map:   0%|          | 0/1181216 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1150 > 512). Running this sequence through the model will result in indexing errors


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 1181216
    })
})

In [12]:
tokenizer.model_max_length

512

In [13]:
chunk_size = 128

In [14]:
# Slicing produces a list of lists for each feature
tokenized_samples = tokenized_datasets["train"][:3]

for idx, sample in enumerate(tokenized_samples["input_ids"]):
    print(f"'>>> Lyric {idx} length: {len(sample)}'")

'>>> Lyric 0 length: 4'
'>>> Lyric 1 length: 4'
'>>> Lyric 2 length: 5'


In [15]:
concatenated_examples = {
    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()
}
total_length = len(concatenated_examples["input_ids"])
print(f"'>>> Concatenated lyric length: {total_length}'")

'>>> Concatenated lyric length: 13'


In [16]:
chunks = {
    k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
    for k, t in concatenated_examples.items()
}

for chunk in chunks["input_ids"]:
    print(f"'>>> Chunk length: {len(chunk)}'")

'>>> Chunk length: 13'


In [17]:
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [18]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

Map:   0%|          | 0/1181216 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 115018
    })
})

In [19]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

"##onna [SEP] [CLS] kanye west [SEP] [CLS] lauryn hill [SEP] [CLS] jean grae [SEP] [CLS] lil kim [SEP] [CLS] missy elliot [SEP] [CLS] rah digga [SEP] [CLS] mc lyte [SEP] [CLS] remy ma [SEP] [CLS] missy elliott [SEP] [CLS] [SEP] [CLS] [SEP] [CLS] foxy brown [SEP] [CLS] < bos > [SEP] [CLS] my beyonce [ chorus : lil durk ] [SEP] [CLS] ooh, i like the way she move [SEP] [CLS] shorty my baby, my everything, she the truth [SEP] [CLS] together we cool, me and her can't lose [SEP] [CLS] keep'em on their feet, baby, i know they so confused [SEP] [CLS] shorty my beyonce"

In [20]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [21]:
samples = [lm_datasets["train"][i] for i in range(2)]
for sample in samples:
    _ = sample.pop("word_ids")

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



'>>> [CLS] lil pump [SEP] [CLS] lil wayne [SEP] [CLS] lil durk [SEP] [CLS] lil b [SEP] [CLS] lil uzi vert [SEP] [CLS] lil baby [SEP] [CLS] lil reese [SEP] [CLS] lil boosie [SEP] [CLS] the notorious b. i [MASK] g. [SEP] [CLS] big pun [SEP] [CLS] big l [SEP] [CLS] nas [SEP] [CLS] 50 cent [SEP] [CLS] prodigy [SEP] [CLS] action bronson [SEP] [CLS] ill bill [SEP] [CLS] wu - tang [MASK] [SEP] [CLS] raekwon [SEP] [CLS] ghostface killah [SEP] [CLS] rza [SEP] [CLS] [MASK]za [SEP] [CLS] ol'dirty bastard [SEP] [CLS] method man [SEP] [CLS] inspectah [MASK] [SEP] [CLS] u - god [SEP] [CLS] masta killa [SEP] [CLS] [MASK]pad'

'>>> ##on [MASK] [SEP] [CLS] kanye west [SEP] [CLS] lauryn hill [SEP] [CLS] jean grae [SEP] [CLS] lil kim [SEP] [CLS] missy elliot [SEP] [CLS] rah digga [SEP] [CLS] [MASK] lyte [SEP] [CLS] remy ma [SEP] [CLS] missy elliott [SEP] [CLS] [SEP] [CLS] [SEP] [CLS] foxy [MASK] [SEP] [CLS] < bos > [SEP] [CLS] my [MASK] [ [MASK] : lil durk ] [SEP] [CLS] ooh, i like the [MASK] she move [

In [22]:
import collections
import numpy as np

from transformers import default_data_collator

wwm_probability = 0.2


def whole_word_masking_data_collator(features):
    for feature in features:
        word_ids = feature.pop("word_ids")

        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        new_labels = [-100] * len(labels)
        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id
        feature["labels"] = new_labels

    return default_data_collator(features)

In [23]:
samples = [lm_datasets["train"][i] for i in range(2)]
batch = whole_word_masking_data_collator(samples)

for chunk in batch["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> [CLS] lil pump [SEP] [CLS] lil [MASK] [SEP] [CLS] [MASK] [MASK] [MASK] [SEP] [CLS] lil b [SEP] [CLS] lil uzi vert [SEP] [CLS] lil baby [SEP] [CLS] lil reese [SEP] [CLS] lil boosie [SEP] [CLS] the notorious [MASK]. i. g. [SEP] [CLS] big pun [SEP] [CLS] big l [SEP] [CLS] nas [SEP] [CLS] 50 cent [SEP] [CLS] [MASK] [SEP] [CLS] [MASK] bronson [SEP] [CLS] ill bill [SEP] [CLS] wu - tang clan [SEP] [CLS] raekwon [SEP] [CLS] ghostface killah [SEP] [CLS] [MASK] [MASK] [SEP] [CLS] [MASK] [MASK] [SEP] [CLS] [MASK]'dirty bastard [SEP] [CLS] method man [SEP] [CLS] inspectah deck [SEP] [CLS] u - god [SEP] [CLS] masta killa [SEP] [CLS] [MASK] [MASK]'

'>>> ##onna [SEP] [CLS] kanye west [SEP] [CLS] lauryn [MASK] [SEP] [CLS] jean grae [SEP] [CLS] lil kim [SEP] [CLS] [MASK] elliot [SEP] [CLS] rah digga [SEP] [CLS] mc lyte [SEP] [CLS] [MASK] ma [SEP] [CLS] [MASK] elliott [SEP] [CLS] [SEP] [CLS] [SEP] [CLS] foxy brown [SEP] [CLS] [MASK] bos > [SEP] [CLS] my beyonce [ chorus : lil durk ] [SEP] [CLS] o

In [24]:
train_size = 10_000
test_size = int(0.1 * train_size)

downsampled_dataset = lm_datasets["train"].train_test_split(
    train_size=train_size, test_size=test_size, seed=42
)
downsampled_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 1000
    })
})

In [25]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [32]:
from transformers import TrainingArguments

batch_size = 64
# Show the training loss with every epoch
logging_steps = len(downsampled_dataset["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1]

training_args = TrainingArguments(
    output_dir=f"{model_name}-finetuned-rap-lyrics-v1",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    push_to_hub=True,
    fp16=True,
    logging_steps=logging_steps,
)

In [33]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=downsampled_dataset["train"],
    eval_dataset=downsampled_dataset["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [34]:
import math

eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

>>> Perplexity: 64.59


In [35]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,2.965,2.652665
2,2.7027,2.597796
3,2.6217,2.563443


TrainOutput(global_step=471, training_loss=2.761963356325834, metrics={'train_runtime': 155.7314, 'train_samples_per_second': 192.639, 'train_steps_per_second': 3.024, 'total_flos': 994208670720000.0, 'train_loss': 2.761963356325834, 'epoch': 3.0})

In [36]:
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

>>> Perplexity: 12.92


In [37]:
trainer.push_to_hub()

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/4.09k [00:00<?, ?B/s]

'https://huggingface.co/peteryushunli/distilbert-base-uncased-finetuned-rap-lyrics-v1/tree/main/'

In [38]:
def insert_random_mask(batch):
    features = [dict(zip(batch, t)) for t in zip(*batch.values())]
    masked_inputs = data_collator(features)
    # Create a new "masked" column for each column in the dataset
    return {"masked_" + k: v.numpy() for k, v in masked_inputs.items()}

In [39]:
downsampled_dataset = downsampled_dataset.remove_columns(["word_ids"])
eval_dataset = downsampled_dataset["test"].map(
    insert_random_mask,
    batched=True,
    remove_columns=downsampled_dataset["test"].column_names,
)
eval_dataset = eval_dataset.rename_columns(
    {
        "masked_input_ids": "input_ids",
        "masked_attention_mask": "attention_mask",
        "masked_labels": "labels",
    }
)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [40]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

batch_size = 64
train_dataloader = DataLoader(
    downsampled_dataset["train"],
    shuffle=True,
    batch_size=batch_size,
    collate_fn=data_collator,
)
eval_dataloader = DataLoader(
    eval_dataset, batch_size=batch_size, collate_fn=default_data_collator
)

In [41]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

In [42]:
from accelerate import Accelerator

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [43]:
from transformers import get_scheduler

num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [44]:
from huggingface_hub import get_full_repo_name

model_name = "distilbert-base-uncased-finetuned-rap-lyrics-accelerate"
repo_name = get_full_repo_name(model_name)
repo_name

'peteryushunli/distilbert-base-uncased-finetuned-rap-lyrics-accelerate'

In [46]:
from huggingface_hub import Repository

output_dir = repo_name
repo = Repository(output_dir, clone_from=repo_name)

Cloning https://huggingface.co/peteryushunli/distilbert-base-uncased-finetuned-rap-lyrics-accelerate into local empty directory.


OSError: ignored

In [None]:
from tqdm.auto import tqdm
import torch
import math

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        loss = outputs.loss
        losses.append(accelerator.gather(loss.repeat(batch_size)))

    losses = torch.cat(losses)
    losses = losses[: len(eval_dataset)]
    try:
        perplexity = math.exp(torch.mean(losses))
    except OverflowError:
        perplexity = float("inf")

    print(f">>> Epoch {epoch}: Perplexity: {perplexity}")

    # Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        repo.push_to_hub(
            commit_message=f"Training in progress epoch {epoch}", blocking=False
        )

>>> Epoch 0: Perplexity: 11.397545307900472
>>> Epoch 1: Perplexity: 10.904909330983092
>>> Epoch 2: Perplexity: 10.729503505340409

In [None]:
from transformers import pipeline

mask_filler = pipeline(
    "fill-mask", model="huggingface-course/distilbert-base-uncased-finetuned-imdb"
)

In [None]:
preds = mask_filler(text)

for pred in preds:
    print(f">>> {pred['sequence']}")

'>>> this is a great movie.'
'>>> this is a great film.'
'>>> this is a great story.'
'>>> this is a great movies.'
'>>> this is a great character.'