# Decoder (GPT) LLM From Scratch

### Importing necessary modules and setting device

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import mmap
import random
import pickle
import argparse
from tqdm import tqdm
import fmmap as mmap
import random
import pickle
import transformers
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader, IterableDataset
from torch.nn.utils.rnn import pad_sequence

# Setting the GPU (mps on current MacOS) as the device, 
device =  'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

print(device) # Checking that mps is the device used

### Initialising tokenizer and setting hyperparameters

In [None]:
# Load the BERT uncased tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Get the vocabulary size
vocab_size = len(tokenizer)

In [None]:
block_size = 32
batch_size = 16
learning_rate= 3e-3
num_epochs = 500
eval_iters = 100
n_embd = 32
n_head = 1
dropout = 0.2
n_layer= 1

#### Importing large txt file in chunks

In [None]:
class ChunkedTextIterableDataset(IterableDataset):
    def __init__(self, filename, tokenizer, max_length=512):
        self.filename = filename
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __iter__(self):
        with open(self.filename, 'r', encoding='utf-8') as f:
            chunk = []
            for line in f:
                encoded_line = self.tokenizer.encode(line.strip(), add_special_tokens=True)
                chunk.extend(encoded_line)

                if len(chunk) >= self.max_length:
                    yield chunk[:self.max_length]
                    chunk = chunk[self.max_length:]

            if chunk:  # Yielding final chunk
                yield chunk

In [None]:
def collate_batch(batch):
    # Pad all sequences to have the same length
    padded_batch = pad_sequence([torch.tensor(seq) for seq in batch], batch_first=True, padding_value=tokenizer.pad_token_id)
    return padded_batch


In [None]:
train_dataset = DataLoader(ChunkedTextIterableDataset('output_train.txt', tokenizer=tokenizer), 
                                                        batch_size=batch_size,
                                                        shuffle=False,
                                                        collate_fn=collate_batch)


val_dataset = DataLoader(ChunkedTextIterableDataset('output_val.txt', tokenizer=tokenizer),
                                                        batch_size=batch_size,
                                                        shuffle=False,
                                                        collate_fn=collate_batch)

In [None]:
a=0
for i, batch in enumerate(train_dataset):
    print("i:", i, "\n")
    print("batch:", batch, "\n")
    print("batch shape:", batch.shape, "\n")

    a += 1
    if a == 5:
        break


## LLM Architecture

#### Feed Forward Layer (FNN) for non-linearity and regularization

In [None]:
class FeedForward(nn.Module):
    """ simple linear layer followed by activation """
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd),
            nn.Dropout(dropout) # Reduce overfit
        )
    
    def forward(self, x):
        return self.net(x)


#### Head Class for input initialization into Key, Query and Value

In [None]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embd, head_size, bias = False)
        self.query = nn.Linear(n_embd, head_size, bias = False)
        self.value = nn.Linear(n_embd, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # Mask

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        v = self.value(x) # (B, T, head_size)

        # Compute attention scores ("affinities")
        wei = torch.matmul(q, k.transpose(-2, -1)) / (self.head_size ** 0.5)  # (B, T, T)

        # Create a lower-triangular matrix of size (T, T) as mask
        mask = torch.tril(torch.ones((T, T), device=x.device)).view(1, T, T)

        # Apply the mask to the attention scores
        wei = wei.masked_fill(mask == 0, float('-inf'))

        # Apply softmax to get the attention probabilities
        wei = F.softmax(wei, dim=-1)
        
        # Apply dropout to attention weights
        wei = self.dropout(wei)

        # Perform the weighted aggregation of the values
        out = torch.matmul(wei, v)  # (B, T, head_size)
        
        return out

#### Multi-Head Attention Layer for capturing dependencies in data

In [None]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parrallel"""
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.heads = nn.ModuleList([Head(head_size) for _ in range (num_heads)])
        self.proj = nn.Linear(head_size*num_heads, n_embd)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim = -1) # (B, T, F) --> (B, T, [h1,h1,h1,h1, ..., hn,hn,hn,hn])
        out = self.dropout(self.proj(out))
        return out

#### Block layer to sequence self-attention and FFN

In [None]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation"""
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd //n_head 
        self.sa = MultiHeadAttention(n_head,head_size) # Self-attention
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

        
    def forward(self, x):
        y = self.sa(x)
        x = self.ln1(x+y) # Wrap around
        y = self.ffwd(x)
        x = self.ln2(x+y) # Wrap around again
        
        return x

#### GPT Language LLM Model Architecture, combining all above classes

In [None]:
class GPTLanguageModel(nn.Module):
     def __init__ (self, vocab_size):
          super().__init__()
          self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # Token embedding
          self.positional_embedding = nn.Embedding(block_size, n_embd) # Positional Embedding
          self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)]) # Creating our decoder (Block) layers
          self.ln_f = nn.LayerNorm(n_embd) # Final layer norm
          self.lm_head = nn.Linear(n_embd, vocab_size) # Final Linear Layer

          self.apply(self._init_weights)

     def _init_weights(self, module):
          if isinstance(module, nn.Linear):
               torch.nn.init.normal_(module.weight, mean = 0.0, std=0.02)
               if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
          elif isinstance(module, nn.Embedding):
               torch.nn.init.normal_(module.weight, mean =0.0, std=0.02)

     def forward(self, input_ids, labels = None):
         # Getting the batch size (B) and seq_length (T) from the input_ids
          B, T = input_ids.shape

          #Token embeddings
          token_embeddings = self.token_embedding_table(input_ids)              # (B, T, C)

          # Positional Embedding
          position_ids = torch.arange(0,T, dtype=torch.long, device = device)
          positional_embeddings = self.positional_embedding(position_ids)       # (T, C)

          # Combine token and position embeddings
          hidden_states = token_embeddings + positional_embeddings.unsqueeze(0) # (B, T, C)

          # Pass through transformer blocks
          for block in self.blocks:
               hidden_states = block(hidden_states)                             # (B, T, C)

          #Final layer normalization
          hidden_states= self.ln_f(hidden_states)                               #(B, T, C)

          # Obtaining the logits (predictions for each token in the vocabulary)
          logits = self.lm_head(hidden_states)                                  #(B, T, vocab_size)

          # Obtaining Loss
          loss_fn = nn.CrossEntropyLoss()
          loss = loss_fn(logits.view(-1, self.lm_head.out_features), labels.view(-1))

          return logits, loss

     
     def generate(self, index, max_new_tokens):
          # using index of (B,T) array in current context
          for _ in range(max_new_tokens): # Using _ as a placeholder to loop through max new tokens
               #index_cond = index[:,-block_size]
               logits, loss = self.forward(index)
               logits = logits[:,-1,:]
               probs = F.softmax(logits, dim= -1)
               index_next = torch.multinomial(probs, num_samples=1)
               index = torch.cat((index, index_next), dim =1)
          return index

model = GPTLanguageModel(vocab_size)
# with open('model-01.pkl','rb') as f:
#      model = pickle.load(f)

model = model.to(device)
optimizer = torch.optim.AdamW(params = model.parameters(), lr = learning_rate) # Utilising Adam's AdamW variant to obtain weight decay/grad independence


## Training Model

### Creating function to estimate loss during training

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train','test']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    
    model.train()
    return out

### Training Loop

In [None]:
model.train()

for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(train_dataset):
        input_ids = batch[:, :-1].to(device)
        labels = batch[:, 1:].to(device)

        optimizer.zero_grad()

        logits,loss = model(input_ids, labels = labels)

        loss.backward()
        optimizer.step()
    
        if epoch % eval_iters == 0 and batch_idx == 0:
            print(f"Loss at epoch {epoch}: {loss.item()}")


with open('model-01.pkl','wb') as f:
    pickle.dump(model, f)

In [None]:
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(train_dataset):
        input_ids = batch[:, :-1].to(device)
        labels = batch[:, 1:].to(device)

        optimizer.zero_grad()
        
        # Running the model forward pass in mixed precision
        with autocast():
            logits, loss = model(input_ids, labels=labels)
            loss = loss / accumulation_steps  # Normalize loss for gradient accumulation
        
        # Scales loss. Calls backward() on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_dataset):
            # Unscales gradients and calls or skips optimizer.step()
            scaler.step(optimizer)
            scaler.update()  # Updates the scale for next iteration
            optimizer.zero_grad()

        if epoch % eval_interval == 0 and batch_idx == 0:
            print(f"Loss at epoch {epoch}: {loss.item()}")


In [None]:
context = torch.zeros((1,1), dtype = torch.long, device = device)
generated_chars = tokenizer.decode(model.generate(context, max_new_tokens=150)[0].tolist())