In [30]:
 %load_ext nb_mypy

The nb_mypy extension is already loaded. To reload it, use:
  %reload_ext nb_mypy


In [31]:
import torch
import torch.nn as nn
from typing import Optional

torch.manual_seed(1337)

<torch._C.Generator at 0x130e2ac70>

In [32]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [69]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
batch_size = 4
block_size = 8
embed_size = 32


In [34]:
charToIndex : dict[str, int] = { ch:i for i,ch in enumerate(chars) }
indexToChar : dict[int, str] = { i:ch for i,ch in enumerate(chars) }

def encode(text: str) -> list[int]: 
    return [charToIndex[c] for c in text]

def decode(values: list[int]) -> str:
    return ''.join([indexToChar[value] for value in values])


In [35]:
data = torch.tensor(encode(text), dtype=torch.long)

In [36]:
n = int(0.9 * len(data))

train_data = data[:n]
val_data  = data[n:]

In [37]:
def get_batch(train = True):
    data = train_data if train else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [41]:
class Head(nn.Module):
    def __init__(self, block_size, embed_size, head_size):
        super().__init__()
        self.key = torch.nn.Linear(embed_size, head_size, bias=False)
        self.query = torch.nn.Linear(embed_size, head_size, bias=False)
        self.value = torch.nn.Linear(embed_size, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        v = self.value(x) # (B, T, head_size)

        wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, T)
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf')) # future tokens have -inf affinities
        wei = torch.nn.functional.softmax(wei, dim=-1)

        return wei @ v # (B, T, head_size)

In [75]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, block_size, embed_size, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(block_size, embed_size, head_size) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

In [76]:
class TransformerLanguageModel(nn.Module):
    def __init__(self, block_size, embed_size, vocab_size, head_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        self.position_embedding_table = nn.Embedding(block_size, embed_size)

        self.sa_heads = MultiHeadAttention(4, block_size, embed_size, (int) (embed_size / 4))
        self.ln_head = nn.Linear(embed_size, vocab_size)
    
    # idx and targets are both (batch_size, block_size) tensors of integers
    def forward(self, idx, targets = None) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        B, T = idx.shape

        token_embeddings = self.token_embedding_table(idx) # (B=batch_size, T=block_size, C=embed_size)
        position_embeddings = self.position_embedding_table(torch.arange(T))
        x = token_embeddings + position_embeddings # (B, T, C)
        x = self.sa_heads(x) # apply one head of self-attention (B, T, head_size)
        logits = self.ln_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            loss = nn.functional.cross_entropy(logits.view(B*T, C), targets.view(B*T))
        
        return logits, loss

    def generate(self, idx, max_new_tokens) -> torch.Tensor:
        # idx is (B, T) array of indexes
        for _ in range(max_new_tokens):
            # crop idx to the block_size
            idx_crop = idx[:, -block_size:]
            logits, _ = self.forward(idx_crop)
            logits = logits[:, -1, :] # becomes (B, C)
            probs = nn.functional.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

In [77]:
import torch

torch.manual_seed(1337)

m = TransformerLanguageModel(block_size, embed_size, vocab_size, head_size)

xb, yb = get_batch()
logits, loss = m.forward(xb, yb)

print(logits.shape)
print(loss)

print(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=100))
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x32 and 16x65)

In [48]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32

for steps in range(1000):
    xb, yb = get_batch(True)
    logits, loss = m.forward(xb, yb)

    if loss is not None:    
        optimizer.zero_grad(True)
        loss.backward()
        optimizer.step()

if loss is not None:            
    print(loss.item())

2.634603261947632


In [49]:
eval_iters = 1000

@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in [True, False]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            xs, ys = get_batch(split)
            logits, loss = m.forward(xs, ys)
            if loss is not None:
                losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

In [52]:
max_iters = 5000
eval_interval = 300

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses[True]:.4f}, val loss {losses[False]:.4f}")
    
    xb, yb = get_batch(True)
    logits, loss = m.forward(xb, yb)

    if loss is not None:    
        optimizer.zero_grad(True)
        loss.backward()
        optimizer.step()

step 0: train loss 2.4125, val loss 2.4323
step 300: train loss 2.4106, val loss 2.4243
step 600: train loss 2.4095, val loss 2.4220
step 900: train loss 2.4045, val loss 2.4270
step 1200: train loss 2.4021, val loss 2.4205
step 1500: train loss 2.3964, val loss 2.4143
step 1800: train loss 2.3917, val loss 2.4157
step 2100: train loss 2.3943, val loss 2.4085
step 2400: train loss 2.3899, val loss 2.4129
step 2700: train loss 2.3921, val loss 2.4053
step 3000: train loss 2.3819, val loss 2.4074
step 3300: train loss 2.3835, val loss 2.4063
step 3600: train loss 2.3814, val loss 2.4024
step 3900: train loss 2.3777, val loss 2.4000
step 4200: train loss 2.3717, val loss 2.4040
step 4500: train loss 2.3758, val loss 2.4026
step 4800: train loss 2.3765, val loss 2.4015


In [53]:
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


PARWAUS:
Anig sper tlad berer fingthon
Om'l seroreofl thourt ybenidr to theatror; lel,
Haly wathinid nsmeem mee so bedon dqou hnk's is Raver; nowo ilsant Meamans. I YO.

ILIRIEMCUZAELe ta's; pseatur fesray, owhow shisowthere, tallo shackif dor yorando lileriefis vicheprs chintoo wit labelil thinor the, bikerer asefigie tothaun thatino cere'save bar?

Th
Wit aly!

CAK:
B
intou sone.

OD:e-
T rsen mbame tiret tiners! ighouef I'teacofr hy himam istimay aivis et
At I of haty at spo ome gorneran fofe
