<a href="https://colab.research.google.com/github/stormliucong/llm-junkyard/blob/main/Training_a_LLM_with_Shakespear.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')
path = "/content/drive/MyDrive/Colab Notebooks/gpt2"
import os
if not os.path.exists(path):
    os.makedirs(path)
%cd $path

In [None]:
%pip install nbstripout
nbstripout Training-a-LLM-with-Shakespear.ipynb

In [None]:
%pip install torch

# Training of GPT2 style LLM

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import os
# import tqdm
from tqdm.notebook import tqdm
import numpy as np
import requests
from torch.utils.data import Dataset, DataLoader
import tiktoken
import random
import string

## GPT2 Model

In [59]:
# GPT2 Model
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, heads, max_seq, rope=True):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.heads = heads
        self.head_dim = d_model // heads
        assert self.head_dim * heads == d_model, "d_model must be divisible by heads"
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.output_linear = nn.Linear(d_model, d_model)
        if rope:
            self.rotary_embedding = RotaryPositionalEmbedding(self.head_dim, max_seq)
        else:
            self.rotary_embedding = None

    def forward(self, query, key, value, mask=None):
        # b, n, s, d
        batch_size = query.size(0)
        # linear
        query = self.query_linear(query)
        key = self.key_linear(key)
        value = self.value_linear(value)
        # split
        query = query.view(batch_size, -1, self.heads, self.head_dim).transpose(1, 2) # b,n,s,d
        key = key.view(batch_size, -1, self.heads, self.head_dim).transpose(1, 2)

        # optional
        # RoPE for positional encoding
        if self.rotary_embedding is not None:
          query = self.rotary_embedding(query)
          key = self.rotary_embedding(key)
        value = value.view(batch_size, -1, self.heads, self.head_dim).transpose(1, 2)
        # attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        else:
            scores = scores.masked_fill(torch.tril(torch.ones_like(scores)) == 0, -1e9)
        # softmax
        attention = F.softmax(scores, dim=-1) # b, n, s, s
        # output
        output = torch.matmul(attention, value).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        # linear projection
        return self.output_linear(output)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, heads, d_ff, max_seq, dropout=0.1, rope=False):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(d_model, heads, max_seq, rope)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        # self.norm1 = nn.LayerNorm(d_model)
        # self.norm2 = nn.LayerNorm(d_model)
        # switch to RMSNorm(d_model)
        self.norm1 = nn.RMSNorm(d_model)
        self.norm2 = nn.RMSNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # pre residue connection
        residue = x
        # attention
        x = self.attention(x, x, x, mask)
        x = self.dropout(x)
        x = self.norm1(x + residue)
        # feed forward
        residue = x
        x = self.feed_forward(x)
        x = self.dropout(x)
        x = self.norm2(x + residue)
        return x

class GPT2(nn.Module):
    def __init__(self, vocab_size, d_model, max_seq, n_layers, heads, d_ff, dropout=0.1,rope=False):
        super(GPT2, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq, d_model)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(d_model, heads, d_ff, max_seq, dropout, rope) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, vocab_size)
        self.apply(self._init_weights)
        self.d_model = d_model


    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()

    def forward(self, x, mask=None):
        # if mask is None, casual mask will added automatically.
        token_emb = self.embedding(x)
        pos_emb = self.position_embedding(torch.arange(x.size(1), device=x.device))
        x = token_emb + pos_emb
        for transformer in self.transformer_blocks:
            x = transformer(x, mask)
        x = self.fc(x)
        return x

### RoPE - Rotary Position Embedding
- add position into q and k.
- In order to have dot product of q_i, k_j reflect relative position i-j, rotate original by θ_i,
- treate x as complext number x = a+bi, rotate θ gives a (cosθ+sinθ) + b (cosθ-sinθ)i. split x into [x1, x2], after rotate [x1cos-x2sin, x1sin+x2cos]
- add frequencies for different dimension. f = base^-2d_i/d_model

In [58]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_seq, base=10000):
        super(RotaryPositionalEmbedding, self).__init__()
        self.d_model = d_model # this should be d_model // n_headers
        self.max_seq = max_seq
        self.base = base
        # register buffer
        self.register_buffer('frequencies', self._get_frequencies())
        cos, sin = self._get_cos_sin(seq_len=self.max_seq)
        self.register_buffer('cos', cos)
        self.register_buffer('sin', sin)

    def _get_frequencies(self):
        frequencies = 1.0 / (self.base ** (torch.arange(0, self.d_model, 2).float() / self.d_model))
        return frequencies

    def _get_cos_sin(self, seq_len=None, device=None):
        pos_i = torch.arange(seq_len).unsqueeze(1)
        if device is not None:
          pos_i = pos_i.to(device)
        freqs = self.frequencies.unsqueeze(0)
        args = pos_i * freqs
        cos = torch.cos(args)
        sin = torch.sin(args)
        return cos, sin

    def _apply_rope(self, x, cos, sin):
        # split x into x1, x2
        x1, x2 = x.chunk(2, dim=-1)
        assert x1.size() == x2.size()
        # rotate x1, x2
        x1 = x1 * cos - x2 * sin
        x2 = x1 * sin + x2 * cos
        # concatenate
        return torch.cat([x1, x2], dim=-1)

    def forward(self, x):
        if x.size(2) == self.max_seq:
            con, sin = self.cos, self.sin
        else:
            cos, sin = self._get_cos_sin(seq_len=x.size(2), device=x.device)
            con, sin = cos.unsqueeze(0), sin.unsqueeze(0)
        return self._apply_rope(x, con, sin)


### RMSNorm
- LayerNorm: (x-μ)/σ. Shape of μ [batch_size, seq_length, 1].
- weighted parameters: α * (x-μ)/σ + β. σ can be zero, catestrophic cancellation, lead to numeric instability.
- RMSNorm: x/RMS(x): γ * x/RMS(x). Easy computing and avoid precision loss; also better gradient flow, each one is independent not relying on μ.
- LayerNorm(x) = LayerNorm(x + c1). absolute position information is not captured. position [1,2,3] will have the same as position [10,20,30].
- μ is not your friends empirically vulnerable to outliers.
- BatchNorm: (x-μ)/σ. Shape of μ [1, 1, d_model]. Not good for LLM b/c batch variability can be large and different between training and testing. Even tricky with flexible input seq_len


In [30]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super(RMSNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.eps = eps
    def forward(self, x):
        return self.gamma * x / (torch.norm(x, dim=-1, keepdim=True) + self.eps)

# Pytorch has RMSNorm nn.RMSNorm(x)

## Dataset

In [31]:
# Dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.seq = self.tokenizer.encode(data)
        self.seq_num = len(self.seq) - max_length - 1

    def __len__(self):
        return self.seq_num

    def __getitem__(self, idx):
        input = self.seq[idx:idx+self.max_length]
        target = self.seq[idx+1:idx+self.max_length+1]
        return torch.tensor(input, dtype=torch.long), torch.tensor(target, dtype=torch.long)

## Train

In [37]:
# Train
class Trainer:
    def __init__(self, model, device, learning_rate = 1e-3, weight_decay = 0.01, warmup_steps = 1000, max_steps = 10000, gradient_accumulation_steps = 1, grad_clip = 1.0, save_dir = "./checkpoints"):
        self.model = model
        self.optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        self.criterion = nn.CrossEntropyLoss()
        self.device = device
        self.model.to(self.device)
        self.scheduler = self._get_scheduler(warmup_steps, max_steps)
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.grad_clip = grad_clip
        self.save_dir = save_dir
        self.global_step = 0
        self.best_loss = float('inf')
        os.makedirs(save_dir, exist_ok=True)

    def _get_scheduler(self, warmup_steps, max_steps):
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return float(current_step) / float(max(1, warmup_steps))
            # Cosine decay after warmup
            progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps))
            return max(0.0, 0.5 * (1.0 + torch.cos(torch.pi * torch.tensor(progress))))
        return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)

    def train(self, train_dataloader, validate_dateloader, epochs):
        self.model.train()

        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            total_loss = 0
            pb = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}")
            for batch_idx, (data, target) in enumerate(pb):
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output.view(-1, output.size(-1)), target.view(-1))
                loss = loss / self.gradient_accumulation_steps if self.gradient_accumulation_steps > 1 else loss
                loss.backward()
                if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
                    self.optimizer.step()
                    self.scheduler.step()
                    self.global_step += 1
                # update pb for loss
                total_loss += loss.item() * self.gradient_accumulation_steps
                # update pb per 100
                if batch_idx % 100 == 0:
                    pb.set_postfix(loss=f"{total_loss / (batch_idx + 1):4f}")
            # call validate
            self.validate(validate_dateloader)
        print(f"Training finished")
        return 1

    def validate(self, dataloader):
        self.model.eval()
        total_loss = 0
        pb = tqdm(dataloader, desc="Validation")
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(pb):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output.view(-1, output.size(-1)), target.view(-1))
                total_loss += loss.item()
                # update pb per 100
                if batch_idx % 100 == 0:
                    pb.set_postfix(loss=f"{total_loss / (batch_idx + 1):4f}")

        if (total_loss / (batch_idx + 1)) < self.best_loss:
            self.best_loss = (total_loss / (batch_idx + 1))
            self.save_checkpoint()

        return 1

    def save_checkpoint(self, filename='best_model.pt'):
        checkpoint_path = os.path.join(self.save_dir, filename)
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'model_config': self.model_config,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'global_step': self.global_step,
            'best_loss': self.best_loss
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    def load_checkpoint(self, filename='best_model.pt'):
        checkpoint_path = os.path.join(self.save_dir, filename)
        checkpoint = torch.load(checkpoint_path)
        self.model_config = checkpoint['model_config']
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.global_step = checkpoint['global_step']
        self.best_loss = checkpoint['best_loss']
        print(f"Checkpoint loaded from {checkpoint_path}")

## Generate

In [38]:
class TextGenerator:
    def __init__(self, model, tokenizer, device, max_length=100, eos_token=None):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.max_length = max_length
        self.model.to(self.device)
        self.model.eval()
        if eos_token is None:
            self.eos_token = self.tokenizer.eot_token
        else:
            self.eos_token = eos_token

    def generate(self, prompt, num_samples=1, temperature=1.0, top_k=50):
        tokens = self.tokenizer.encode(prompt)
        tokens = torch.tensor(tokens, dtype=torch.long, device=self.device).unsqueeze(0) # 0, s
        for _ in range(num_samples):
            for _ in range(self.max_length - len(tokens)):
                logits = self.model(tokens)
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                filtered_logits = logits.topk(top_k)[0]
                probs = F.softmax(filtered_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                tokens = torch.cat((tokens, next_token), dim=1)
                if next_token.item() == self.eos_token:
                    break
            generated_text = self.tokenizer.decode(tokens[0].tolist())
            print(generated_text)

## Training Shakespear GPT2

In [67]:
# define all training parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# get gpt2 tokenizer and vocab_size
tokenizer = tiktoken.get_encoding("gpt2")
vocab_size = tokenizer.n_vocab
# model parameters
model_config = {
    "vocab_size": vocab_size,
    "d_model": 128,
    "max_seq": 128,
    "n_layers": 4,
    "heads": 4,
    "d_ff": 256,
    "dropout": 0.1,
    "rope": True
}
# training parameters
batch_size=2
learning_rate = 1e-3
weight_decay = 0.01
warmup_steps = 1000
max_steps = 10000
gradient_accumulation_steps = 5
grad_clip = 1.0
epochs = 2
# train and val data
if os.path.exists("data/t8.shakespeare.txt"):
    # if the file exists, read from it
    with open("data/t8.shakespeare.txt", "r") as f:
        data = f.read()
else:
    shakespear_content_url = "https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt"
    # read the content from the URL
    data = requests.get(shakespear_content_url).text
    # save to content to local file
    if not os.path.exists("data"):
        os.makedirs("data")
    with open("data/t8.shakespeare.txt", "w") as f:
        f.write(data)
train_data, val_data = data[:int(len(data)*0.8)], data[int(len(data)*0.8):]
train_dataloader = DataLoader(TextDataset(train_data[1:10000], tokenizer, model_config['max_seq']), batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(TextDataset(val_data[1:1000], tokenizer, model_config['max_seq']), batch_size=batch_size, shuffle=True)
model = GPT2(**model_config)
# init model and trainer
trainer = Trainer(model, device, learning_rate, weight_decay, warmup_steps, max_steps, gradient_accumulation_steps, grad_clip, "checkpoints")
# load model if available
if os.path.exists("checkpoints/best_model.pt"):
    try:
      trainer.load_checkpoint()
    except Exception as e:
      print(e)
# start training
trainer.train(train_dataloader, val_dataloader, epochs)

Checkpoint loaded from checkpoints/best_model.pt
Epoch 1/2


Epoch 1/2:   0%|          | 0/1352 [00:00<?, ?it/s]

Validation:   0%|          | 0/120 [00:00<?, ?it/s]

Epoch 2/2


Epoch 2/2:   0%|          | 0/1352 [00:00<?, ?it/s]

Validation:   0%|          | 0/120 [00:00<?, ?it/s]

Training finished


1

## Inference on Trained GPT2

In [68]:
# generate
# load bst model
model = GPT2(**model_config)
model.load_state_dict(torch.load("checkpoints/best_model.pt")['model_state_dict'])
eos_token = tokenizer.eot_token
text_generator = TextGenerator(model, tokenizer, device, max_length=100, eos_token=eos_token)
prompt = "he is a "
text_generator.generate(prompt, num_samples=1, temperature=1.0, top_k=50)

he is a D!!#!!"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


# New Section