In [None]:
# load dataset
from datasets import load_dataset
dataset = load_dataset("roneneldan/TinyStories")

shuffled_train_dataset = dataset["train"].shuffle(seed=42)
shuffled_val_dataset = dataset["validation"].shuffle(seed=42)

# Sample size for training data
train_sample_size = 50000

# Calculate proportional size for validation data
val_sample_size = int((len(shuffled_val_dataset) / len(shuffled_train_dataset)) * train_sample_size)

# Sample the data
sampled_train_dataset = shuffled_train_dataset.select(range(train_sample_size))
sampled_val_dataset = shuffled_val_dataset.select(range(val_sample_size))

dataset = {
    "train": sampled_train_dataset,
    "validation": sampled_val_dataset
}

print(dataset['train'][:10])
print(len(dataset['train']))
model_type = 'naive'

In [None]:
from torch.utils.data import Dataset
class TinyDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

In [None]:
# Tokenize
import sentencepiece as spm
import os
from typing import List
class Tokenizer:
    def __init__(self, model_path):
        assert os.path.isfile(model_path), model_path
        self.sp_model = spm.SentencePieceProcessor(model_file=model_path)
        # BOS / EOS token IDs
        self.vocab_size: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.pad_id()
        print(f"vocab_size: {self.vocab_size} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
    
    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, ids: List[int]) -> str:
        return self.sp_model.decode(ids)

tokenizer = Tokenizer("tokenizer.model")
encoded_text = tokenizer.encode(dataset['train'][0]['text'], bos=True, eos=True)
decoded_text = tokenizer.decode(encoded_text)
assert dataset['train'][0]['text'] == decoded_text

In [None]:
# LLama Architecture (Thanks GPT4)

# Input
#   |
#   v
# Embeddings
#   |
#   v
# RMS Norm
#   |
#   v
# (Start of Nx transformer blocks)
# Self-Attention (Grouped Multi-Query Attention) with KV Cache
#   |                         |
#   |                         |-----------------,
#   v                         v                 |
# Q (Rotary)              K (Rotary)           V
#   |                         |                 |
#   '----------+--------------'                 |
#              |                                |
#              +---->(Add & Norm)               |
#              |                                |
#              v                                |
#         RMS Norm                              |
#              |                                |
#              v                                |
#   Feed Forward SwiGLU                         |
#              |                                |
#              +---->(Add & Norm)<--------------+
#              |
# (End of Nx transformer blocks)
#              |
#              v
#            Linear
#              |
#              v
#           Softmax


In [None]:
import torch
import numpy as np
# Utilities
def prep_data(dataset, tokenizer, device):
    new_dataset = {}
    # add eos and bos tokens
    for split in dataset.keys():
        new_dataset[split] = []
        for data in dataset[split]:
            data['text'] = f"{tokenizer.bos_id}{data['text']}{tokenizer.eos_id}"
            new_dataset[split].append(data)
    
    encoded_ds = {}
    for split in new_dataset.keys():
        combined_corpus = "".join([x['text'] for x in new_dataset[split]])
        encoded_ds[split] = torch.tensor(tokenizer.encode(combined_corpus, bos=False, eos=False)).to(device)
    return encoded_ds

def get_ds(encoded_ds, split, config, device):
    ds = encoded_ds[split]
    ix = torch.randint(0, ds.size(0) - config['context_window'] - 1, (config['batch_size'],))
    x = torch.stack([ds[i:i+config['context_window']] for i in ix]).long().to(device)
    y = torch.stack([ds[i+1:i+config['context_window']+1] for i in ix]).long().to(device)
    return x, y

@torch.no_grad()  # don't compute gradients for this function
def evaluate_loss(model, config, device, ds):
    out = {}
    model.eval()
    for split in ds.keys():
        losses = []
        for _ in range(10):
            xb, yb = get_ds(ds, split, config, device)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        out[split] = np.mean(losses)
    model.train()
    return out

In [None]:
from torch import nn
import torch

# Llama Model modules

class RMSNorm(nn.Module):
    # RMS = x / sqrt(E[x^2] + eps)
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x: torch.Tensor):
        rms_norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        # if not self.training:
        #     print("++++++++++++++++++++++++++++++++++++++")
        #     print(rms_norm.shape)
        #     print(self.weight.shape)
        #     print("++++++++++++++++++++++++++++++++++++++")
        return self.weight * rms_norm

In [None]:
# Naive Model
from torch import nn
from torch.nn import functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device: ", device)
class NaiveModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
        self.rms = RMSNorm((config['context_window'], config['d_model']))
        self.linear = nn.Sequential(
            nn.Linear(config['d_model'], config['d_model']),
            nn.ReLU(),
            nn.Linear(config['d_model'], config['vocab_size'])
        )
        print("Model params: ", sum([m.numel() for m in self.parameters()]))
    
    def forward(self, idx, targets=None):
        x = self.embedding(idx)
        x = self.rms(x)
        logits = self.linear(x)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1)).to(device)
            return logits, loss
        else:
            return logits

config = {
    'vocab_size': tokenizer.vocab_size,
    'batch_size': 32,
    'context_window': 64,
    'd_model': 128,
}
if model_type == 'naive':
    model = NaiveModel(config).to(device)
    encoded_ds = prep_data(dataset, tokenizer, device)
    xs, ys = get_ds(encoded_ds, 'train', config, device)
    logits, loss = model(xs, ys)
    print(loss)

In [None]:
# Llama model
n_transformer_blocks = 1


In [None]:
# train
from torch.optim import Adam
import time
import pandas as pd
config.update({
    'lr': 1e-3,
    'epochs': 1000,
    'log_interval': 10,
})
optimizer = Adam(model.parameters())
def train(model, optimizer, config):
    losses = []
    start_time = time.time()
    for epoch in range(config['epochs']):
        optimizer.zero_grad()
        xs, ys = get_ds(encoded_ds, 'train', config, device)
        logits, loss = model(xs, ys)
        loss.backward()
        optimizer.step()
    
        if epoch % config['log_interval'] == 0:
            batch_time = time.time() - start_time
            evaluated_loss = evaluate_loss(model, config, device, encoded_ds)
            losses += [evaluated_loss]
            print(f"Epoch {epoch} | val loss {evaluated_loss['validation']:.3f} | Time {batch_time:.3f} | ETA in seconds {batch_time * (config['epochs'] - epoch)/config['log_interval'] :.3f}")
            start_time = time.time()
    print("validation loss: ", losses[-1]['validation'])
    return pd.DataFrame(losses).plot()

train(model, optimizer, config)

In [None]:
# Inference
def generate(model, prompt, tokenizer, config, device, max_new_tokens, temperature):
    model.eval()
    encoded_prompt = torch.tensor(tokenizer.encode(prompt, bos=True, eos=False)).to(device).unsqueeze(0).to(device)
    # pad tokens
    if encoded_prompt.size(1) < config['context_window']:
        encoded_prompt = F.pad(encoded_prompt, (config['context_window'] - encoded_prompt.size(1), 0), value=tokenizer.pad_id).to(device)
    print(f"encoded prompt shape: {encoded_prompt.shape}")
    next_tokens = []
    for _ in range(max_new_tokens):
        # x = encoded_prompt[-config['context_window']:].to(device)
        logits = model(encoded_prompt).to(device)
        probs = F.softmax(logits[:, -1, :], dim=-1).to(device)
        next_token = torch.multinomial(probs / temperature, num_samples=1).to(device)
        next_tokens.append(next_token.item())
        encoded_prompt = torch.cat([encoded_prompt[1:], next_token], dim=0).to(device)
        print("encoded prompt: ", encoded_prompt)
    print(tokenizer.decode(next_tokens))

generate(model, "Once upon ", tokenizer, config, device, 10, 0.7)