In [494]:
import torch
import torch.nn as nn
import numpy as np  # reserved for later use
import einops
import time

In [495]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

In [496]:
class Config():
    def __init__(self):
        self.vocab_size = 50257
        self.embedding_dim = 1024
        self.mlp_dim = 4 * self.embedding_dim
        self.num_blocks = 4
        self.num_heads = 8
        self.context_len = 1024
        self.attention_dim = self.embedding_dim // self.num_heads
        # mps (Apple Silicon) support, reserved for further training
        self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

        # hyperparams
        self.lr = 3e-4
        self.betas = (0.9, 0.95) # for controlling momentum
        self.eps = 1e-8
        self.weight_decay = 0.1

        self.learnable_params_dict = {"embedding": self.vocab_size * self.embedding_dim, "positional_embedding": self.context_len * self.embedding_dim, "MLPs (Weights)": self.num_blocks * 2 * self.embedding_dim * self.mlp_dim, "MLPs (Biases)": self.num_blocks * (self.mlp_dim + self.embedding_dim), "W_Qs": self.num_blocks * self.embedding_dim * self.embedding_dim, "W_Ks": self.num_blocks * self.embedding_dim * self.embedding_dim, "W_Vs": self.num_blocks * self.embedding_dim * self.embedding_dim, "W_Out": self.num_blocks * self.embedding_dim * self.embedding_dim}
        self.learnable_params = (lambda d: sum(d.values()))(self.learnable_params_dict)

        self.non_learnable_params_dict = {"deembedding (tied to embedding weights)": self.vocab_size * self.embedding_dim}
        self.non_learnable_params = (lambda d: sum(d.values()))(self.non_learnable_params_dict)

config = Config()

In [497]:
class DataLoader:
    def __init__(self, B, T):
        self.batch_size = B # num of sequences processed together in each batch
        self.seq_len = T # how many tokens are in each sequence/batch
    
        with open("tiny_shakespeare.txt", "r") as f:
            text = f.read()
        
        encoding = tokenizer.encode(text)
        self.tokens = torch.tensor(encoding)

        self.current_pos = 0 # maintain the index of the current data sample

        print(f"loaded {len(self.tokens)} tokens with batch size of {self.batch_size} sequences and {self.seq_len} tokens per sequence in the batch")
        print(f"each epoch has {len(self.tokens) / (self.batch_size * self.seq_len)} batches, with {self.seq_len * self.batch_size} tokens per batch, for a total of {self.seq_len * self.batch_size * (len(self.tokens) / (self.batch_size * self.seq_len))} tokens")
        print("*"*50)
        
    def next_batch(self):
        B, T = self.batch_size, self.seq_len
        x = self.tokens[self.current_pos:self.current_pos+B*T]
        y = self.tokens[self.current_pos+1:self.current_pos+B*T+1]
        x = x.view(B, T)
        y = y.view(B, T)
        self.current_pos += B * T

        if (len(self.tokens) - self.current_pos + 1) < B * T:
            self.current_pos = 0

        return x, y

In [498]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_E = nn.Embedding(config.vocab_size, config.embedding_dim)
        self.W_pos = nn.Embedding(config.context_len, config.embedding_dim)

    def forward(self, tokens):
        # Don't convert to tensor if already a tensor
        if not isinstance(tokens, torch.Tensor):
            tokens = torch.tensor(tokens)
        
        embeddings = self.W_E(tokens)

        # Create positions tensor on the same device as tokens
        positions = torch.arange(tokens.shape[1], device=tokens.device)
        position_embeddings = self.W_pos(positions)

        return embeddings + position_embeddings

class DeEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_D = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)

    def forward(self, x):
        embeddings = self.W_D(x)

        return embeddings

In [499]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.W_Q = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        self.W_K = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        self.W_V = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        
        self.W_out = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)

        self.num_heads = config.num_heads
        self.attention_dim = config.attention_dim
        
        # Register causal mask as buffer
        self.register_buffer("causal_mask", torch.tril(torch.ones(config.context_len, config.context_len)))
    
    def forward(self, x):
        B, T, C = x.shape
        
        Q = einops.rearrange(self.W_Q(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)
        K = einops.rearrange(self.W_K(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)
        V = einops.rearrange(self.W_V(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)

        # Calculate attention scores
        scores = (Q @ K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.attention_dim))
        
        # Apply causal mask
        scores = scores.masked_fill(self.causal_mask[:T, :T] == 0, float('-inf'))
        
        # Apply softmax
        QK = torch.softmax(scores, dim=-1)

        QKV = einops.rearrange(QK @ V, 'batch head seq dim -> batch seq (head dim)', head=self.num_heads)
        QKV_Out = self.W_out(QKV)

        return QKV_Out

In [500]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer1 = nn.Linear(config.embedding_dim, config.mlp_dim)
        self.gelu = nn.GELU()
        self.layer2 = nn.Linear(config.mlp_dim, config.embedding_dim)
    
    def forward(self, x):
        x = self.layer2(self.gelu(self.layer1(x)))

        return x

In [501]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.Attention_Layers = Attention(config)
        self.MLP_Layers = MLP(config)

    def forward(self, x):
        x = x + self.Attention_Layers(x)
        return x + self.MLP_Layers(x)

In [502]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embed = Embedding(config)

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for i in range(config.num_blocks)
        ])

        self.deembed = DeEmbedding(config)
        self.deembed.W_D.weight = self.embed.W_E.weight  # tie weights

    def forward(self, x):

        x = self.embed(x)
        # print(f"after embed: {x}")

        for block in self.blocks:
            x = block(x)
        
        x = self.deembed(x)
        # print(f"after deembed: {x}")
        x = torch.softmax(x, dim=-1)

        return x

In [503]:
def inference(inference_config, inference_model):
    text = "They fear us"
    tokens = tokenizer.encode(text)
    x = torch.tensor(tokens)
    x = x.unsqueeze(0)
    
    # Move input tensor to the same device as the model
    if inference_model:
        print("using passed in model for inference")
        device = next(inference_model.parameters()).device
        x = x.to(device)
    else:
        inference_model = Transformer(inference_config)
        print("using random model for inference")
        device = inference_config.device
        x = x.to(device)

    out = inference_model(x)
    pred_tokens = out.argmax(dim=-1)
    print(f"predicted tokens: {pred_tokens}")

    # Only take the prediction from the last position (next token after "fox")
    next_token = pred_tokens[0, -1].item()
    predicted_word = tokenizer.decode([next_token])
    print(f"predicted word: {predicted_word}")

    print(f"full sentence: {text}{predicted_word}")

    print("*"*50)
    print("sanity check: all predicted tokens")
    for num, token in enumerate(pred_tokens.flatten()):
        decoded = tokenizer.decode([token.item()])
        
        if num == (len(pred_tokens.flatten()) - 1):
            print(f"** Token {token} -> '{decoded}' **")
        else:
            print(f"Token {token} -> '{decoded}'")

def training(model_config):
    device = model_config.device
    print(f"Using device: {device}")

    torch.manual_seed(42)

    if device == "cuda":
        torch.cuda.manual_seed(42)

    train_loader = DataLoader(8, 1024)

    model = Transformer(model_config).to(device)
    # model = torch.compile(model) # temporary comment to resolve errors with metal
    losses = []

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8, weight_decay = 0.1)

    return model

def display_params_info():
    print(f"learnable params dict: {config.learnable_params_dict}")
    print(f"total # of learnable params: {config.learnable_params:,}")
    print(f"non-learnable params dict: {config.non_learnable_params_dict}")
    print(f"total # of non-learnable params: {config.non_learnable_params:,}")
    print(f"** total # of params: {(config.learnable_params + config.non_learnable_params):,} **")

if __name__ == "__main__":
    display_params_info()
    model = training(config)
    inference(config, model)

learnable params dict: {'embedding': 51463168, 'positional_embedding': 1048576, 'MLPs (Weights)': 33554432, 'MLPs (Biases)': 20480, 'W_Qs': 4194304, 'W_Ks': 4194304, 'W_Vs': 4194304, 'W_Out': 4194304}
total # of learnable params: 102,863,872
non-learnable params dict: {'deembedding (tied to embedding weights)': 51463168}
total # of non-learnable params: 51,463,168
** total # of params: 154,327,040 **
Using device: mps
loaded 338025 tokens with batch size of 8 sequences and 1024 tokens per sequence in the batch
each epoch has 41.2628173828125 batches, with 8192 tokens per batch, for a total of 338025.0 tokens
**************************************************
using passed in model for inference
predicted tokens: tensor([[2990, 3252,  514]], device='mps:0')
predicted word:  us
full sentence: They fear us us
**************************************************
sanity check: all predicted tokens
Token 2990 -> 'They'
Token 3252 -> ' fear'
** Token 514 -> ' us' **


In [504]:
# %% [markdown]
# ## TODOs
# - [x] working training
# - [ ] Add support for displaying calculated # of learnable and non-learnable params
# - [ ] Attention sink?