# Chapter 2 Tutorial: Understanding LLMs and Pre-training

In this tutorial, we will explore the mechanics of LLM architectures, with an emphasis on the differences between masked models and causal models. In the first section, we'll examine some existing pretrained models to understand how they produce their outputs. Once we've demonstrated how LLM's are able to do what they do, we will then run an abbreviated training loop to provide a glimpse into the training process.

## Installation and Imports

In [2]:
!pip install datasets "transformers[sentencepiece,torch]"
!pip install apache_beam

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting transformers[sentencepiece,torch]
  Using cached transformers-4.47.1-py3-none-any.whl.metadata (44 kB)
Collecting filelock (from datasets)
  Using cached filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting pandas (from datasets)
  Using cached pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl.metadata (89 kB)
Collecting tqdm>=4.66.3 (from datasets)
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.5.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Using cached multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Using cached fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting aiohttp (from datasets)
  Downloading aiohttp-3.11.11-cp311-cp311-macosx_11_0_arm64.whl

In [1]:
import torch
from datasets import load_dataset, DatasetDict

from transformers import (
    BertTokenizer,
    BertForMaskedLM,
    GPT2Tokenizer,
    GPT2LMHeadModel,
    DataCollatorForLanguageModeling,
    AutoConfig,
    AutoTokenizer,
    Trainer,
    TrainingArguments
)

## Understanding Masked LM's

In [2]:
## The first model we will look at is BERT, which is trained with masked tokens. As an example,
## the text below masks the word "box" from a well-known movie quote.

text = "Life is like a [MASK] of chocolates."

In [3]:
## We'll now see how BERT is able to predict the missing word. We can use HuggingFace to load
## a copy of the pretrained model and tokenizer.

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

In [5]:
## Next, we'll feed our example text into the tokenizer.

encoded_input = tokenizer(text, return_tensors='pt')
print('encoded_input:', encoded_input, '\n')
print('input_ids:', encoded_input['input_ids'])
print('attention_mask:', encoded_input['attention_mask'])

encoded_input: {'input_ids': tensor([[ 101, 2166, 2003, 2066, 1037,  103, 1997, 7967, 2015, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])} 

input_ids: tensor([[ 101, 2166, 2003, 2066, 1037,  103, 1997, 7967, 2015, 1012,  102]])
attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])


In [6]:
## input_ids represents the tokenized output. Each integer can be mapped back to the corresponding string.

print(tokenizer.decode([7967]))

chocolate


In [7]:
## The model will then receive the output of the tokenizer. We can look at the BERT model to see exactly how
## it was constructed and what the outputs will be like.

model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [9]:
## The model starts with an embedding of each of the 30,522 possible tokens into 768 dimensions, which at this
## point is simply a representation of each token without any additional information about their relationships
## to one another in the text. Then the encoder attention blocks are applied, updating the embeddings such that
## they now encode each token's contribution to the chunk of text and interactions with other tokens. Notably,
## this includes the masked tokens as well. The final stage is the language model head, which takes the embeddings
## from the masked positions back to 30,522 dimensions. Each index of this final vector corresponds to the
## probability that the token in that position would be the correct choice to fill the mask.


model_output = model(**encoded_input)
output = model_output["logits"]

print('model_output:', model_output, '\n')
print('output:', output, '\n')
print(output.shape)

model_output: MaskedLMOutput(loss=None, logits=tensor([[[ -6.8028,  -6.7572,  -6.7647,  ...,  -6.1487,  -5.9161,  -4.0691],
         [-10.1066, -10.5149,  -9.8165,  ...,  -9.6765,  -8.1962,  -4.5285],
         [ -8.0093,  -8.2118,  -8.2052,  ...,  -8.1465,  -4.5270,  -5.7145],
         ...,
         [-11.9653, -11.9873, -11.6436,  ..., -10.3985,  -9.1539,  -6.8904],
         [-11.7146, -11.2433, -11.6727,  ...,  -9.9394,  -9.8698,  -6.7365],
         [-13.5199, -14.0397, -13.5588,  ..., -13.8563, -10.4240,  -9.8423]]],
       grad_fn=<ViewBackward0>), hidden_states=None, attentions=None) 

output: tensor([[[ -6.8028,  -6.7572,  -6.7647,  ...,  -6.1487,  -5.9161,  -4.0691],
         [-10.1066, -10.5149,  -9.8165,  ...,  -9.6765,  -8.1962,  -4.5285],
         [ -8.0093,  -8.2118,  -8.2052,  ...,  -8.1465,  -4.5270,  -5.7145],
         ...,
         [-11.9653, -11.9873, -11.6436,  ..., -10.3985,  -9.1539,  -6.8904],
         [-11.7146, -11.2433, -11.6727,  ...,  -9.9394,  -9.8698,  -6.736

In [14]:
tokens = encoded_input['input_ids'][0].tolist()
masked_index = tokens.index(tokenizer.mask_token_id)
logits = output[0, masked_index, :]

print('tokens:', tokens, '\n')
print('masked_index:', masked_index, '\n')
print(logits.shape)

tokens: [101, 2166, 2003, 2066, 1037, 103, 1997, 7967, 2015, 1012, 102] 

masked_index: 5 

torch.Size([30522])


In [13]:
probs = logits.softmax(dim=-1)
values, predictions = probs.topk(5)
sequence = tokenizer.decode(predictions)

print('Top 5 predictions:', sequence)
print(values)

Top 5 predictions: box bag bowl jar cup
tensor([0.1764, 0.1688, 0.0419, 0.0336, 0.0262], grad_fn=<TopkBackward0>)


Printing the top 5 predictions and their respective scores, we see that BERT accurately chooses "box" as the most likely replacement for the mask token.

## Understanding Causal LM's

In [15]:
## We now repeat a similar exercise with the causal LLM GPT-2. This model generates
## text following an input, instead of replacing a mask within the text.

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [16]:
## We can examine the model again, noting the similarities to BERT. An embedding, 12 attention blocks,
## and a linear transformation bringing the output back to the size of the tokenizer. The tokenizer is
## different from BERT so we see we have more tokens this time.

model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [17]:
## We'll use a different text example, since this model works by producing tokens sequentially
## rather than filling a mask.

text = "Swimming at the beach is"
model_inputs = tokenizer(text, return_tensors='pt')
model_inputs

{'input_ids': tensor([[10462, 27428,   379,   262, 10481,   318]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [18]:
## After applying the model, the information needed to predict the next token is represented by
## the last token. So we can access that vector by the index -1.

output = model(**model_inputs)
next_token_logits = output.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
print(next_token)

tensor([257])


In [19]:
## Now add the new token to the end of the text, and feed all of it back to the model to continue
## predicting more tokens.

model_inputs['input_ids'] = torch.cat([model_inputs['input_ids'], next_token[:, None]], dim=-1)
model_inputs["attention_mask"] = torch.cat([model_inputs['attention_mask'], torch.tensor([[1]])], dim=-1)
print(model_inputs)

{'input_ids': tensor([[10462, 27428,   379,   262, 10481,   318,   257]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}


In [20]:
## Here's what we have so far. The model added the word 'a' to the input text.

print(tokenizer.decode(model_inputs['input_ids'][0]))

Swimming at the beach is a


In [22]:
## Repeating all the previous steps, we then add the word 'great'.

output = model(**model_inputs)
next_token_logits = output.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
model_inputs['input_ids'] = torch.cat([model_inputs['input_ids'], next_token[:, None]], dim=-1)
model_inputs["attention_mask"] = torch.cat([model_inputs['attention_mask'], torch.tensor([[1]])], dim=-1)
print('model_inputs:', model_inputs, '\n')
print(tokenizer.decode(model_inputs['input_ids'][0]))

model_inputs: {'input_ids': tensor([[10462, 27428,   379,   262, 10481,   318,   257,  1049,   835]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])} 

Swimming at the beach is a great way


In [23]:
## HuggingFace automates this iterative process. We'll use the quicker approach to finish our sentence.

output_generate = model.generate(**model_inputs, max_length=20, pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(output_generate[0]))

Swimming at the beach is a great way to get a little extra energy.

The beach


## Pre-training a GPT-2 model from scratch

Next we'll train a GPT-2 model from scratch using English Wikipedia data. Note that we're only using a tiny subset of the data to demonstrate that the model is capable of learning. The exact same approach could be followed on the full dataset to train a more functional model, but that would require a lot of compute.

In [2]:
# dataset = load_dataset("wikipedia", "20220301.en", trust_remote_code=True)
dataset = load_dataset("retarfi/wikipedia-en-20230720-debug", trust_remote_code=True)
ds_shuffle = dataset['train'].shuffle()

raw_datasets = DatasetDict(
    {
        "train": ds_shuffle.select(range(50)),
        "valid": ds_shuffle.select(range(50, 100))
    }
)

raw_datasets

README.md:   0%|          | 0.00/522 [00:00<?, ?B/s]

(…)-00000-of-00001-8ec0b803c0d960bb.parquet:   0%|          | 0.00/1.70M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['curid', 'title', 'text'],
        num_rows: 50
    })
    valid: Dataset({
        features: ['curid', 'title', 'text'],
        num_rows: 50
    })
})

In [3]:
print(raw_datasets['train'][0]['text'][:200])

Abraham Lincoln ( ; February 12, 1809 – April 15, 1865) was an American lawyer, politician, and statesman who served as the 16th president of the United States from 1861 until his assassination in 186


In [5]:
## We'll tokenize the text, setting the context size to 128 and thus breaking each document into chunks of 128 tokens.

context_length = 128
tokenizer = AutoTokenizer.from_pretrained("gpt2")

outputs = tokenizer(
    raw_datasets["train"][:2]["text"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)

print(f"Input IDs length: {len(outputs['input_ids'])}", '\n')
print(f"Input chunk lengths: {(outputs['length'])}", '\n')
print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

Input IDs length: 158 

Input chunk lengths: [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 107, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 52] 

Chunk mapping: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [7]:
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)
tokenized_datasets

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 2188
    })
    valid: Dataset({
        features: ['input_ids'],
        num_rows: 2630
    })
})

Now we can set up the HuggingFace Trainer as follows. Since we're using such a small dataset, we'll need lots of epochs for the model to make progress because all of the parameters are randomly initialized at the outset. Typically, most LLM's are trained for only one epoch and more diverse examples.

In [8]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

model = GPT2LMHeadModel(config)

In [9]:
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [10]:
args = TrainingArguments(
    output_dir="wiki-gpt2",
    evaluation_strategy="steps",
    num_train_epochs=100
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"]
)

  trainer = Trainer(


In [11]:
trainer.train()

  0%|          | 0/27400 [00:00<?, ?it/s]

{'loss': 7.2315, 'grad_norm': 2.4070262908935547, 'learning_rate': 4.908759124087591e-05, 'epoch': 1.82}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 7.129822731018066, 'eval_runtime': 38.7521, 'eval_samples_per_second': 67.867, 'eval_steps_per_second': 8.49, 'epoch': 1.82}
{'loss': 6.2851, 'grad_norm': 2.9584884643554688, 'learning_rate': 4.817518248175183e-05, 'epoch': 3.65}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 6.925766468048096, 'eval_runtime': 38.6199, 'eval_samples_per_second': 68.1, 'eval_steps_per_second': 8.519, 'epoch': 3.65}
{'loss': 5.7643, 'grad_norm': 3.7472305297851562, 'learning_rate': 4.726277372262774e-05, 'epoch': 5.47}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 6.874337673187256, 'eval_runtime': 38.5654, 'eval_samples_per_second': 68.196, 'eval_steps_per_second': 8.531, 'epoch': 5.47}
{'loss': 5.3205, 'grad_norm': 4.354938507080078, 'learning_rate': 4.635036496350365e-05, 'epoch': 7.3}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 6.903533458709717, 'eval_runtime': 38.6368, 'eval_samples_per_second': 68.07, 'eval_steps_per_second': 8.515, 'epoch': 7.3}
{'loss': 4.8807, 'grad_norm': 4.876628398895264, 'learning_rate': 4.5437956204379564e-05, 'epoch': 9.12}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 6.984766960144043, 'eval_runtime': 38.4419, 'eval_samples_per_second': 68.415, 'eval_steps_per_second': 8.558, 'epoch': 9.12}
{'loss': 4.442, 'grad_norm': 5.547481536865234, 'learning_rate': 4.452554744525548e-05, 'epoch': 10.95}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 7.050103187561035, 'eval_runtime': 38.7343, 'eval_samples_per_second': 67.898, 'eval_steps_per_second': 8.494, 'epoch': 10.95}
{'loss': 3.99, 'grad_norm': 6.534890174865723, 'learning_rate': 4.361313868613139e-05, 'epoch': 12.77}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 7.183302402496338, 'eval_runtime': 38.4686, 'eval_samples_per_second': 68.368, 'eval_steps_per_second': 8.552, 'epoch': 12.77}
{'loss': 3.5818, 'grad_norm': 6.865777015686035, 'learning_rate': 4.27007299270073e-05, 'epoch': 14.6}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 7.356001377105713, 'eval_runtime': 38.5519, 'eval_samples_per_second': 68.22, 'eval_steps_per_second': 8.534, 'epoch': 14.6}
{'loss': 3.1923, 'grad_norm': 7.661568641662598, 'learning_rate': 4.1788321167883216e-05, 'epoch': 16.42}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 7.499016761779785, 'eval_runtime': 38.5999, 'eval_samples_per_second': 68.135, 'eval_steps_per_second': 8.523, 'epoch': 16.42}
{'loss': 2.8103, 'grad_norm': 7.98348331451416, 'learning_rate': 4.0875912408759126e-05, 'epoch': 18.25}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 7.6842570304870605, 'eval_runtime': 38.5363, 'eval_samples_per_second': 68.247, 'eval_steps_per_second': 8.537, 'epoch': 18.25}
{'loss': 2.4543, 'grad_norm': 7.786740303039551, 'learning_rate': 3.9963503649635035e-05, 'epoch': 20.07}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 7.757016181945801, 'eval_runtime': 38.3544, 'eval_samples_per_second': 68.571, 'eval_steps_per_second': 8.578, 'epoch': 20.07}
{'loss': 2.0971, 'grad_norm': 8.255990028381348, 'learning_rate': 3.905109489051095e-05, 'epoch': 21.9}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 7.878429412841797, 'eval_runtime': 38.3349, 'eval_samples_per_second': 68.606, 'eval_steps_per_second': 8.582, 'epoch': 21.9}
{'loss': 1.7764, 'grad_norm': 8.459480285644531, 'learning_rate': 3.813868613138686e-05, 'epoch': 23.72}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.015480041503906, 'eval_runtime': 38.4592, 'eval_samples_per_second': 68.384, 'eval_steps_per_second': 8.555, 'epoch': 23.72}
{'loss': 1.4802, 'grad_norm': 8.08670711517334, 'learning_rate': 3.722627737226278e-05, 'epoch': 25.55}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.12671947479248, 'eval_runtime': 38.358, 'eval_samples_per_second': 68.564, 'eval_steps_per_second': 8.577, 'epoch': 25.55}
{'loss': 1.2251, 'grad_norm': 7.575474262237549, 'learning_rate': 3.631386861313869e-05, 'epoch': 27.37}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.225798606872559, 'eval_runtime': 38.3894, 'eval_samples_per_second': 68.508, 'eval_steps_per_second': 8.57, 'epoch': 27.37}
{'loss': 0.9852, 'grad_norm': 7.271633625030518, 'learning_rate': 3.5401459854014604e-05, 'epoch': 29.2}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.325549125671387, 'eval_runtime': 38.5015, 'eval_samples_per_second': 68.309, 'eval_steps_per_second': 8.545, 'epoch': 29.2}
{'loss': 0.781, 'grad_norm': 6.5625481605529785, 'learning_rate': 3.448905109489051e-05, 'epoch': 31.02}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.398702621459961, 'eval_runtime': 38.3779, 'eval_samples_per_second': 68.529, 'eval_steps_per_second': 8.573, 'epoch': 31.02}
{'loss': 0.5937, 'grad_norm': 6.837762832641602, 'learning_rate': 3.357664233576642e-05, 'epoch': 32.85}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.435324668884277, 'eval_runtime': 38.4345, 'eval_samples_per_second': 68.428, 'eval_steps_per_second': 8.56, 'epoch': 32.85}
{'loss': 0.4552, 'grad_norm': 6.207019329071045, 'learning_rate': 3.266423357664234e-05, 'epoch': 34.67}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.49921703338623, 'eval_runtime': 38.37, 'eval_samples_per_second': 68.543, 'eval_steps_per_second': 8.574, 'epoch': 34.67}
{'loss': 0.3488, 'grad_norm': 5.646033763885498, 'learning_rate': 3.175182481751825e-05, 'epoch': 36.5}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.638547897338867, 'eval_runtime': 38.3718, 'eval_samples_per_second': 68.54, 'eval_steps_per_second': 8.574, 'epoch': 36.5}
{'loss': 0.2674, 'grad_norm': 5.296948432922363, 'learning_rate': 3.083941605839416e-05, 'epoch': 38.32}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.69522476196289, 'eval_runtime': 38.44, 'eval_samples_per_second': 68.418, 'eval_steps_per_second': 8.559, 'epoch': 38.32}
{'loss': 0.2112, 'grad_norm': 5.084445953369141, 'learning_rate': 2.992700729927008e-05, 'epoch': 40.15}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.72973346710205, 'eval_runtime': 38.3839, 'eval_samples_per_second': 68.518, 'eval_steps_per_second': 8.571, 'epoch': 40.15}
{'loss': 0.1702, 'grad_norm': 3.9905476570129395, 'learning_rate': 2.9014598540145988e-05, 'epoch': 41.97}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.778406143188477, 'eval_runtime': 38.3517, 'eval_samples_per_second': 68.576, 'eval_steps_per_second': 8.579, 'epoch': 41.97}
{'loss': 0.1394, 'grad_norm': 3.1596906185150146, 'learning_rate': 2.8102189781021898e-05, 'epoch': 43.8}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.82097339630127, 'eval_runtime': 38.2587, 'eval_samples_per_second': 68.743, 'eval_steps_per_second': 8.599, 'epoch': 43.8}
{'loss': 0.1192, 'grad_norm': 4.753178119659424, 'learning_rate': 2.7189781021897807e-05, 'epoch': 45.62}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.903903007507324, 'eval_runtime': 38.3487, 'eval_samples_per_second': 68.581, 'eval_steps_per_second': 8.579, 'epoch': 45.62}
{'loss': 0.1029, 'grad_norm': 3.1203410625457764, 'learning_rate': 2.6277372262773724e-05, 'epoch': 47.45}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 8.957517623901367, 'eval_runtime': 38.3645, 'eval_samples_per_second': 68.553, 'eval_steps_per_second': 8.576, 'epoch': 47.45}
{'loss': 0.0911, 'grad_norm': 3.238842487335205, 'learning_rate': 2.5364963503649637e-05, 'epoch': 49.27}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.094146728515625, 'eval_runtime': 38.4145, 'eval_samples_per_second': 68.464, 'eval_steps_per_second': 8.564, 'epoch': 49.27}
{'loss': 0.0804, 'grad_norm': 2.876498222351074, 'learning_rate': 2.445255474452555e-05, 'epoch': 51.09}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.159856796264648, 'eval_runtime': 38.5466, 'eval_samples_per_second': 68.229, 'eval_steps_per_second': 8.535, 'epoch': 51.09}
{'loss': 0.0726, 'grad_norm': 3.280609130859375, 'learning_rate': 2.354014598540146e-05, 'epoch': 52.92}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.087692260742188, 'eval_runtime': 38.3021, 'eval_samples_per_second': 68.665, 'eval_steps_per_second': 8.59, 'epoch': 52.92}
{'loss': 0.0649, 'grad_norm': 3.2745323181152344, 'learning_rate': 2.2627737226277372e-05, 'epoch': 54.74}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.123873710632324, 'eval_runtime': 38.3945, 'eval_samples_per_second': 68.499, 'eval_steps_per_second': 8.569, 'epoch': 54.74}
{'loss': 0.0588, 'grad_norm': 2.962770700454712, 'learning_rate': 2.1715328467153285e-05, 'epoch': 56.57}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.17492389678955, 'eval_runtime': 38.2909, 'eval_samples_per_second': 68.685, 'eval_steps_per_second': 8.592, 'epoch': 56.57}
{'loss': 0.0547, 'grad_norm': 2.8347556591033936, 'learning_rate': 2.08029197080292e-05, 'epoch': 58.39}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.283421516418457, 'eval_runtime': 38.3064, 'eval_samples_per_second': 68.657, 'eval_steps_per_second': 8.589, 'epoch': 58.39}
{'loss': 0.0502, 'grad_norm': 1.4399888515472412, 'learning_rate': 1.989051094890511e-05, 'epoch': 60.22}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.228477478027344, 'eval_runtime': 537.9526, 'eval_samples_per_second': 4.889, 'eval_steps_per_second': 0.612, 'epoch': 60.22}
{'loss': 0.0471, 'grad_norm': 1.929141879081726, 'learning_rate': 1.897810218978102e-05, 'epoch': 62.04}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.261701583862305, 'eval_runtime': 981.6167, 'eval_samples_per_second': 2.679, 'eval_steps_per_second': 0.335, 'epoch': 62.04}
{'loss': 0.0434, 'grad_norm': 3.029575824737549, 'learning_rate': 1.8065693430656934e-05, 'epoch': 63.87}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.302042007446289, 'eval_runtime': 1053.4809, 'eval_samples_per_second': 2.496, 'eval_steps_per_second': 0.312, 'epoch': 63.87}
{'loss': 0.0405, 'grad_norm': 1.878759741783142, 'learning_rate': 1.715328467153285e-05, 'epoch': 65.69}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.305377006530762, 'eval_runtime': 1087.5718, 'eval_samples_per_second': 2.418, 'eval_steps_per_second': 0.303, 'epoch': 65.69}
{'loss': 0.0374, 'grad_norm': 2.630403995513916, 'learning_rate': 1.624087591240876e-05, 'epoch': 67.52}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.422980308532715, 'eval_runtime': 37.4337, 'eval_samples_per_second': 70.258, 'eval_steps_per_second': 8.789, 'epoch': 67.52}
{'loss': 0.0354, 'grad_norm': 0.9712914824485779, 'learning_rate': 1.5328467153284673e-05, 'epoch': 69.34}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.449637413024902, 'eval_runtime': 949.2747, 'eval_samples_per_second': 2.771, 'eval_steps_per_second': 0.347, 'epoch': 69.34}
{'loss': 0.0333, 'grad_norm': 3.111647605895996, 'learning_rate': 1.4416058394160584e-05, 'epoch': 71.17}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.510079383850098, 'eval_runtime': 37.6389, 'eval_samples_per_second': 69.875, 'eval_steps_per_second': 8.741, 'epoch': 71.17}
{'loss': 0.0316, 'grad_norm': 1.7258329391479492, 'learning_rate': 1.3503649635036497e-05, 'epoch': 72.99}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.481695175170898, 'eval_runtime': 982.4667, 'eval_samples_per_second': 2.677, 'eval_steps_per_second': 0.335, 'epoch': 72.99}
{'loss': 0.0296, 'grad_norm': 3.29180645942688, 'learning_rate': 1.259124087591241e-05, 'epoch': 74.82}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.519734382629395, 'eval_runtime': 971.4038, 'eval_samples_per_second': 2.707, 'eval_steps_per_second': 0.339, 'epoch': 74.82}
{'loss': 0.0274, 'grad_norm': 2.712430000305176, 'learning_rate': 1.1678832116788322e-05, 'epoch': 76.64}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.594943046569824, 'eval_runtime': 1072.964, 'eval_samples_per_second': 2.451, 'eval_steps_per_second': 0.307, 'epoch': 76.64}
{'loss': 0.0263, 'grad_norm': 0.8342414498329163, 'learning_rate': 1.0766423357664233e-05, 'epoch': 78.47}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.641653060913086, 'eval_runtime': 37.7352, 'eval_samples_per_second': 69.696, 'eval_steps_per_second': 8.719, 'epoch': 78.47}
{'loss': 0.025, 'grad_norm': 0.5869117975234985, 'learning_rate': 9.854014598540148e-06, 'epoch': 80.29}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.61470890045166, 'eval_runtime': 37.7101, 'eval_samples_per_second': 69.743, 'eval_steps_per_second': 8.724, 'epoch': 80.29}
{'loss': 0.0238, 'grad_norm': 2.373626708984375, 'learning_rate': 8.941605839416059e-06, 'epoch': 82.12}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.612598419189453, 'eval_runtime': 37.7357, 'eval_samples_per_second': 69.695, 'eval_steps_per_second': 8.719, 'epoch': 82.12}
{'loss': 0.0226, 'grad_norm': 2.194218635559082, 'learning_rate': 8.02919708029197e-06, 'epoch': 83.94}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.713486671447754, 'eval_runtime': 37.8276, 'eval_samples_per_second': 69.526, 'eval_steps_per_second': 8.697, 'epoch': 83.94}
{'loss': 0.0218, 'grad_norm': 0.6258280873298645, 'learning_rate': 7.116788321167883e-06, 'epoch': 85.77}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.618224143981934, 'eval_runtime': 37.6948, 'eval_samples_per_second': 69.771, 'eval_steps_per_second': 8.728, 'epoch': 85.77}
{'loss': 0.0204, 'grad_norm': 0.4551713466644287, 'learning_rate': 6.204379562043796e-06, 'epoch': 87.59}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.69461727142334, 'eval_runtime': 37.6701, 'eval_samples_per_second': 69.817, 'eval_steps_per_second': 8.734, 'epoch': 87.59}
{'loss': 0.0196, 'grad_norm': 0.3699822425842285, 'learning_rate': 5.2919708029197084e-06, 'epoch': 89.42}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.709086418151855, 'eval_runtime': 37.6626, 'eval_samples_per_second': 69.83, 'eval_steps_per_second': 8.735, 'epoch': 89.42}
{'loss': 0.0189, 'grad_norm': 0.8575927019119263, 'learning_rate': 4.379562043795621e-06, 'epoch': 91.24}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.753740310668945, 'eval_runtime': 37.6791, 'eval_samples_per_second': 69.8, 'eval_steps_per_second': 8.732, 'epoch': 91.24}
{'loss': 0.0184, 'grad_norm': 0.5209119319915771, 'learning_rate': 3.4671532846715328e-06, 'epoch': 93.07}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.741988182067871, 'eval_runtime': 37.688, 'eval_samples_per_second': 69.783, 'eval_steps_per_second': 8.73, 'epoch': 93.07}
{'loss': 0.0175, 'grad_norm': 0.39775925874710083, 'learning_rate': 2.5547445255474454e-06, 'epoch': 94.89}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.758973121643066, 'eval_runtime': 37.6741, 'eval_samples_per_second': 69.809, 'eval_steps_per_second': 8.733, 'epoch': 94.89}
{'loss': 0.0169, 'grad_norm': 0.5033899545669556, 'learning_rate': 1.6423357664233577e-06, 'epoch': 96.72}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.729644775390625, 'eval_runtime': 37.6897, 'eval_samples_per_second': 69.78, 'eval_steps_per_second': 8.729, 'epoch': 96.72}
{'loss': 0.0167, 'grad_norm': 1.0400859117507935, 'learning_rate': 7.299270072992701e-07, 'epoch': 98.54}


  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.743011474609375, 'eval_runtime': 37.676, 'eval_samples_per_second': 69.806, 'eval_steps_per_second': 8.732, 'epoch': 98.54}
{'train_runtime': 47752.8848, 'train_samples_per_second': 4.582, 'train_steps_per_second': 0.574, 'train_loss': 1.1271013758304345, 'epoch': 100.0}


TrainOutput(global_step=27400, training_loss=1.1271013758304345, metrics={'train_runtime': 47752.8848, 'train_samples_per_second': 4.582, 'train_steps_per_second': 0.574, 'total_flos': 1.42926741504e+16, 'train_loss': 1.1271013758304345, 'epoch': 100.0})

In [12]:
trainer.evaluate()

  0%|          | 0/329 [00:00<?, ?it/s]

{'eval_loss': 9.74312973022461,
 'eval_runtime': 39.0233,
 'eval_samples_per_second': 67.396,
 'eval_steps_per_second': 8.431,
 'epoch': 100.0}

The training loss is low by the end, which means the model should perform very well on training examples it has seen. It does not generalize well to the validation set of course, since we deliberately overfit on a small train set.

We can confirm with a couple of examples that were seen in training.

In [13]:
text = tokenizer.decode(tokenized_datasets["train"][0]['input_ids'][:16])
print(text)

Abraham Lincoln ( ; February 12, 1809 – April 15, 1865


In [14]:
model_inputs = tokenizer(text, return_tensors='pt')
print(model_inputs['input_ids'].shape)

torch.Size([1, 16])


In [22]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


In [17]:
model_inputs["input_ids"] = model_inputs["input_ids"].to(device)
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(device)

output_generate = model.generate(**model_inputs, max_new_tokens=16)
output_generate

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[ 4826, 13220, 12406,   357,  2162,  3945,  1105,    11,  1248,  2931,
          1849,  1906,  3035,  1315,    11, 47801,     8,   373,   281,  1605,
          6853,    11, 14971,    11,   290,  2585,   805,   508,  4983,   355,
           262,  1467]], device='mps:0')

In [18]:
sequence = tokenizer.decode(output_generate[0])
print(sequence)

Abraham Lincoln ( ; February 12, 1809 – April 15, 1865) was an American lawyer, politician, and statesman who served as the 16


The model should do quite well at reciting text after seeing it so many times. We can be convinced that the tokenizer, model architecture, and training objective are well-suited to learning Wikipedia data. For comparison, we'll try this model on text from the validation set.

In [19]:
text = tokenizer.decode(tokenized_datasets["valid"][0]['input_ids'][:32])
print(text)

The alkali metals consist of the chemical elements lithium (Li), sodium (Na), potassium (K), rubidium (Rb), caesium (Cs


In [23]:
model_inputs = tokenizer(text, return_tensors='pt')

# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_inputs['input_ids'] = model_inputs['input_ids'].to(device)
model_inputs['attention_mask'] = model_inputs['attention_mask'].to(device)

output_generate = model.generate(**model_inputs, max_new_tokens=16)
sequence = tokenizer.decode(output_generate[0])
print(sequence)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


The alkali metals consist of the chemical elements lithium (Li), sodium (Na), potassium (K), rubidium (Rb), caesium (Cs () (Anna Lizium (Cfb) and 'para) (called


In [24]:
raw_datasets['valid'][0]['text']

'The alkali metals consist of the chemical elements lithium (Li), sodium (Na), potassium (K), rubidium (Rb), caesium (Cs), and francium (Fr). Together with hydrogen they constitute group 1, which lies in the s-block of the periodic table. All alkali metals have their outermost electron in an s-orbital: this shared electron configuration results in their having very similar characteristic properties. Indeed, the alkali metals provide the best example of group trends in properties in the periodic table, with elements exhibiting well-characterised homologous behaviour. This family of elements is also known as the lithium family after its leading element.\nThe alkali metals are all shiny, soft, highly reactive metals at standard temperature and pressure and readily lose their outermost electron to form cations with charge +1. They can all be cut easily with a knife due to their softness, exposing a shiny surface that tarnishes rapidly in air due to oxidation by atmospheric moisture and oxy

As expected, our model is completely confused this time. We'd need to train for much longer, and on much more diverse data, before we would have a model that can sensibly complete prompts it has never seen before. This is precisely why pre-training is such an important and powerful technique. If we had to train on all of Wikipedia for every NLP application to achieve optimal performance, it would be prohibitively expensive. But there's no need to do that when we can share and reuse existing pre-trained models as we did in the first part of this tutorial.