In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random

# --- 1. Data Setup ---
words = open('names.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
block_size = 3

def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    return torch.tensor(X), torch.tensor(Y)

random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))
Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

# --- 2. Modular Layers ---
class Linear:
    def __init__(self, fan_in, fan_out, bias=True):
        self.weight = torch.randn((fan_in, fan_out)) / (fan_in**0.5) # Kaiming
        self.bias = torch.zeros(fan_out) if bias else None
    def __call__(self, x):
        self.out = x @ self.weight
        if self.bias is not None: self.out += self.bias
        return self.out
    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

class BatchNorm1d:
    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)
    def __call__(self, x):
        if self.training:
            xmean = x.mean(0, keepdim=True)
            xvar = x.var(0, keepdim=True)
        else:
            xmean = self.running_mean
            xvar = self.running_var
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
        self.out = self.gamma * xhat + self.beta
        if self.training:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        return self.out
    def parameters(self):
        return [self.gamma, self.beta]

class Tanh:
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out
    def parameters(self):
        return []

# --- 3. Model & Training ---
n_embd = 10
n_hidden = 100
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g)

# 6-layer deep network
layers = [
    Linear(block_size * n_embd, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden, vocab_size, bias=False), BatchNorm1d(vocab_size),
]

with torch.no_grad():
  layers[-1].gamma *= 0.1 # Make less confident
  for l in layers[:-1]:
    if isinstance(l, Linear): l.weight *= 5/3 # Tanh gain

parameters = [C] + [p for layer in layers for p in layer.parameters()]
for p in parameters: p.requires_grad = True

max_steps = 200000
batch_size = 32
lossi, ud = [], []

for i in range(max_steps):
  ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
  Xb, Yb = Xtr[ix], Ytr[ix]
  
  # Forward
  emb = C[Xb]
  x = emb.view(-1, block_size * n_embd)
  for layer in layers: x = layer(x)
  loss = F.cross_entropy(x, Yb)
  
  # Backward
  for p in parameters: p.grad = None
  loss.backward()
  
  # Update
  lr = 0.1 if i < 100000 else 0.01
  for p in parameters: p.data += -lr * p.grad

  if i % 10000 == 0: print(f'{i}: {loss.item():.4f}')
  lossi.append(torch.log10(loss).item())

0: 3.2653
10000: 2.3222
20000: 1.9460
30000: 1.9646
40000: 2.3770
50000: 2.3421
60000: 1.7678
70000: 2.2212
80000: 2.1977
90000: 1.8840
100000: 1.6919
110000: 2.1216
120000: 1.7353
130000: 1.7076
140000: 2.0491
150000: 1.8994
160000: 2.0667
170000: 2.3399
180000: 2.0589
190000: 1.9954
