# Fine-tuning a masked language model (TensorFlow)

Install the Transformers and Datasets libraries to run this notebook.

In [1]:
!pip install datasets transformers[sentencepiece]
!apt install git-lfs

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.3.2-py3-none-any.whl (362 kB)
[K     |████████████████████████████████| 362 kB 7.1 MB/s 
[?25hCollecting transformers[sentencepiece]
  Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 49.2 MB/s 
Collecting xxhash
  Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |████████████████████████████████| 212 kB 72.8 MB/s 
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 67.7 MB/s 
[?25hCollecting huggingface-hub<1.0.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 13.8 MB/s 
[?25hCollecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting aiohtt

You will need to setup git, adapt your email and name in the following cell.

In [2]:
!git config --global user.email "sajjad_ramezani@ind.iust.ac.ir"
!git config --global user.name "Sajjad"

In [3]:
from transformers import TFAutoModelForMaskedLM

model_checkpoint = "bert-base-cased"
model = TFAutoModelForMaskedLM.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/502M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFBertForMaskedLM.

All the layers of TFBertForMaskedLM were initialized from the model checkpoint at bert-base-cased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMaskedLM for predictions without further training.


In [4]:
model(model.dummy_inputs)  # Build the model
model.summary()

Model: "tf_bert_for_masked_lm"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 bert (TFBertMainLayer)      multiple                  107719680 
                                                                 
 mlm___cls (TFBertMLMHead)   multiple                  23286340  
                                                                 
Total params: 108,340,804
Trainable params: 108,340,804
Non-trainable params: 0
_________________________________________________________________


In [5]:
text = "trump claims [MASK]."

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

In [7]:
import numpy as np
import tensorflow as tf

inputs = tokenizer(text, return_tensors="np")
token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = np.argwhere(inputs["input_ids"] == tokenizer.mask_token_id)[0, 1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
# We negate the array before argsort to get the largest, not the smallest, logits
top_5_tokens = np.argsort(-mask_token_logits)[:5].tolist()

for token in top_5_tokens:
    print(f">>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")

>>> trump claims win.
>>> trump claims won.
>>> trump claims only.
>>> trump claims wins.
>>> trump claims are.


In [8]:
from datasets import load_dataset
imdb_dataset = load_dataset('csv',column_names=['text', 'label'],data_files ={'train': "/content/label_World News_news.csv",'test':'/content/label_World News_sample_news.csv'})

Using custom data configuration default-b8cede8c23c16fca


Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-b8cede8c23c16fca/0.0.0/51cce309a08df9c4d82ffd9363bbe090bf173197fc01a71b034e8594995a1a58...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-b8cede8c23c16fca/0.0.0/51cce309a08df9c4d82ffd9363bbe090bf173197fc01a71b034e8594995a1a58. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
sample = imdb_dataset["train"].shuffle(seed=42).select(range(3))

for row in sample:
    print(f"\n'>>> Review: {row['text']}'")
    print(f"'>>> Label: {row['label']}'")


'>>> Review: Ukraine Mariupol Descends despair'
'>>> Label: World News'

'>>> Review: War White Gold North Carolina'
'>>> Label: World News'

'>>> Review: russian Threats push Finland join NATO Alliance'
'>>> Label: World News'


In [10]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result


# Use batched=True to activate fast multithreading!
tokenized_datasets = imdb_dataset.map(
    tokenize_function, batched=True, remove_columns=["text", "label"]
)
tokenized_datasets



  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids'],
        num_rows: 974
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids'],
        num_rows: 109
    })
})

In [11]:
tokenizer.model_max_length

512

In [12]:
chunk_size = 128

In [13]:
# Slicing produces a list of lists for each feature
tokenized_samples = tokenized_datasets["train"][:3]

for idx, sample in enumerate(tokenized_samples["input_ids"]):
    print(f"'>>> Review {idx} length: {len(sample)}'")

'>>> Review 0 length: 3'
'>>> Review 1 length: 15'
'>>> Review 2 length: 9'


In [14]:
concatenated_examples = {
    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()
}
total_length = len(concatenated_examples["input_ids"])
print(f"'>>> Concatenated reviews length: {total_length}'")

'>>> Concatenated reviews length: 27'


In [15]:
chunks = {
    k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
    for k, t in concatenated_examples.items()
}

for chunk in chunks["input_ids"]:
    print(f"'>>> Chunk length: {len(chunk)}'")

'>>> Chunk length: 27'


In [16]:
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [17]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 102
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 11
    })
})

In [18]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

'##spected Outbreak Nears million [SEP] [CLS] australian Prime Minister Tackles Child campaign event Viral Clip [SEP] [CLS] China Eastern Plane Crash intentional official find report [SEP] [CLS] russian soldier plead Guilty Ukraine War Crimes Trial kill Civilian [SEP] [CLS] russian Gymnast ban wear Pro war z symbol ukrainian Rival [SEP] [CLS] Ukraine Hopes Swap Steel Mill Fighters russian pow [SEP] [CLS] Finland Sweden Submit Applications join NATO [SEP] [CLS] Queen Elizabeth make Special Appearance London Tube stop [SEP] [CLS] Worth shot Pope Francis say need Tequila sore knee [SEP] [CLS] belgian'

In [19]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [20]:
samples = [lm_datasets["train"][i] for i in range(2)]
for sample in samples:
    _ = sample.pop("word_ids")

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> [CLS] text [SEP] [CLS] [MASK]ive Medic Body [MASK] show Firsthand Horror Mariupol [SEP] [CLS] Russia fire senior [MASK] mean Ukraine War [SEP] [CLS] Red Cross [MASK]s Hundreds ukrainian pow Mariup [MASK] [SEP] [CLS] [MASK] GOP Attorney General Primary big lie battleground [SEP] [CLS] War White Gold North Carolina [SEP] [CLS] Trump Reportedly consider quit Presidential [MASK] [MASK]pe surface [SEP] [CLS] [MASK] Lyonne bring Ex Fred Armisen SNL [MASK]ologue joke sex tape [SEP] [CLS] Brooklyn Half Marathon Runner Col [MASK]ses die cross finish [MASK] [SEP] [CLS] apparently ailing Putin surround [MASK]s Kremlin Chaos say Ex brit [MASK] Official [SEP] [CLS] North Korea Su'

'>>> ##spected Outbreak Nears million [SEP] [CLS] australian Prime Minister Tackles Child campaign [MASK] Vira [MASK] Clip [SEP] [CLS] China Eastern Plane Crash intentional official find report [SEP] [CLS] r [MASK]ian soldier [MASK]d [MASK]uilty Ukraine [MASK] Crimes Trial kill Civilian [SEP] [CLS]roducedussian Gym

In [21]:
import collections
import numpy as np

# from transformers.data import tf_default_data_collator
from transformers.data.data_collator import tf_default_data_collator
wwm_probability = 0.2


def whole_word_masking_data_collator(features):
    for feature in features:
        word_ids = feature.pop("word_ids")

        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        new_labels = [-100] * len(labels)
        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id

    return tf_default_data_collator(features)

In [22]:
samples = [lm_datasets["train"][i] for i in range(2)]
batch = whole_word_masking_data_collator(samples)

for chunk in batch["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> [CLS] text [SEP] [CLS] Captive [MASK] [MASK] Bodycam show Firsthand Horror Mariupol [SEP] [CLS] Russia fire senior [MASK] mean Ukraine War [SEP] [CLS] Red Cross Registers Hundreds ukrainian pow Mariupol [SEP] [CLS] Georgia GOP Attorney General Primary [MASK] lie [MASK] [MASK] [SEP] [CLS] War [MASK] Gold North Carolina [SEP] [CLS] [MASK] Reportedly consider quit Presidential Race Tape surface [SEP] [CLS] [MASK] Lyonne bring [MASK] Fred Armisen SNL Monologue joke sex tape [SEP] [CLS] Brooklyn Half Marathon Runner Collapses die cross finish line [SEP] [CLS] apparently ailing [MASK] surround Docs Kremlin Chaos [MASK] [MASK] [MASK] [MASK] [MASK] Official [SEP] [CLS] North Korea Su'

'>>> ##spected [MASK] [MASK] Nears [MASK] [SEP] [CLS] australian Prime Minister Tackles Child campaign [MASK] Viral [MASK] [MASK] [SEP] [CLS] China Eastern [MASK] [MASK] [MASK] intentional official [MASK] report [SEP] [CLS] russian soldier plead Guilty Ukraine War [MASK] [MASK] Trial [MASK] Civilian [SEP] 

In [23]:
train_size =90
test_size = 10

downsampled_dataset = lm_datasets["train"].train_test_split(
    train_size=train_size, test_size=test_size, seed=42
)
downsampled_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 90
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 10
    })
})

In [24]:
tf_train_dataset = downsampled_dataset["train"].to_tf_dataset(
    columns=["input_ids", "attention_mask", "labels"],
    collate_fn=data_collator,
    shuffle=True,
    batch_size=32,
)

tf_eval_dataset = downsampled_dataset["test"].to_tf_dataset(
    columns=["input_ids", "attention_mask", "labels"],
    collate_fn=data_collator,
    shuffle=False,
    batch_size=32,
)

In [25]:
model_name='ss'

In [26]:
from transformers import create_optimizer
from transformers.keras_callbacks import PushToHubCallback
import tensorflow as tf

num_train_steps = len(tf_train_dataset)
optimizer, schedule = create_optimizer(
    init_lr=2e-5,
    num_warmup_steps=1_000,
    num_train_steps=num_train_steps,
    weight_decay_rate=0.01,
)
model.compile(optimizer=optimizer)

# Train in mixed-precision float16
tf.keras.mixed_precision.set_global_policy("mixed_float16")

# callback = PushToHubCallback(
#     output_dir=f"{model_name}-finetuned-imdb", tokenizer=tokenizer
# )

No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.


INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: Tesla T4, compute capability 7.5


In [27]:
import math

eval_loss = model.evaluate(tf_eval_dataset)
print(f"Perplexity: {math.exp(eval_loss):.2f}")

Perplexity: 457.44


In [28]:
model.fit(tf_train_dataset, validation_data=tf_eval_dataset,epochs=150)

Epoch 1/150
Epoch 2/150
Epoch 3/150
Epoch 4/150
Epoch 5/150
Epoch 6/150
Epoch 7/150
Epoch 8/150
Epoch 9/150
Epoch 10/150
Epoch 11/150
Epoch 12/150
Epoch 13/150
Epoch 14/150
Epoch 15/150
Epoch 16/150
Epoch 17/150
Epoch 18/150
Epoch 19/150
Epoch 20/150
Epoch 21/150
Epoch 22/150
Epoch 23/150
Epoch 24/150
Epoch 25/150
Epoch 26/150
Epoch 27/150
Epoch 28/150
Epoch 29/150
Epoch 30/150
Epoch 31/150
Epoch 32/150
Epoch 33/150
Epoch 34/150
Epoch 35/150
Epoch 36/150
Epoch 37/150
Epoch 38/150
Epoch 39/150
Epoch 40/150
Epoch 41/150
Epoch 42/150
Epoch 43/150
Epoch 44/150
Epoch 45/150
Epoch 46/150
Epoch 47/150
Epoch 48/150
Epoch 49/150
Epoch 50/150
Epoch 51/150
Epoch 52/150
Epoch 53/150
Epoch 54/150
Epoch 55/150
Epoch 56/150
Epoch 57/150
Epoch 58/150
Epoch 59/150
Epoch 60/150
Epoch 61/150
Epoch 62/150
Epoch 63/150
Epoch 64/150
Epoch 65/150
Epoch 66/150
Epoch 67/150
Epoch 68/150
Epoch 69/150
Epoch 70/150
Epoch 71/150
Epoch 72/150
Epoch 73/150
Epoch 74/150
Epoch 75/150
Epoch 76/150
Epoch 77/150
Epoch 78

<keras.callbacks.History at 0x7f6094e66510>

In [29]:
eval_loss = model.evaluate(tf_eval_dataset)
print(f"Perplexity: {math.exp(eval_loss):.2f}")

Perplexity: 13.28


In [30]:
from transformers import pipeline

mask_filler = pipeline(
    "fill-mask", model="huggingface-course/distilbert-base-uncased-finetuned-imdb"
)

Downloading:   0%|          | 0.00/557 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/333 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [31]:
preds = mask_filler(text)

for pred in preds:
    print(f">>> {pred['sequence']}")

>>> trump claims innocence.
>>> trump claims victory.
>>> trump claims this.
>>> trump claims that.
>>> trump claims nothing.
