In [502]:
import torch
from torch import nn
import torch.nn.functional as F
import einops

In [503]:
with open('/Users/Darrell/Desktop/tiny-shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [504]:
vocab = sorted(list(set(text)))
vocab_size = len(vocab)

batch_size = 64
seq_len = 3
n_embd = 5
head_size = 8
n_heads = 11
depth = 6

lr = 1e-4
wd = 1e-2
betas = (0.9, 0.99)
eps = 1e-8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [505]:
stoi = {s:i for i, s in enumerate(vocab)}
itos = {i:s for i, s in enumerate(vocab)}
tokenize = lambda s: [stoi[c] for c in s]
detokenize = lambda c: ''.join([itos[x] for x in c])

data = torch.tensor(tokenize(text))
n = int(len(data)*.9)
train_data = data[:n]
val_data = data[n:]

In [506]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - seq_len, (batch_size,))
    x = torch.stack([data[i:i+seq_len] for i in ix])
    y = torch.stack([data[i+1: i+1 + seq_len] for i in ix])
    return x, y

In [507]:
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

In [508]:
class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate)*x

In [509]:
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return torch.nn.functional.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

In [510]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn 
        self.drop = nn.Dropout(p=0.1)

    def forward(self, x, **kwargs):
        y = self.fn(x, **kwargs)
        x = self.drop(x)
        return y + x

In [511]:
class Token_Embedding(nn.Module):
    def __init__(self, vocab_size=vocab_size, n_embd=n_embd):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, n_embd)

    def forward(self, x):
        return self.embedding(x)

In [512]:
class RotaryEmbedding(nn.Module):
    def __init__(self, head_dim: int, base=10000):
        super().__init__()
        inv_freq = float(head_dim)/(base ** torch.arange(0, head_dim, 2).float())
        self.register_buffer('inv_freq', inv_freq, persistent=False)
        self.head_dim = head_dim
        self.seq_len_cached = None
        self.batch_size_cached = None
        self.cos_cached: torch.tensor | None = None
        self.sin_cached: torch.tensor | None = None

    def trig(self, seq_len: int, device=device, dtype=torch.bfloat16) -> torch.Tensor:
        if seq_len != self.seq_len_cached: 
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
            freqs = torch.einsum('i,j -> ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).float().to(device)

            self.cos_cached = emb.cos()
            self.sin_cached = emb.sin()

        return self.cos_cached, self.sin_cached
    
    def forward(self, q, k):
        _, _, seq_len, _ = q.shape
        cos, sin = self.trig(seq_len)
        return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

            

In [513]:
class MultiQueryAttentionHead(nn.Module):
    def __init__(self, n_heads, head_size, n_embd, attention_drop = 0.1, ff_drop = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.attention_drop = nn.Dropout(attention_drop)
        self.qkv = nn.Linear(n_embd, n_heads*head_size + 2*head_size)
        self.rotary = RotaryEmbedding(head_size)
        self.LNorm = LayerNorm(n_embd)
        self.ff_out = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(ff_drop),
            nn.Linear(n_heads * head_size, n_embd, bias=False)
        )
    
    def forward(self, x):
        # x = (B, T, E) ---> (B, num_heads, T, h_size)
        B, T, E = x.shape
        x = self.LNorm(x)

        #TODO: see if this qkv thing is actually faster than doing seperate q, k, v 
        qkv = self.qkv(x) #(B, T, n_heads*head_size + 2*head_size)

        # q has shape (B, n_heads, T, head_size)
        q = qkv[:, :, : n_heads*head_size].view((B, T, n_heads, head_size)).transpose(-2, -3)
        # k has shape (B, T, head_size)
        k = qkv[:, :, n_heads*head_size:n_heads*head_size+head_size].view((B, 1, T, head_size))
        # v has shape (B, T, head_size)
        v = qkv[:, :, -head_size:].view((B, 1, T, head_size))
        
        q, k = self.rotary(q, k)

        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.1, is_causal=True)
        y = einops.rearrange(y, 'b h t d -> b t (h d)')
        return self.ff_out(y)

In [514]:
class LaTeXModel(nn.Module):
    def __init__(self, depth, n_heads, head_size, n_embd=n_embd, vocab_size=vocab_size) -> None:
        super().__init__()
        self.n_heads = n_heads
        self.head_size = head_size
        self.n_embd = n_embd

        self.token_embedding = Token_Embedding(n_embd = n_embd)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            block = Residual(MultiQueryAttentionHead(n_heads, head_size, n_embd))
            self.layers.append(block)
        
        self.LNorm = LayerNorm(n_embd)
        self.to_logits = nn.Linear(n_embd, vocab_size, bias=False)

        #self.to_logits.weight = self.token_embedding.weight
        #nn.init.normal_(self.token_embedding.weight, std=0.02)

    def forward(self, x, targets=None):
        x = self.token_embedding(x)

        for layer in self.layers:
            x = layer(x)
        embeds = self.LNorm(x)

        logits = self.to_logits(embeds)
        
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss

In [515]:
model = LaTeXModel(depth, n_heads, head_size)
optimizer = torch.optim.AdamW(model.parameters(), lr = lr, weight_decay = wd, betas = betas, eps = eps)

In [516]:
for iter in range(2):
    # every once in a while evaluate the loss on train and val sets
    #if iter % eval_interval == 0 or iter == max_iters - 1:
        #losses = estimate_loss()
        #print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')
    # evaluate the loss
    logits, loss = model(xb, yb)
    print(loss)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
#context = torch.zeros((64, 256), dtype=torch.long)
#print(detokenize(model.generate(context, max_new_tokens=50)[0].tolist()))


tensor(4.3521, grad_fn=<NllLossBackward0>)
tensor(4.3362, grad_fn=<NllLossBackward0>)


In [517]:
x = torch.ones((64, 3, 88))
y, z = x.chunk(2, dim=-1)

In [518]:
y.shape

torch.Size([64, 3, 44])

In [519]:
z.shape

torch.Size([64, 3, 44])

In [520]:
(F.silu(z) * y).shape

torch.Size([64, 3, 44])