In [1]:
from pathlib import Path

In [2]:
text: str = Path("tiny-shakespeare.txt").read_text()

# Tokenizing

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [4]:
itos = dict(enumerate(chars))
stoi = {v:k for k,v in itos.items()}

In [5]:
def encode(s: str) -> list[int]:
    return [stoi[c] for c in s]

def decode(ints: list[int]) -> str:
    return "".join(itos[i] for i in ints)

In [6]:
assert "yay" == decode(encode("yay"))

In [7]:
import torch

data = torch.tensor(encode(text))

# Test / Train Split

In [8]:
n = int(0.9 * len(data))
train_data = data[:n]
test_data = data[n:]
assert len(data) == len(train_data) + len(test_data)

# Inputs and Targets

In [9]:
block_size = 8 # or context length

In [10]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"context={context}, target={target}")

context=tensor([18]), target=47
context=tensor([18, 47]), target=56
context=tensor([18, 47, 56]), target=57
context=tensor([18, 47, 56, 57]), target=58
context=tensor([18, 47, 56, 57, 58]), target=1
context=tensor([18, 47, 56, 57, 58,  1]), target=15
context=tensor([18, 47, 56, 57, 58,  1, 15]), target=47
context=tensor([18, 47, 56, 57, 58,  1, 15, 47]), target=58


# Batching

In [11]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

from torch import Tensor

def get_batch(data: Tensor) -> Tensor:
    start_ixs = torch.randint(len(data) - block_size, size=(batch_size,))
    xs = torch.stack([data[i:i+block_size] for i in start_ixs])
    ys = torch.stack([data[i+1:i+block_size+1] for i in start_ixs])
    return xs, ys

get_batch(train_data)

(tensor([[24, 43, 58,  5, 57,  1, 46, 43],
         [44, 53, 56,  1, 58, 46, 39, 58],
         [52, 58,  1, 58, 46, 39, 58,  1],
         [25, 17, 27, 10,  0, 21,  1, 54]]),
 tensor([[43, 58,  5, 57,  1, 46, 43, 39],
         [53, 56,  1, 58, 46, 39, 58,  1],
         [58,  1, 58, 46, 39, 58,  1, 46],
         [17, 27, 10,  0, 21,  1, 54, 39]]))

# Efficient Self-Attention Mechanism

In [12]:
import torch.nn as nn
from torch.nn import functional as F

class SelfAttention(nn.Module):
    def __init__(self, head_size: int, dim_embedding: int):
        super().__init__()
        C = dim_embedding
        self.key = nn.Linear(C, head_size, bias=False)
        self.query = nn.Linear(C, head_size, bias=False)
        self.value = nn.Linear(C, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, xs: Tensor) -> Tensor:
        B,T,C = xs.shape
        k = self.key(xs)   # (B, T, 16)
        q = self.query(xs) # (B, T, 16)
        ws =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
        ws = ws * C**-0.5 # scale
        
        ws = ws.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
        ws = F.softmax(ws, dim=-1)
        
        return ws @ self.value(xs)

# GPT Model

In [13]:
torch.manual_seed(1337)

class BGLM(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        dim_embedding = 32 # C
        self.embedding = nn.Embedding(vocab_size, dim_embedding)
        self.pos_embedding = nn.Embedding(block_size, dim_embedding)
        self.sa_head = SelfAttention(head_size=dim_embedding, dim_embedding=dim_embedding)
        self.lm_head = nn.Linear(dim_embedding, vocab_size)

    def forward(self, xs: Tensor, ys: Tensor = None):  # Both size (B,T)
        B, T = xs.shape
        token_emb = self.embedding(xs)  # size (B,T,C)
        pos_emb = self.pos_embedding(torch.arange(T)) # (T,C)
        x = token_emb + pos_emb # (B,T,C)
        x = self.sa_head(x)
        logits = self.lm_head(x) # (B,T,vocab)
        if ys is None:
            loss = None
            return logits, loss
        
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        ys = ys.view(B*T)
        loss = F.cross_entropy(logits, ys)
        return logits, loss

    def generate(self, xs: Tensor, n: int) -> Tensor:
        """Expands each x in xs to have 'n' more tokens"""
        for _ in range(n):
            xs_next = self.generate1(xs[:, -block_size:])
            xs = torch.cat((xs, xs_next), dim=1) # (B,T+1)
        return xs

    def generate1(self, xs: Tensor) -> Tensor:
        logits, _ = self(xs)
        last_timestep = logits[:, -1, :]  # (B,C)
        probs = F.softmax(last_timestep, dim=-1)  # (B,C)
        xs_next = torch.multinomial(probs, num_samples=1) # (B,1)
        return xs_next

m = BGLM(vocab_size)

In [14]:
xs_, ys_ = get_batch(train_data)

In [15]:
logits, loss = m(xs_, ys_)
m.generate1(xs_)

tensor([[41],
        [41],
        [42],
        [50]])

# Text Generation
As we see below, the model is still random.

In [16]:
initial_x = torch.tensor(stoi["\n"]).reshape((1,1))
new_x = m.generate(initial_x, n=100)
decode(new_x[0].tolist())

"\njVJDq:X&edpv,b? rPDACszAS-nkNch-Nryw:$jupUj\n T'PG&GFGP !&:aLjKL$u.qrCpIadhkIXtRBEtnxE:cTmFXpOPq&aZ!Q"

# Evaluate Model

In [17]:
def evaluate_model(m: BGLM) -> dict:
    return dict(
        train_loss=avg_loss(m, train_data),
        test_loss=avg_loss(m, test_data)
    )

@torch.no_grad()
def avg_loss(m: BGLM, data: Tensor) -> float:
    m.eval()
    n_batches = 200
    loss = torch.tensor([_loss(m, data) for _ in range(n_batches)]).mean()
    m.train()
    return loss.tolist()

def _loss(m: BGLM, data: Tensor) -> float:
    xs, ys = get_batch(data)
    _, loss = m(xs, ys)
    return loss

evaluate_model(m)

{'train_loss': 4.202554225921631, 'test_loss': 4.200745582580566}

# Train the model

In [18]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

batch_size = 32
n_iters = 10000
for idx in range(n_iters):
    xs_, ys_ = get_batch(train_data)
    logits, loss = m(xs_, ys_)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if idx % (n_iters / 10) == 0:
        print(evaluate_model(m))

{'train_loss': 4.188299179077148, 'test_loss': 4.18881368637085}
{'train_loss': 2.5251500606536865, 'test_loss': 2.534402847290039}
{'train_loss': 2.4450864791870117, 'test_loss': 2.453768491744995}
{'train_loss': 2.4123029708862305, 'test_loss': 2.428366184234619}
{'train_loss': 2.3954243659973145, 'test_loss': 2.409773349761963}
{'train_loss': 2.402365207672119, 'test_loss': 2.4012603759765625}
{'train_loss': 2.3836829662323, 'test_loss': 2.4104301929473877}
{'train_loss': 2.377206325531006, 'test_loss': 2.390448808670044}
{'train_loss': 2.375992774963379, 'test_loss': 2.378732204437256}
{'train_loss': 2.3648550510406494, 'test_loss': 2.38360857963562}


In [19]:
new_x = m.generate(initial_x, n=500)
print(decode(new_x[0].tolist()))


Tisheassy, save, ome int fr, aspoef
Fo'lsam illvestimy Polod easene atu ly avent, sosu coly, sba sha, llafr ly riothe:
Mand hear ad
Mel, thal anderd tw bough, sheasul,
Af a blu. Shon:
Merd;
Whanchour gren alt keasn he cagtern momousle,
Loll chisu ghrer ame mesend tind be se thas glok kee.

OO,
Porf toen ithis sourst. I ske, stawnd; dso yory fa tad;-
Ber ticontt amon Edingeanch mud ingo.


J'oniceaced by, hourarson to do--
An meand thare tn yo sous, flen ack hre'nt whee, Min hed san dyo es othe m
