# *Lab: LLM Pretraining

Here we directly leverage the decoder architecture we made from previous sections. 
  

In [2]:
import math
from typing import List, Optional, Tuple, Union
import os, urllib
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from omegaconf import OmegaConf
from llm_lab.model.vanilla_decoder import VanillaDecoderModel
from transformers import AutoTokenizer
from datasets import load_dataset

## Data

In [47]:
from torch.utils.data import Dataset, DataLoader

class GPTPretrainDataset(Dataset):
    def __init__(self, text, tokenizer, max_length, stride):
        super().__init__()
        self.input_ids = []
        self.target_ids = []
        
        # Tokenizer the entire text
        token_ids = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
        
        # use a sliding window approach to chunk the input text corpus
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i: i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1] # shift by one
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
            
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

def create_data_loader(text, batch_size=4, max_length=256, 
                       stride=128, shuffle=True, drop_last=True, num_workers=0):
    tokenizer = tiktoken.get_encoding('gpt2')
    
    dataset = GPTPretrainDataset(text=text,
                                 tokenizer=tokenizer,
                                 max_length=max_length,
                                 stride=stride)
    
    data_loader = DataLoader(dataset, 
                             batch_size=batch_size,
                             shuffle=shuffle,
                             drop_last=drop_last,
                             num_workers=num_workers)
    
    return data_loader

In [41]:
def read_text_data(file_path, url):
    
    if not os.path.exists(file_path):
        with urllib.request.urlopen(url) as response:
            text_data = response.read().decode('utf-8')
        with open(file_path, "w", encoding="utf-8") as file:
            file.write(text_data)
    else:
        with open(file_path, "r", encoding="utf-8") as file:
            text_data = file.read()
            
    return text_data

In [50]:
def test_data_component():
    file_path = "the-verdict.txt"
    url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"

    text_data = read_text_data(file_path, url)
    train_loader = create_data_loader(text=text_data)
    for batch in train_loader:
        print(batch)
        break
    
test_data_component()

[tensor([[  423,  4750,   326,  ...,   262,  8216,    13],
        [   11,   508,   550,  ...,   526,   198,   198],
        [  271, 10899,   550,  ..., 29543,  2745,    11],
        [   26,   475,   314,  ...,   287,   683,   438]]), tensor([[ 4750,   326,  9074,  ...,  8216,    13,   314],
        [  508,   550, 18459,  ...,   198,   198,  3347],
        [10899,   550,   366,  ...,  2745,    11,   314],
        [  475,   314,  2936,  ...,   683,   438,   273]])]


In [None]:
def tokenizer_and_chunk_text(examples):

## Training

In [72]:
def compute_batch_loss(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    
    flat_targets = target_batch.flatten() 
    flat_logits = logits.flatten(0, 1)# flatten the first two dimensions 
    loss = F.cross_entropy(flat_logits, flat_targets)
    return loss

def train_model_epoch(model, 
                train_loader, 
                val_loader, 
                optimizer,
                device,
                num_epochs):
    
    train_losses, val_losses, track_token_seen = [],[],[]
    tokens_seen = 0
    global_steps = -1
    
    for epoch in range(num_epochs):
        model.train()
        
        for input_batch, target_batch in train_loader:
            optimizer.zero_grad()
            loss = compute_batch_loss(input_batch, target_batch, model, device)
            loss.backward()
            optimizer.step()
            tokens_seen += input_batch.numel()
            global_steps += 1
            train_losses.append(loss.detach().item())
        
    return train_losses, model

In [None]:
def train_main(model_config, train_settings):
    
    torch.manual_seed(train_settings.seed)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    text_data = read_text_data(train_settings.file_path, train_settings.url)
            
    model = LLM(config=model_config)
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=train_settings.learning_rate,
                                  weight_decay=train_settings.weight_decay)
    
    # set up dataloader
    train_ratio = 0.90
    split_position = int(len(text_data) * train_ratio)
    
    train_data = text_data[:split_position]
    val_data = text_data[split_position:]
    
    train_loader = create_data_loader(text=train_data,
                                      batch_size=train_settings.batch_size,
                                      max_length=model_config.context_length,
                                      stride=train_settings.stride,
                                      drop_last=True,
                                      shuffle=True,
                                        num_workers=0
    )
    
    val_loader = create_data_loader(text=val_data,
                                      batch_size=train_settings.batch_size,
                                      max_length=model_config.context_length,
                                      stride=train_settings.stride,
                                      drop_last=True,
                                      shuffle=True,
                                        num_workers=0
    )
    
    train_losses, model = train_model_epoch(model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                optimizer=optimizer,
                num_epochs=train_settings.num_epochs,
                device=device)
    
    return train_losses, model
    

## Training Entry

In [74]:
if __name__ == '__main__':
    
    model_config = {
        "vocab_size": 50257,    # Vocabulary size
        "context_length": 256,  # Shortened context length (orig: 1024)
        "stride": 128,
        "d_model": 768,         # model dimension
        "num_heads": 12,          # Number of attention heads
        "num_layers": 12,         # Number of layers
        "dropout": 0.1,       # Dropout rate
        "qkv_bias": False       # Query-key-value bias
    }
    
    model_config = OmegaConf.create(model_config)
    train_settings = {
        "learning_rate": 5e-4,
        "num_epochs": 10,
        "batch_size": 2,
        "weight_decay": 0.1,
        "stride": 128,
        "seed": 1,
        "file_path":"./pretraining/the-verdict.txt",
        "url":"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
    }
    
    train_settings = OmegaConf.create(train_settings)
    
    # train model
    train_losses, model = train_main(model_config=model_config,
                       train_settings=train_settings)
    
    print(train_losses)
    # save model
    #torch.save(model.state_dict(), "model.pth")
    

[10.889654159545898, 10.55605697631836, 10.390718460083008, 9.955503463745117, 9.866605758666992, 9.824695587158203, 9.696834564208984, 9.723470687866211, 9.311979293823242, 9.14416790008545, 8.976319313049316, 8.796220779418945, 8.653111457824707, 8.667469024658203, 8.531828880310059, 8.036416053771973, 7.521772861480713, 7.099956512451172, 6.590849876403809, 5.911465644836426, 5.8712005615234375, 5.545934677124023, 5.661034107208252, 5.0880608558654785, 5.054011344909668, 4.862744331359863, 4.2214741706848145, 4.4426798820495605, 4.785312175750732, 3.9887170791625977, 3.7415623664855957, 4.1082444190979, 4.18638801574707, 4.1566009521484375, 2.5507609844207764, 2.6057915687561035, 3.7073099613189697, 2.823371648788452, 2.622760057449341, 3.088858127593994, 1.9870753288269043, 2.2322254180908203, 3.128173351287842, 1.7956585884094238, 3.165675401687622, 1.8216800689697266, 2.4981274604797363, 2.861833095550537, 2.2204294204711914, 2.6053457260131836, 1.7304390668869019, 0.838503956794