# *Lab: Pretraining

Here we present a simplified llama implementation based [Huggingface implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L731) to illustrate different components on the Llama decoder model.

The key components are
* RMS Norm
* Rotary Position Embedding
* Grouped Query Attention
* Feedfoward network (FFN)
  

In [1]:
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

## Model Architecture

### Attention and FFN

In [32]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, context_length, num_heads, dropout, qkv_bias=False):
        super().__init__()
        
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.W_Q = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.W_K = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.W_V = nn.Linear(d_model, d_model, bias=qkv_bias)
        
        self.W_O = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Buffers are not updated during training but are part of the module's state.
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length)))
        
    def forward(self, x):
        # x shape: (batch_size, num_tokens, d_model)
        (batch_size, num_tokens, d_model) = x.shape
        
        queries = self.W_Q(x)
        keys = self.W_K(x)
        values = self.W_V(x)
        
        # change to shape easier for multi-head attention
        # (batch_size, num_tokens, num_heads, d_head)
        keys = keys.view(batch_size, num_tokens, self.num_heads,  self.head_dim)
        queries = queries.view(batch_size, num_tokens, self.num_heads,  self.head_dim)
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        
        queries = queries.transpose(1, 2) # 
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        attention_logits = queries @ keys.transpose(2, 3) / queries.shape[-1] ** 0.5
        
        mask_bool = self.mask.bool()[:num_tokens,:num_tokens]
        
        attention_logits.masked_fill(mask_bool, -torch.inf)
        
        attention_weights = torch.softmax(attention_logits, dim=-1)
        
        # attention_weights: (batch_size, num_heads, num_tokens, num_tokens)
        attention_weights = self.dropout(attention_weights)
        
        # context_vec: (batch_size, num_heads, num_tokens, d_head)
        context_vec = (attention_weights @ values)
        
        context_vec = context_vec.transpose(2, 3).contiguous().view(batch_size, num_tokens, self.d_model)
        
        out_vec = self.W_O(context_vec)
        
        return out_vec
    
class LayerNorm(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(d_model))
        self.shift = nn.Parameter(torch.zeros(d_model))
        
    def forward(self, x):
        mean = torch.mean(x, dim=-1, keepdim=True)
        var = torch.var(x, dim=-1, keepdim=True)
        norm_x = (x - mean) / (var + self.eps) * self.scale + self.shift
        return norm_x
    
class FeedForward(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.up_proj = nn.Linear(d_model, 4 * d_model)
        self.act = nn.GELU()
        self.down_proj = nn.Linear(4 * d_model, d_model)
        
    def forward(self, x):
        
        x = self.up_proj(x)
        x = self.act(x)
        x = self.down_proj(x)
        
        return x
        

In [33]:
class TransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        
        self.att = MultiHeadAttention(d_model = config.d_model, 
                                      context_length=config.context_length, 
                                      num_heads=config.num_heads,
                                      dropout=config.dropout,
                                      qkv_bias=config.qkv_bias)
        
        self.ff = FeedForward(d_model=config.d_model)
        self.norm1 = LayerNorm(d_model=config.d_model)
        self.norm2 = LayerNorm(d_model=config.d_model)
        
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x):
        
        shortcut = x
        # pre-norm
        x = self.norm1(x)
        x = self.att(x)
        x = self.dropout(x)
        
        x = x + shortcut
        
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.dropout(x)
        x = x + shortcut
        
        return x
        

In [34]:

class LLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.context_length, config.d_model)
        self.emb_dropout = nn.Dropout(config.dropout)
        
        self.transformer_backbone = nn.Sequential(*[TransformerLayer(self.config) for _ in range(self.config.num_layers)])
        
        self.final_norm = LayerNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(self, input_idx):
        (batch_size, num_tokens) = input_idx.shape
        
        token_embedings = self.token_emb(input_idx)
        pos_embeddings = self.pos_emb(torch.arange(num_tokens, device=input_idx.device))
        
        input_embeddings = token_embedings + pos_embeddings
        
        x = self.emb_dropout(input_embeddings)
        
        x = self.transformer_backbone(x)
        
        x = self.final_norm(x)
        logits = self.lm_head(x)
        
        return logits 



        
        

In [35]:
from omegaconf import OmegaConf
import tiktoken
def main():
    
    config_dict = {
        "vocab_size": 50257,
        "context_length": 1024,
        "d_model": 768,
        "num_heads": 12,
        "num_layers": 12,
        "dropout": 0.1,
        "qkv_bias": False        
    }
    
    config = OmegaConf.create(config_dict)
    torch.manual_seed(123)
    model = LLM(config)
    
    text = "Hello"
    tokenizer = tiktoken.get_encoding("gpt2")
    input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
    
    print(model(input_ids))
    
main()
    

tensor([[[-0.1013, -0.1647, -0.1755,  ..., -0.0994,  0.0870, -0.3435]]],
       grad_fn=<UnsafeViewBackward0>)


## Data

In [None]:
from torch.utils.data import Dataset, DataLoader

class GPTPretrainDataset(Dataset):
    def __init__(self, text, tokenizer, max_length, stride):
        super().__init__()
        self.input_ids = []
        self.target_ids = []
        
        # Tokenizer the entire text
        token_ids = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
        
        # use a sliding window approach to chunk the input text corpus
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i: i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
            
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

def create_data_loader(text, batch_size=4, max_length=256, 
                       stride=128, shuffle=True, drop_last=True, num_workers=0):
    tokenizer = tiktoken.get_encoding('gpt2')
    
    dataset = GPTPretrainDataset(text=text,
                                 tokenizer=tokenizer,
                                 max_length=max_length,
                                 stride=stride)
    
    data_loader = DataLoader(dataset, 
                             batch_size=batch_size,
                             shuffle=shuffle,
                             drop_last=drop_last,
                             num_workers=num_workers)
    
    return data_loader

In [None]:
def train_model(model, 
                train_loader, 
                val_loader, 
                optimizer,
                device,
                num_epochs):
    
    train_losses, val_losses, track_token_seen = [],[],[]
    tokens_seen = 0
    global_steps = -1
    
    for epoch in range(num_epochs):
        model.train()
        
        for input_batch, target_batch in train_loader:
            