In [1]:
from pathlib import Path

In [2]:
text: str = Path("tiny-shakespeare.txt").read_text()

# Tokenizing

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [4]:
itos = dict(enumerate(chars))
stoi = {v:k for k,v in itos.items()}

In [5]:
def encode(s: str) -> list[int]:
    return [stoi[c] for c in s]

def decode(ints: list[int]) -> str:
    return "".join(itos[i] for i in ints)

In [6]:
assert "yay" == decode(encode("yay"))

In [7]:
import torch

data = torch.tensor(encode(text))

# Test / Train Split

In [8]:
n = int(0.9 * len(data))
train_data = data[:n]
test_data = data[n:]
assert len(data) == len(train_data) + len(test_data)

# Inputs and Targets

In [9]:
block_size = 8 # or context length

In [10]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"context={context}, target={target}")

context=tensor([18]), target=47
context=tensor([18, 47]), target=56
context=tensor([18, 47, 56]), target=57
context=tensor([18, 47, 56, 57]), target=58
context=tensor([18, 47, 56, 57, 58]), target=1
context=tensor([18, 47, 56, 57, 58,  1]), target=15
context=tensor([18, 47, 56, 57, 58,  1, 15]), target=47
context=tensor([18, 47, 56, 57, 58,  1, 15, 47]), target=58


# Batching

In [11]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

from torch import Tensor

def get_batch(data: Tensor) -> Tensor:
    start_ixs = torch.randint(len(data) - block_size, size=(batch_size,))
    xs = torch.stack([data[i:i+block_size] for i in start_ixs])
    ys = torch.stack([data[i+1:i+block_size+1] for i in start_ixs])
    return xs, ys

get_batch(train_data)

(tensor([[24, 43, 58,  5, 57,  1, 46, 43],
         [44, 53, 56,  1, 58, 46, 39, 58],
         [52, 58,  1, 58, 46, 39, 58,  1],
         [25, 17, 27, 10,  0, 21,  1, 54]]),
 tensor([[43, 58,  5, 57,  1, 46, 43, 39],
         [53, 56,  1, 58, 46, 39, 58,  1],
         [58,  1, 58, 46, 39, 58,  1, 46],
         [17, 27, 10,  0, 21,  1, 54, 39]]))

In [12]:
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BGLM(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)

    def forward(self, xs: Tensor, ys: Tensor = None):  # Both size (B,T)
        logits = self.embedding(xs)  # size (B,T,C)
        if ys is None:
            loss = None
            return logits, loss
        
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        ys = ys.view(B*T)
        loss = F.cross_entropy(logits, ys)
        return logits, loss

    def generate(self, xs: Tensor, n: int) -> Tensor:
        """Expands each x in xs to have 'n' more tokens"""
        for _ in range(n):
            xs = self.generate1(xs)
        return xs

    def generate1(self, xs: Tensor) -> Tensor:
        logits, _ = self(xs)
        last_timestep = logits[:, -1, :]  # (B,C)
        probs = F.softmax(last_timestep, dim=1)  # (B,C)
        xs_next = torch.multinomial(probs, num_samples=1) # (B,1)
        return torch.cat((xs, xs_next), dim=1) # (B,T+1)

m = BGLM(vocab_size)

In [13]:
xs_, ys_ = get_batch(train_data)

In [14]:
xs_

tensor([[ 6,  0, 21, 44,  1, 61, 43,  1],
        [58, 52, 43, 57, 57,  2,  1, 57],
        [ 1, 59, 52, 39, 41, 46, 47, 52],
        [43, 42, 50, 39, 56,  6,  1, 50]])

In [15]:
logits, loss = m(xs_, ys_)
m.generate1(xs_)

tensor([[ 6,  0, 21, 44,  1, 61, 43,  1,  7],
        [58, 52, 43, 57, 57,  2,  1, 57, 64],
        [ 1, 59, 52, 39, 41, 46, 47, 52, 29],
        [43, 42, 50, 39, 56,  6,  1, 50, 39]])

# Text Generation
As we see below, the model is still random.

In [16]:
initial_x = torch.tensor(stoi["\n"]).reshape((1,1))
new_x = m.generate(initial_x, n=100)
decode(new_x[0].tolist())

"\noRFJa!JKmRjtXzfN:CERiC-KuDHoiMIB!o3QHN\n,SPyiFhRKuxZOMsB-ZJhsucL:wfzLSPyZalylgQUEU cLq,SqV&vW:hhir'q?"

# Train the model

In [17]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

batch_size = 32
for _ in range(10000):
    xs_, ys_ = get_batch(train_data)
    logits, loss = m(xs_, ys_)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.4290781021118164


In [18]:
new_x = m.generate(initial_x, n=500)
print(decode(new_x[0].tolist()))


Ar bet!
AMAs sod ke alved.
Thup stheve de t
I: ir w, l me sie hend lor ito'l an e

I:
Gochosen ea ar btamandd halind
Aust, plt t wadzotl
I bel qungnqthoth he m he de avellis k'l, tond sorangr?

the tousButhe bott oze, t s d je hid t his Inces I my ig t
Ril'swoll e pupat inouleacands-beriqu heamer te
Wht s

MI wect!-lltherotheve t fe;
WAnd pporury t s ld tathat, ir V:
A thesecin teot tit ado ilorer.
Ply, d'stacoes, ld omat mealellly yererer EMEvesa! ie IZKI pave mautoofareanerllleyomerer but?
The
