In [110]:
import math
import torch 
import torch.nn as nn
from einops import einsum, rearrange

torch.manual_seed(123);

In [111]:
class LayerNorm(nn.Module):
    def __init__(self, cfg, eps: float = 1e-5):
        super().__init__()
        self.eps = eps 
        self.scale = nn.Parameter(torch.ones(cfg["d_model"]))
        self.shift = nn.Parameter(torch.zeros(cfg["d_model"]))
        
    def forward(self, x):
        x_mean = x.mean(dim=-1, keepdim=True)
        # x_var = (x - x_mean).pow(2).mean(dim=-1, keepdim=True)
        x_var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - x_mean) * torch.rsqrt(x_var + self.eps)
        return (x_norm * self.scale) + self.shift

class RMSNorm(nn.Module):
    def __init__(self, cfg, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(cfg["d_model"]))

    def forward(self, x):
        var = x.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x * torch.rsqrt(var + self.eps)
        return x_norm * self.scale

In [112]:
class GeLU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2/math.pi) * (x + 0.044715 * x**3)))

class SiLU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

In [113]:
class FeedForwardGPT2(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["d_model"], cfg["hidden_dim"], bias=cfg["qkv_bias"])
        self.fc2 = nn.Linear(cfg["hidden_dim"], cfg["d_model"], bias=cfg["qkv_bias"])
        self.gelu = GeLU()

    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))

class FeedForwardLlama2(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["d_model"], cfg["hidden_dim"], bias=False)
        self.fc2 = nn.Linear(cfg["d_model"], cfg["hidden_dim"], bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["d_model"], bias=False)
        self.silu = SiLU()

    def forward(self, x):
        return self.fc3(self.silu(self.fc1(x)) * self.fc2(x))

In [114]:
class RoPE(nn.Module):
    pass

In [137]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, n_heads, ctx_len, qkv_bias=False):
        super().__init__()
        assert d_out % n_heads == 0, "d_out must be divisible by n_heads!"

        self.n_heads = n_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)

        mask = torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1)
        self.register_buffer("mask", mask)

        ## KV cache
        self.register_buffer("cache_k", None, persistent=False)
        self.register_buffer("cache_v", None, persistent=False)
        self.ptr_current_pos = 0

    def forward(self, x, use_cache=False):
        batch_size, seq_len, d_in = x.shape
        queries = self.W_query(x) # (batch_size, seq_len, d_out)
        keys = self.W_key(x)
        values = self.W_value(x)

        if use_cache:
            if self.cache_k is None:
                self.cache_k = keys
                self.cache_v = values
            else:
                self.cache_k = torch.cat((self.cache_k, keys), dim=1)
                self.cache_v = torch.cat((self.cache_v, values), dim=1)

            keys = self.cache_k
            values = self.cache_v

        queries = rearrange(queries, "batch_size seq_len (n_heads head_dim) -> batch_size n_heads seq_len head_dim", n_heads=self.n_heads)
        keys = rearrange(keys, "batch_size seq_len (n_heads head_dim) -> batch_size n_heads seq_len head_dim", n_heads=self.n_heads)
        values = rearrange(values, "batch_size seq_len (n_heads head_dim) -> batch_size n_heads seq_len head_dim", n_heads=self.n_heads)

        attn_scores = einsum(queries, keys, "... s1 head_dim, ... s2 head_dim -> ... s1 s2")
        
        num_tokens_Q = queries.shape[-2]
        num_tokens_K = keys.shape[-2]
        if use_cache:
            attn_scores.masked_fill_(self.mask.bool()[self.ptr_current_pos:self.ptr_current_pos+num_tokens_Q, :num_tokens_K], -torch.inf)
            self.ptr_current_pos += num_tokens_Q
        else:
            attn_scores.masked_fill_(self.mask.bool()[:num_tokens_Q,:num_tokens_K], -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vecs = einsum(attn_weights, values, "... s1 s2, ... s2 head_dim -> ... s1 head_dim")
        context_vecs = rearrange(context_vecs, "batch_size n_heads seq_len head_dim -> batch_size seq_len (n_heads head_dim)")
        context_vecs = self.out_proj(context_vecs)
        return context_vecs

    def reset_cache(self):
        self.cache_k, self.cache_v = None, None 
        self.ptr_current_pos = 0

In [138]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        
        self.mha = MultiHeadAttention(
            d_in=cfg["d_model"], 
            d_out=cfg["d_model"], 
            n_heads=cfg["n_heads"], 
            ctx_len=cfg["ctx_len"], 
            qkv_bias=cfg["qkv_bias"]
        )
        self.ff = FeedForwardGPT2(cfg)
        self.norm1 = LayerNorm(cfg)
        self.norm2 = LayerNorm(cfg)

    def forward(self, x, use_cache=False):
        x = x + self.mha(self.norm1(x), use_cache=use_cache)
        x = x + self.ff(self.norm2(x))
        return x

In [None]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["d_model"])
        self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["d_model"])
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        ) 
        self.final_norm = LayerNorm(cfg)
        self.out_head = nn.Linear(cfg["d_model"], cfg["vocab_size"], bias=cfg["qkv_bias"])

        self.current_pos = 0

    def forward(self, in_idx, use_cache=False):
        batch_size, seq_len = in_idx.shape

        tok_embeds = self.tok_emb(in_idx)
        if use_cache:
            pos_ids = torch.arange(self.current_pos, self.current_pos+seq_len)
            self.current_pos += seq_len
        else:
            pos_ids = torch.arange(seq_len)

        pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) # create batch_size dim

        x = tok_embeds + pos_embeds
        for blk in self.transformer_blocks:
            x = blk(x, use_cache=use_cache)

        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

    def reset_kv_cache(self):
        for blk in self.transformer_blocks:
            blk.mha.reset_cache()
        self.current_pos = 0

In [143]:
def generate_text_simple(model, idx, max_new_tokens, ctx_len):
    model.eval()
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ctx_len:] # (batch_size, seq_len)
        with torch.no_grad():
            logits = model(idx_cond) # (batch_size, seq_len, vocab_size)
        logits = logits[:,-1,:]
        next_idx = logits.argmax(dim=-1, keepdim=True) # (batch_size, 1)
        idx = torch.cat((idx, next_idx), dim=1) # (batch_size, seq_len+1)
    
    return idx

def generate_text_cached(model, idx, max_new_tokens, ctx_len):
    model.eval()
    logits = model(idx[:,-ctx_len:], use_cache=True)
    for _ in range(max_new_tokens):
        next_idx = logits[:,-1,:].argmax(dim=-1, keepdim=True)
        idx = torch.cat((idx, next_idx), dim=1)
        logits = model(next_idx, use_cache=True)
    
    return idx

In [145]:
import time
import tiktoken

GPT_CONFIG_124M = {
    "vocab_size": 50257,     # Vocabulary size
    "ctx_len": 1024,  # Context length
    "d_model": 768,          # Embedding dimension
    "hidden_dim": 768*4,
    "n_heads": 12,           # Number of attention heads
    "n_layers": 12,          # Number of layers
    "drop_rate": 0.1,        # Dropout rate
    "qkv_bias": False        # Query-Key-Value bias
}

model = GPTModel(GPT_CONFIG_124M)
model.eval()

start_context = "Hello, I am"
tokenizer = tiktoken.get_encoding("gpt2")
encoded = tokenizer.encode(start_context)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print(encoded_tensor)

## naive
start = time.perf_counter()

token_ids = generate_text_simple(model, encoded_tensor, max_new_tokens=20, ctx_len=200)
decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
print(decoded_text)

end = time.perf_counter()
print(f"elapsed time: {end - start:.2f}s")

## cached
start = time.perf_counter()

token_ids = generate_text_cached(model, encoded_tensor, max_new_tokens=20, ctx_len=200)
decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
print(decoded_text)

end = time.perf_counter()
print(f"elapsed time: {end - start:.2f}s")

tensor([[15496,    11,   314,   716]])
Hello, I am therNeilliveicans Vancecf believes Tours abbrevioard noticing pilesgenonis Merry THEM Meshcol centered Mirror
elapsed time: 0.64s
Hello, I am therNeilliveicans Vancecf believes Tours abbrevioard noticing pilesgenonis Merry THEM Meshcol centered Mirror
elapsed time: 0.23s
