# Chapter 3 - Pre-train a tiny LLM

In this chapter we will train our tiny LLM using simplebooks dataset.

In [1]:
import os
from pathlib import Path
import sys
import warnings

warnings.simplefilter("ignore")
current_path = Path(os.getcwd())
parent_path  = str(current_path.parent.absolute())
sys.path.append(parent_path)

## Prepare dataset

Get the train test and validation dataloaders from the code we wrote in previous chapter. We will use the gpt2 tokenizer from Transformers library.

In [None]:
%load_ext autoreload
%autoreload 1

from chapter1.simplebooks import get_dataloaders, get_tokenizer


# Load train,validation and test datasets
train_loader, valid_loader, test_loader = get_dataloaders(batch_size=12, \
                num_workers=4)



## Load Model

We have close to 1.7 million tokens for training. 193K tokens for validation and 192K tokenz for testing.


In [None]:
from chapter2.gptlikemodel import SLLM, SLLMConfig


# Initialize the model class
config = SLLMConfig()

print(f"Model configuration {config}\n")

model = SLLM(config)

total_params = sum(p.numel() for p in model.parameters())

print(f"Total parameters: {total_params:,}")

total_size_bytes = total_params * 4
total_size_mb = total_size_bytes / (1024 * 1024)

print(f"Model size: {total_size_mb:.2f} MB")

## Query the model

In [3]:
import torch
from chapter1.simplebooks import get_tokenizer

def generate_text(model, idx, max_new_tokens, context_size):
    """
    Generate output tokens from a given model.
    Arguments:
        model: 
            llm model for text generation
        idx:
            Input token tensor
        max_new_tokens:
            Number of output tokens to be generated
        context_size:
            model context window.
    """
    for _ in range(max_new_tokens):
        idx_trim = idx[:,-context_size:]
        
        with torch.no_grad():
            logits = model(idx_trim)
        
        logits = logits[:,-1,:]
        probas = torch.softmax(logits, dim=-1)
        
        idx_next = torch.argmax(probas, dim=-1, keepdim=True)
        
        idx = torch.cat((idx, idx_next), dim=1)
    return idx

def invoke_model(model, start_context):
    
    assert len(start_context) > 0 \
        and start_context is not None
        
    print(f"Input context: '{start_context}'\n")
    tokenizer = get_tokenizer()
    encoded = tokenizer.encode(start_context)
    
    # convert to tensor and add batch dimension
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)
    print(f"Encoded tensor {encoded_tensor} No Tokens: {encoded_tensor.size()[-1]} \n")
    
    model.eval()
    with torch.no_grad():
        out = generate_text(model, encoded_tensor, 5, context_size=50)
    print(f"Output {out} No Tokens: {out.size()[-1]}")
    
    decoded_text = tokenizer.decode(out.squeeze(0))
    print(f"Decoded text: '{decoded_text}'")

In [None]:
start_context = "wonderful spring is awaited."
model = model.to("cpu")
invoke_model(model, start_context)

The output tokensize is 11, 5 more than the input token size.

## Define loss function

In [None]:
import torch.nn as nn
import torch



class LLMLoss(nn.Module):
    def __init__(self):
        super(LLMLoss, self).__init__()
    
    def forward(self, logits, targets):
        loss = torch.nn.functional.cross_entropy(logits.flatten(0,1), targets.flatten())
        return loss

        
def batch_loss(loss_fn, input_batch,target_batch, model, device):

    assert model is not None
    assert input_batch is not None 
    assert target_batch is not None

    input_batch  = input_batch.to(device)
    target_batch = target_batch.to(device)

    with torch.no_grad():
        logits = model(input_batch)
        loss   = loss_fn(logits, target_batch)


    return loss

def loader_loss(loss_fn, data_loader, model, device="cpu"):

    assert data_loader is not None
    assert model is not None

    total_loss = 0
    num_batches = len(data_loader)

    for i, batch in enumerate(data_loader):

        features, target = batch
        loss = batch_loss(loss_fn, features, target, model, device)
        total_loss+=loss

    return total_loss / num_batches

            
        
    
        

In [None]:
device = "cuda"

loss_fn = LLMLoss()

model = model.to(device)
model.eval()
batch_no = 1
for batch in train_loader:
    
    features, target = batch
    loss = batch_loss(loss_fn, features, target, model, device)
    
    print(f"Batch {batch_no} Loss {loss}")
    batch_no+=1
    
    if batch_no > 2:
        break



In [None]:
train_loss = loader_loss(loss_fn, train_loader, model,device)

print(f"Train data loss {train_loss}")

## Training Loop

In [None]:
import torch
import math
from tqdm import tqdm

## Learning rate warmup

n_epochs = 10
initial_lr = 1e-4
min_lr = 1e-6
top_lr = 0.01
warmup_steps = 20
total_training_steps = n_epochs * len(train_loader)
device = "cuda"
progress_bar = tqdm(range(total_training_steps))
eval_freq = 500

lr_increment = (top_lr - initial_lr) / warmup_steps


optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.1)
loss_fn = LLMLoss()

global_steps = -1
tokens_seen = 0

track_lrs = []

train_losses = []
eval_losses  = []

model = model.to(device)

for epoch in range(n_epochs):
    
    losses = []
    model.train()
    for input_batch in train_loader:
        
        features, target = input_batch
        features = features.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()
        global_steps+=1
        
        if global_steps < warmup_steps:
            lr = initial_lr + global_steps * lr_increment
        else:
            # cosine decay
            progress = (global_steps - warmup_steps) / (total_training_steps - warmup_steps)
            lr = min_lr + (top_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
            
        
        for param_group in optimizer.param_groups:
            param_group["lr"] =lr
        
        logits = model(features)
        loss = loss_fn(logits, target)
        
        tokens_seen += features.numel()
        
        loss.backward()
        
        if global_steps > warmup_steps:
            torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=0.1)
        
        optimizer.step()
        
        losses.append(loss.item())
        track_lrs.append(lr)
        
        progress_bar.update(1)
        
        """
        if global_steps % eval_freq == 0:
            model.eval()
            eval_loss = loader_loss(loss_fn, valid_loader, model)
            model.train()
            print(f"Epoch {epoch} Evaluation Loss {eval_loss} LR {lr}")
            eval_losses.append((epoch, eval_loss))
        """
        
        
    
    print(f"Epoch {epoch} Avg Train Loss {sum(losses)/len(losses)} LR {lr}")
    invoke_model(model, start_context)
        
        
        
        
        

        


## Learning rate warmup and Cosine decay

## Load saved model for inference

In [4]:
## Load the saved model
from chapter2.gptlikemodel import SLLM, SLLMConfig
import torch
from pathlib import Path

current_path = os.getcwd()
model_path   = str(Path(current_path).parent.parent.absolute())


config = SLLMConfig()
model = SLLM(config)

save_directory = model_path + "/bin/"
model_name = "small_llm-v1-52-0.855"

model.load_state_dict(torch.load(save_directory + model_name))

<All keys matched successfully>

In [8]:
start_context = "It is a"
invoke_model(model, start_context)


Input context: 'It is a'

Loading tokenizer from /home/gopi/Documents/small_llm/llmbook/data/simplebooks-tokenizer
Encoded tensor tensor([[679, 357, 259]]) No Tokens: 3 

Output tensor([[  679,   357,   259,  7485,    14, 10148, 26114,     2]]) No Tokens: 8
Decoded text: 'It is a yeast. Page wisps"'


In [None]:
## Improvements to invoking the model for text generation

* curently greedy search
* Beam search
* temperature
* top-k


https://huggingface.co/blog/introducing-csearch
Deterministic methods, e.g. greedy search and beam search, generate text by selecting the text continuation with the highest likelihood measured by the language model. However, as widely discussed in previous studies [3][4], deterministic methods often lead to the problem of model degeneration, i.e., the generated text is unnatural and contains undesirable repetitions.

To address the issues posed by deterministic methods, stochastic methods generate text by introducing randomness during the decoding process. Two widely-used stochastic methods are (i) top-k sampling [3] and (ii) nucleus sampling (also called top-p sampling) [4].

While nucleus sampling can generate text free of repetitions, the semantic coherence of the generated text is not well-maintained. For instance, the generated phrase 'AI is not journalism' is incoherent with respect to the given prefix, i.e. 'DeepMind Company'.

We note that this semantic inconsistency problem can partially be remedied by lowering the temperature. However, reducing the temperature brings nucleus sampling closer to greedy search, which can be seen as a trade-off between greedy search and nucleus sampling. Generally, it is challenging to find a prompt and model-independent temperature that avoids both the pitfalls of greedy search and nucleus sampling.


In [48]:
import torch


def greedy_search(**kwargs):
    logits = kwargs['logits']
    probas = torch.softmax(logits, dim=-1)
    idx_next = torch.argmax(probas, dim=-1, keepdim=True)
    return idx_next


def generate_text(model, idx, max_new_tokens
                  , context_size
                  , search_fn=greedy_search
                  , temperature=1.0):
    """
    Generate output tokens from a given model.
    Arguments:
        model: 
        
            llm model for text generation
        idx:
            Input token tensor
        max_new_tokens:
            Number of output tokens to be generated
        context_size:
            model context window.
    """
    for _ in range(max_new_tokens):
        idx_trim = idx[:,-context_size:]
        
        with torch.no_grad():
            logits = model(idx_trim)
        
        logits = logits[:,-1,:]
        idx_next = search_fn(logits=logits,temperature=temperature)
        
        idx = torch.cat((idx, idx_next), dim=1)
    return idx


In [54]:
def invoke_model(model,tokenizer 
                 ,start_context
                 ,search_fn=greedy_search
                ,temperature=1.0):
    
    assert len(start_context) > 0 \
        and start_context is not None
        
    print(f"Input context: '{start_context}'")
    encoded = tokenizer.encode(start_context)
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)
    model.eval()
    with torch.no_grad():
        out = generate_text(model, encoded_tensor, 5
                            , context_size=50
                            ,search_fn=search_fn
                           ,temperature=temperature)
    
    decoded_text = tokenizer.decode(out.squeeze(0))
    print(f"Decoded text: '{decoded_text}'\n")

In [55]:
tokenizer = get_tokenizer()

for i in range(2):
    start_context = "It is a"
    invoke_model(model,tokenizer,start_context,search_fn=greedy_search)


Loading tokenizer from /home/gopi/Documents/small_llm/llmbook/data/simplebooks-tokenizer
Input context: 'It is a'
Decoded text: 'It is a yeast?" Whoo- Whoo'

Input context: 'It is a'
Decoded text: 'It is a yeast?" Whoo- Whoo'



In [56]:
def probabilistic_search(**kwargs):
    logits = kwargs['logits']
    probas = torch.softmax(logits, dim=-1)
    idx_next = torch.multinomial(probas, num_samples=1)
    return idx_next


In [57]:
for i in range(2):
    start_context = "It is a"
    invoke_model(model,tokenizer,start_context,search_fn=probabilistic_search)

Input context: 'It is a'
Decoded text: 'It is a campfire!" Painted Weasel.'

Input context: 'It is a'
Decoded text: 'It is a lifeboat, yer, shy'



In [95]:
import numpy as np

words = ["a","tree","space"]

logits = np.asarray([0.2,0.11,0.5])
temp_range = np.linspace(0,1,11)

softmax = lambda x: np.exp(x)/sum(np.exp(x))

for temperature in temp_range:
    if temperature > 0:
        b = np.round(logits * 1/temperature,2)
        b_norm = np.round(softmax(b),3)
        print(f"@ Temperature {temperature:.2f} values {b_norm}")

        experiments = 50
        idxs = np.random.multinomial(experiments, b_norm)
        
        for word,choosen_freq in zip(words, idxs):
            print(f"\t{word} choosen {choosen_freq} times out of {experiments} trials")
            
            

@ Temperature 0.10 values [0.047 0.019 0.935]
	a choosen 3 times out of 50 trials
	tree choosen 1 times out of 50 trials
	space choosen 46 times out of 50 trials
@ Temperature 0.20 values [0.163 0.104 0.732]
	a choosen 9 times out of 50 trials
	tree choosen 3 times out of 50 trials
	space choosen 38 times out of 50 trials
@ Temperature 0.30 values [0.224 0.166 0.61 ]
	a choosen 7 times out of 50 trials
	tree choosen 9 times out of 50 trials
	space choosen 34 times out of 50 trials
@ Temperature 0.40 values [0.256 0.203 0.541]
	a choosen 13 times out of 50 trials
	tree choosen 9 times out of 50 trials
	space choosen 28 times out of 50 trials
@ Temperature 0.50 values [0.273 0.228 0.498]
	a choosen 19 times out of 50 trials
	tree choosen 13 times out of 50 trials
	space choosen 18 times out of 50 trials
@ Temperature 0.60 values [0.285 0.245 0.47 ]
	a choosen 12 times out of 50 trials
	tree choosen 8 times out of 50 trials
	space choosen 30 times out of 50 trials
@ Temperature 0.70 value

In [79]:
def temperature_scaling(**kwargs):
    logits = kwargs['logits']
    temperature = kwargs['temperature']
    probas = torch.softmax(logits/temperature, dim=-1)
    idx_next = torch.argmax(probas, dim=-1, keepdim=True)
    return idx_next

    
    

In [59]:
for i in range(2):
    start_context = "It is a"
    temperature =0.7
    invoke_model(model,tokenizer,start_context,search_fn=probabilistic_search, temperature=0.7)

Input context: 'It is a'
Decoded text: 'It is a yeast. Page rap increased'

Input context: 'It is a'
Decoded text: 'It is a campfire. Page a fa'



## Accelarator

In [None]:
from accelerate import Accelerator
from transformers import AdamW, get_scheduler
from tqdm import tqdm

accelerator = Accelerator()

optimizer = AdamW(model.parameters(), lr=3e-5)



num_epochs = 1

lr_scheduler = get_scheduler(
  "linear",
  optimizer=optimizer,
  num_warmup_steps=0,
  num_training_steps=num_training_steps
)

train_dataloader, eval_dataloader, model, optimizer, scheduler = accelerator.prepare(
     train_loader, valid_loader, model, optimizer, lr_scheduler
 )
num_training_steps = num_epochs * len(train_dataloader)

progress_bar = tqdm(range(num_training_steps))

loss_fn = LLMLoss()



model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        features,target = batch
        logits = model(features)
        loss = loss_fn(logits, target)
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        progress_bar.update(1)
      


In [None]:
start_context = "wonderful spring is awaited."
tokenizer = get_tokenizer()
encoded = tokenizer.encode(start_context)
model.to("cpu")

encoded_tensor = torch.tensor(encoded).unsqueeze(0)
model.eval()

out = generate_text(model, encoded_tensor, 5, context_size=50)


In [None]:
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(f"Decoded text: {decoded_text}")

## Save and Load Model

In [None]:
save_directory = str(Path(current_path.parent.parent.absolute(), "bin"))

# save state dictionary
accelerator.wait_for_everyone()
accelerator.save_model(model, save_directory)

In [None]:
accelerator.save_model(model, save_directory, max_shard_size="1GB", safe_serialization=True)

In [None]:
from accelerate import load_checkpoint_in_model

new_model = SLLM(config)
device = accelerator.device
load_checkpoint_in_model(model, save_directory)

In [None]:
start_context = "wonderful spring is awaited."
tokenizer = get_tokenizer()
encoded = tokenizer.encode(start_context)
model.to("cpu")

encoded_tensor = torch.tensor(encoded).unsqueeze(0)
model.eval()

out = generate_text(model, encoded_tensor, 5, context_size=50)
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(f"Decoded text: {decoded_text}")