# Chapter 3 - Pre-train a tiny LLM

In [2]:
import os
from pathlib import Path
import sys

current_path = Path(os.getcwd())
parent_path  = str(current_path.parent.absolute())

print(parent_path)
sys.path.append(parent_path)

/home/gopi/Documents/small_llm/llmbook/chapters


In [3]:
from accelerate import Accelerator
import os
import torch
import bitsandbytes as bnb
import torch.nn as nn
from dataclasses import dataclass
import math
from datasets import load_dataset



In [8]:
%load_ext autoreload
%autoreload 2

from chapter1.simplebooks import get_dataloaders, get_tokenizer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
train_loader, valid_loader = get_dataloaders(batch_size=12, num_workers=4)

Loading dataset from /home/gopi/Documents/small_llm/llmbook/data/simplebooks/simplebooks-2-raw/
Total train tokens 1676477
Total validation tokens 189785


In [24]:
from chapter2.gptlikemodel import SLLM, SLLMConfig

In [25]:
config = SLLMConfig()
config

SLLMConfig(d_model=128, d_head=128, bias=False, dropout=0.0, context_window=50, n_heads=2, vocab_size=52000, n_layers=2)

In [26]:
model = SLLM(config)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

total_size_bytes = total_params * 4
total_size_mb = total_size_bytes / (1024 * 1024)

print(f"Model size: {total_size_mb:.2f} MB")

Total parameters: 13,698,976
Model size: 52.26 MB


In [27]:
def generate_text(model, idx, max_new_tokens, context_size):
    for _ in range(max_new_tokens):
        idx_trim = idx[:,-context_size:]
        
        with torch.no_grad():
            logits = model(idx_trim)
        
        logits = logits[:,-1,:]
        probas = torch.softmax(logits, dim=-1)
        
        idx_next = torch.argmax(probas, dim=-1, keepdim=True)
        
        idx = torch.cat((idx, idx_next), dim=1)
    return idx

In [28]:
start_context = "wonderful spring is awaited."
tokenizer = get_tokenizer()
encoded = tokenizer.encode(start_context)

encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print(f"Encoded tensor {encoded_tensor}")

Loading tokenizer from /home/gopi/Documents/small_llm/llmbook/data/simplebooks-tokenizer
Encoded tensor tensor([[ 3111, 18733,  1316,   357, 14783,    14]])


In [29]:
model.eval()

out = generate_text(model, encoded_tensor, 5, context_size=50)

print("Output", out.squeeze(0).tolist())

Output [3111, 18733, 1316, 357, 14783, 14, 39860, 5082, 25350, 7337, 45039]


In [30]:
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(f"Decoded text: {decoded_text}")

Decoded text: wonderful spring is awaited.Still pictured Scar


In [41]:
class LLMLoss(nn.Module):
    def __init__(self):
        super(LLMLoss, self).__init__()
    
    def forward(self, logits, targets):
        loss = torch.nn.functional.cross_entropy(logits.flatten(0,1), targets.flatten())
        
model.eval()        
logits = model(encoded_tensor)
print(logits.shape)


torch.Size([1, 6, 52000])


In [42]:
logits

tensor([[[ 0.2436,  0.4534, -0.2164,  ..., -0.0228,  0.1009,  0.0285],
         [-0.5120,  0.6210,  0.3816,  ..., -1.3161,  0.2259,  0.6086],
         [-0.6945, -0.5430, -0.2874,  ...,  0.3271,  0.0516,  0.0516],
         [-1.0814, -0.7064,  0.9559,  ..., -0.3215, -0.2887, -0.0172],
         [-0.0137,  0.5215, -1.4794,  ...,  0.3573, -0.6554, -0.3192],
         [ 0.0354, -0.8348, -1.0299,  ...,  0.0443,  0.0224, -0.0852]]],
       grad_fn=<ViewBackward0>)

In [59]:
target_context = "nothing is awaited t.let"
target_encoded = tokenizer.encode(target_context)
target_tensors = torch.tensor(target_encoded).unsqueeze(0)
target_tensors.shape

torch.Size([1, 6])

In [60]:
target_tensors.flatten()

tensor([12323,   357, 14783,   257,    14,  1499])

In [61]:
logits.flatten(0,1)

tensor([[ 0.2436,  0.4534, -0.2164,  ..., -0.0228,  0.1009,  0.0285],
        [-0.5120,  0.6210,  0.3816,  ..., -1.3161,  0.2259,  0.6086],
        [-0.6945, -0.5430, -0.2874,  ...,  0.3271,  0.0516,  0.0516],
        [-1.0814, -0.7064,  0.9559,  ..., -0.3215, -0.2887, -0.0172],
        [-0.0137,  0.5215, -1.4794,  ...,  0.3573, -0.6554, -0.3192],
        [ 0.0354, -0.8348, -1.0299,  ...,  0.0443,  0.0224, -0.0852]],
       grad_fn=<ViewBackward0>)

In [62]:
torch.nn.functional.cross_entropy(logits.flatten(0,1), target_tensors.flatten())

tensor(10.8142, grad_fn=<NllLossBackward0>)

In [63]:
target_context = "sum is awaited.let"
target_encoded = tokenizer.encode(target_context)
target_tensors = torch.tensor(target_encoded).unsqueeze(0)
target_tensors.shape

torch.Size([1, 6])

In [64]:
torch.nn.functional.cross_entropy(logits.flatten(0,1), target_tensors.flatten())

tensor(10.8329, grad_fn=<NllLossBackward0>)

In [65]:
class LLMLoss(nn.Module):
    def __init__(self):
        super(LLMLoss, self).__init__()
    
    def forward(self, logits, targets):
        loss = torch.nn.functional.cross_entropy(logits.flatten(0,1), targets.flatten())
        return loss

In [69]:
from accelerate import Accelerator
from transformers import AdamW, get_scheduler
from tqdm import tqdm

accelerator = Accelerator()

optimizer = AdamW(model.parameters(), lr=3e-5)



num_epochs = 1

lr_scheduler = get_scheduler(
  "linear",
  optimizer=optimizer,
  num_warmup_steps=0,
  num_training_steps=num_training_steps
)

train_dataloader, eval_dataloader, model, optimizer, scheduler = accelerator.prepare(
     train_loader, valid_loader, model, optimizer, lr_scheduler
 )
num_training_steps = num_epochs * len(train_dataloader)

progress_bar = tqdm(range(num_training_steps))

loss_fn = LLMLoss()



model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        features,target = batch
        logits = model(features)
        loss = loss_fn(logits, target)
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        progress_bar.update(1)
      


  0%|                                                                                                                                                                              | 0/8382 [00:27<?, ?it/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the envir

In [74]:
start_context = "wonderful spring is awaited."
tokenizer = get_tokenizer()
encoded = tokenizer.encode(start_context)
model.to("cpu")

encoded_tensor = torch.tensor(encoded).unsqueeze(0)
model.eval()

out = generate_text(model, encoded_tensor, 5, context_size=50)


Loading tokenizer from /home/gopi/Documents/small_llm/llmbook/data/simplebooks-tokenizer


In [75]:
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(f"Decoded text: {decoded_text}")

Decoded text: wonderful spring is awaited... the. and


## Save and Load Model

In [97]:
save_directory = str(Path(current_path.parent.parent.absolute(), "bin"))

# save state dictionary
accelerator.wait_for_everyone()
accelerator.save_model(model, save_directory)

In [103]:
accelerator.save_model(model, save_directory, max_shard_size="1GB", safe_serialization=True)

In [109]:
from accelerate import load_checkpoint_in_model

new_model = SLLM(config)
device = accelerator.device
load_checkpoint_in_model(model, save_directory)

In [111]:
start_context = "wonderful spring is awaited."
tokenizer = get_tokenizer()
encoded = tokenizer.encode(start_context)
model.to("cpu")

encoded_tensor = torch.tensor(encoded).unsqueeze(0)
model.eval()

out = generate_text(model, encoded_tensor, 5, context_size=50)
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(f"Decoded text: {decoded_text}")

Loading tokenizer from /home/gopi/Documents/small_llm/llmbook/data/simplebooks-tokenizer
Decoded text: wonderful spring is awaited... the. and
