In [1]:
import random
import torch

words = open('names.txt', 'r').read().splitlines()
words = list(set(words))
random.seed(42)
random.shuffle(words)
len(words)

chs = list(set(''.join(words + ['.'])))
chs = sorted(chs, reverse=False)
stoi = {ch: i for i, ch in enumerate(chs)}
itos = {i: ch for i, ch in enumerate(chs)}

# predict next token use previous 3 tokens
X, Y = [], []

for w in words:
    context = '...'
    for ch in w + '.':
        x = [stoi[c] for c in context]
        y = stoi[ch]
        X.append(x)
        Y.append(y)
        context = context[1:] + ch

X = torch.tensor(X)
Y = torch.tensor(Y)
n1, n2  = int(0.8 * len(X)), int(0.9 * len(X))

X_train, X_val, X_test = X.tensor_split([n1, n2])
Y_train, Y_val, Y_test = Y.tensor_split([n1, n2])

X_train.shape, X_val.shape, X_test.shape, Y_train.shape, Y_val.shape, Y_test.shape


(torch.Size([169062, 3]),
 torch.Size([21133, 3]),
 torch.Size([21133, 3]),
 torch.Size([169062]),
 torch.Size([21133]),
 torch.Size([21133]))

# implement backward from scratch


In [449]:
n_embd = 10
n_hidden = 200
vocab_size = 27
block_size = 3

def get_params():
    torch.manual_seed(42)
    C = torch.randn(vocab_size, n_embd)
    w1 = torch.randn(n_embd * block_size, n_hidden) * (n_embd * block_size)**-0.5
    w2 = torch.randn(n_hidden, vocab_size) * (5/3) * (n_hidden)**-0.5 * 0.1 # 0.1 is for less confident at initialization
    b2 = torch.randn(vocab_size) * 0
    bnw = torch.ones(n_hidden)
    bnb = torch.zeros(n_hidden)
    params = [C, w1, w2, b2, bnw, bnb]
    for p in params:
        p.requires_grad = True
    return params

params = get_params()
C, w1, w2, b2, bnw, bnb = params
bs = 32
idx = torch.randint(0, X_train.shape[0], (bs,))
x, y = X_train[idx], Y_train[idx]

## forward and torch backward


In [3]:
# buffer
mean_proj = torch.ones(1, bs) / bs
var_proj = (torch.eye(bs) - mean_proj)

# forward
emb = C[x].view(x.shape[0], -1)
emb.retain_grad()
hpreact = emb @ w1
hpreact.retain_grad()
bnmeani = mean_proj @ hpreact
bnmeani.retain_grad()
bnstdi = (var_proj @ hpreact).square().mean(dim=0, keepdim=True).sqrt()
bnstdi.retain_grad()
hpreact_bn = (hpreact - bnmeani) / bnstdi * bnw + bnb
hpreact_bn.retain_grad()
h = hpreact_bn.tanh()
h.retain_grad()
logits = h @ w2 + b2
logits.retain_grad()
# 2. loss
exp_l = logits.exp()
exp_l.retain_grad()
count = exp_l.sum(dim=-1, keepdim=True)
count.retain_grad()
probs = exp_l / count
probs.retain_grad()
nlls = -probs.log()
nlls.retain_grad()
loss = nlls[torch.arange(y.shape[0]), y].mean()

# backward
loss.backward()

## manual backward

In [4]:
# buffer grad
nlls_grad = torch.zeros(bs, vocab_size)
probs_grad = torch.zeros(bs, vocab_size)
count_grad = torch.zeros(bs, 1)
exp_l_grad = torch.zeros(bs, vocab_size)
logits_grad = torch.zeros(bs, vocab_size)
h_grad = torch.zeros(bs, n_hidden)
hpreact_bn_grad = torch.zeros(bs, n_hidden)
bnmeani_grad = torch.zeros(1, n_hidden)
bnstdi_grad = torch.zeros(1, n_hidden)
bnvari_grad = torch.zeros(1, n_hidden)
hpreact_grad = torch.zeros(bs, n_hidden)
emb_grad = torch.zeros(bs, n_embd * block_size)
# param grad
C_grad = torch.zeros(vocab_size, n_embd)
w1_grad = torch.zeros(n_embd * block_size, n_hidden)
w2_grad = torch.zeros(n_hidden, vocab_size)
b2_grad = torch.zeros(vocab_size)
bnw_grad = torch.zeros(n_hidden)
bnb_grad = torch.zeros(n_hidden)


In [5]:
# 1. loss
nlls_grad[torch.arange(y.shape[0]), y] = 1 / bs
probs_grad[torch.arange(y.shape[0]), y] = -1 / probs.data[torch.arange(y.shape[0]), y] * nlls_grad[torch.arange(y.shape[0]), y]
count_grad = -(exp_l.data * probs_grad).sum(dim=-1, keepdim=True) / count.data**2
exp_l_grad = probs_grad / count.data + count_grad  # one is from e/c to e, one is from c=\sum e to e
logits_grad = exp_l.data * exp_l_grad

# 2. logits
h_grad = logits_grad @ w2.data.T
hpreact_bn_grad = h_grad * (1 - h.data**2)
# bn
bnmeani_grad = ((-bnw.data / bnstdi.data) * hpreact_bn_grad).sum(dim=0, keepdim=True)
bnstdi_grad = (-((hpreact.data - bnmeani.data) * bnw.data / bnstdi.data**2) * hpreact_bn_grad).sum(dim=0, keepdim=True)
# hpreact
hpreact_grad_mean = bnmeani_grad * torch.ones_like(hpreact.data) / bs
hpreact_grad_std = bnstdi_grad * (1 / 2 / bnstdi.data) * (1 / bs) * (2 * var_proj @ hpreact.data)
hpreact_grad_direct = hpreact_bn_grad * (bnw.data / bnstdi.data)
hpreact_grad = hpreact_grad_mean + hpreact_grad_std + hpreact_grad_direct
# emb
emb_grad = hpreact_grad @ w1.data.T

# 3. params
C_grad.index_add_(dim=0, index=x.view(-1), source=emb_grad.view(-1, n_embd)) # add emb_grad[i] to C[x[i]]
w1_grad = emb.data.T @ hpreact_grad
w2_grad = h.data.T @ logits_grad
b2_grad = logits_grad.sum(dim=0)
bnw_grad = ((hpreact.data - bnmeani.data) / bnstdi.data * hpreact_bn_grad).sum(dim=0)
bnb_grad = hpreact_bn_grad.sum(dim=0)

# check
is_equal1 = [torch.allclose(nlls_grad, nlls.grad), torch.allclose(probs_grad, probs.grad), torch.allclose(count_grad, count.grad), torch.allclose(exp_l_grad, exp_l.grad), torch.allclose(logits_grad, logits.grad)]
is_equal2 = [torch.allclose(h_grad, h.grad), torch.allclose(hpreact_bn_grad, hpreact_bn.grad), torch.allclose(bnmeani_grad, bnmeani.grad), torch.allclose(bnstdi_grad, bnstdi.grad), torch.allclose(hpreact_grad, hpreact.grad), torch.allclose(emb_grad, emb.grad)]
is_equal3 = [torch.allclose(C_grad, C.grad), torch.allclose(w1_grad, w1.grad), torch.allclose(w2_grad, w2.grad), torch.allclose(b2_grad, b2.grad), torch.allclose(bnw_grad, bnw.grad), torch.allclose(bnb_grad, bnb.grad)]
print('same grad for loss calculation:', is_equal1)
print('same grad for logits calculation:', is_equal2)
print('same grad for params:', is_equal3)




same grad for loss calculation: [True, True, True, True, True]
same grad for logits calculation: [True, True, True, True, True, True]
same grad for params: [True, True, True, True, True, True]


# loop

In [28]:
import torch.nn.functional as F

# model
torch.manual_seed(42)
C = torch.randn(vocab_size, n_embd)
w1 = torch.randn(n_embd * block_size, n_hidden) * (n_embd * block_size)**-0.5
w2 = torch.randn(n_hidden, vocab_size) * (5/3) * (n_hidden)**-0.5 * 0.1 # 0.1 is for less confident at initialization
b2 = torch.randn(vocab_size) * 0
bnw = torch.ones(n_hidden)
bnb = torch.zeros(n_hidden)
params = [C, w1, w2, b2, bnw, bnb]
bnmean_running = torch.zeros(n_hidden)
bnstd_running = torch.ones(n_hidden)

# args
bs = 32
n_steps = 10000
ini_lr = 1.0

# buffer
mean_proj = torch.ones(1, bs) / bs
var_proj = (torch.eye(bs) - mean_proj)

torch.manual_seed(42)
for step in range(n_steps):
    lr = ini_lr if step < n_steps // 2 else ini_lr / 10
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]

    # ---------------- forward --------------------
    # 1. logits
    emb = C[x].view(x.shape[0], -1)
    hpreact = emb @ w1
    bnmeani = mean_proj @ hpreact
    bnstdi = (var_proj @ hpreact).square().mean(dim=0, keepdim=True).sqrt()
    hpreact_bn = (hpreact - bnmeani) / bnstdi * bnw + bnb
    h = hpreact_bn.tanh()
    logits = h @ w2 + b2
    # 2. loss
    exp_l = logits.exp()
    count = exp_l.sum(dim=-1, keepdim=True)
    probs = exp_l / count
    nlls = -probs.log()
    loss = nlls[torch.arange(y.shape[0]), y].mean()
    


    # ---------------- backward --------------------
    # 0. zero grad
    # buffer grad
    nlls_grad = torch.zeros(bs, vocab_size)
    probs_grad = torch.zeros(bs, vocab_size)
    count_grad = torch.zeros(bs, 1)
    exp_l_grad = torch.zeros(bs, vocab_size)
    logits_grad = torch.zeros(bs, vocab_size)
    h_grad = torch.zeros(bs, n_hidden)
    hpreact_bn_grad = torch.zeros(bs, n_hidden)
    bnmeani_grad = torch.zeros(1, n_hidden)
    bnstdi_grad = torch.zeros(1, n_hidden)
    bnvari_grad = torch.zeros(1, n_hidden)
    hpreact_grad = torch.zeros(bs, n_hidden)
    emb_grad = torch.zeros(bs, n_embd * block_size)
    # param grad
    C_grad = torch.zeros(vocab_size, n_embd)
    w1_grad = torch.zeros(n_embd * block_size, n_hidden)
    w2_grad = torch.zeros(n_hidden, vocab_size)
    b2_grad = torch.zeros(vocab_size)
    bnw_grad = torch.zeros(n_hidden)
    bnb_grad = torch.zeros(n_hidden)

    # 1. loss
    nlls_grad[torch.arange(y.shape[0]), y] = 1 / bs
    probs_grad[torch.arange(y.shape[0]), y] = -1 / probs.data[torch.arange(y.shape[0]), y] * nlls_grad[torch.arange(y.shape[0]), y]
    count_grad = -(exp_l.data * probs_grad).sum(dim=-1, keepdim=True) / count.data**2
    exp_l_grad = probs_grad / count.data + count_grad  # one is from e/c to e, one is from c=\sum e to e
    logits_grad = exp_l.data * exp_l_grad

    # 2. logits
    h_grad = logits_grad @ w2.data.T
    hpreact_bn_grad = h_grad * (1 - h.data**2)
    # bn
    bnmeani_grad = ((-bnw.data / bnstdi.data) * hpreact_bn_grad).sum(dim=0, keepdim=True)
    bnstdi_grad = (-((hpreact.data - bnmeani.data) * bnw.data / bnstdi.data**2) * hpreact_bn_grad).sum(dim=0, keepdim=True)
    # hpreact
    hpreact_grad_mean = bnmeani_grad * torch.ones_like(hpreact.data) / bs
    hpreact_grad_std = bnstdi_grad * (1 / 2 / bnstdi.data) * (1 / bs) * (2 * var_proj @ hpreact.data)
    hpreact_grad_direct = hpreact_bn_grad * (bnw.data / bnstdi.data)
    hpreact_grad = hpreact_grad_mean + hpreact_grad_std + hpreact_grad_direct
    # emb
    emb_grad = hpreact_grad @ w1.data.T
    
    # 3. params
    C_grad.index_add_(dim=0, index=x.view(-1), source=emb_grad.view(-1, n_embd)) # add emb_grad[i] to C[x[i]]
    w1_grad = emb.data.T @ hpreact_grad
    w2_grad = h.data.T @ logits_grad
    b2_grad = logits_grad.sum(dim=0)
    bnw_grad = ((hpreact.data - bnmeani.data) / bnstdi.data * hpreact_bn_grad).sum(dim=0)
    bnb_grad = hpreact_bn_grad.sum(dim=0)
    param_grads = [C_grad, w1_grad, w2_grad, b2_grad, bnw_grad, bnb_grad]

    if step % 1000 == 0:
        with torch.no_grad():
            emb = C[X_val].view(X_val.shape[0], -1)
            hpreact = emb @ w1
            hpreact = (hpreact - hpreact.mean(dim=0, keepdim=True)) / hpreact.std(dim=0, keepdim=True) * bnw + bnb
            h = hpreact.tanh()
            logits = h @ w2 + b2
            val_loss = F.cross_entropy(logits, Y_val)
            print(f'step: {step}, train loss: {loss.item()}, val loss: {val_loss.item()}')
    
    # update
    for p, g in zip(params, param_grads):
        p.data -= lr * g
    
    bnmean_running = bnmean_running * 0.99 + bnmeani * 0.01
    bnstd_running = bnstd_running * 0.99 + bnstdi * 0.01
    
    

step: 0, train loss: 3.3074488639831543, val loss: 3.3160440921783447
step: 1000, train loss: 2.583611488342285, val loss: 2.4248502254486084
step: 2000, train loss: 2.4143431186676025, val loss: 2.390408992767334
step: 3000, train loss: 2.1222798824310303, val loss: 2.379809856414795
step: 4000, train loss: 2.1513724327087402, val loss: 2.374274969100952
step: 5000, train loss: 2.349586009979248, val loss: 2.386371374130249
step: 6000, train loss: 2.2481563091278076, val loss: 2.2990658283233643
step: 7000, train loss: 2.0973422527313232, val loss: 2.2958080768585205
step: 8000, train loss: 2.5018582344055176, val loss: 2.2962090969085693
step: 9000, train loss: 2.072721242904663, val loss: 2.288125991821289


In [29]:
w1

tensor([[-0.1794, -0.1739,  0.0228,  ...,  0.0201, -0.1145,  0.1450],
        [-0.3108,  0.1481, -0.2057,  ..., -0.3503,  0.2971,  0.1593],
        [-0.0755,  0.0891, -0.0533,  ..., -0.1936, -0.2113, -0.3970],
        ...,
        [ 0.0905, -0.2106, -0.5778,  ..., -0.1088, -0.1970, -0.3650],
        [-0.1003,  0.1913, -0.4233,  ...,  0.3125,  0.0998, -0.2520],
        [ 0.2542,  0.2179, -0.3316,  ..., -0.4188, -0.2184, -0.4282]])

In [30]:
w1[0,0].item()

-0.17936445772647858

# compare with torch

In [31]:
import torch.nn.functional as F

# model
params = get_params()
C, w1, w2, b2, bnw, bnb = params
bnmean_running = torch.zeros(n_hidden)
bnstd_running = torch.ones(n_hidden)

# args
bs = 32
n_steps = 10000
ini_lr = 1.0

# buffer
mean_proj = torch.ones(1, bs) / bs
var_proj = (torch.eye(bs) - mean_proj)

torch.manual_seed(42)
for step in range(n_steps):
    lr = ini_lr if step < n_steps // 2 else ini_lr / 10
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]

    # forward
    emb = C[x].view(x.shape[0], -1)
    hpreact = emb @ w1
    bnmeani = mean_proj @ hpreact
    bnstdi = (var_proj @ hpreact).square().mean(dim=0, keepdim=True).sqrt()
    hpreact_bn = (hpreact - bnmeani) / bnstdi * bnw + bnb
    h = hpreact_bn.tanh()
    logits = h @ w2 + b2
    # 2. loss
    exp_l = logits.exp()
    count = exp_l.sum(dim=-1, keepdim=True)
    probs = exp_l / count
    nlls = -probs.log()
    loss = nlls[torch.arange(y.shape[0]), y].mean()

    # backward
    loss.backward()
    if step % 1000 == 0:
        with torch.no_grad():
            emb = C[X_val].view(X_val.shape[0], -1)
            hpreact = emb @ w1
            hpreact = (hpreact - hpreact.mean(dim=0, keepdim=True)) / hpreact.std(dim=0, keepdim=True) * bnw + bnb
            h = hpreact.tanh()
            logits = h @ w2 + b2
            val_loss = F.cross_entropy(logits, Y_val)
            print(f'step: {step}, train loss: {loss.item()}, val loss: {val_loss.item()}')
    
    # update
    for p in params:
        p.data -= lr * p.grad
        p.grad = None
    with torch.no_grad():
        bnmean_running = bnmean_running * 0.99 + bnmeani * 0.01
        bnstd_running = bnstd_running * 0.99 + bnstdi * 0.01
    
    

step: 0, train loss: 3.3074488639831543, val loss: 3.3160440921783447
step: 1000, train loss: 2.583611488342285, val loss: 2.4248502254486084
step: 2000, train loss: 2.4143431186676025, val loss: 2.390408992767334
step: 3000, train loss: 2.1222798824310303, val loss: 2.379809856414795
step: 4000, train loss: 2.1513726711273193, val loss: 2.374274969100952
step: 5000, train loss: 2.349586009979248, val loss: 2.386371374130249
step: 6000, train loss: 2.2481563091278076, val loss: 2.2990663051605225
step: 7000, train loss: 2.0973422527313232, val loss: 2.2958080768585205
step: 8000, train loss: 2.5018584728240967, val loss: 2.2962090969085693
step: 9000, train loss: 2.072721242904663, val loss: 2.288125991821289


In [32]:
w1

tensor([[-0.1794, -0.1739,  0.0228,  ...,  0.0201, -0.1145,  0.1450],
        [-0.3108,  0.1481, -0.2057,  ..., -0.3503,  0.2971,  0.1593],
        [-0.0755,  0.0891, -0.0533,  ..., -0.1936, -0.2113, -0.3970],
        ...,
        [ 0.0905, -0.2106, -0.5778,  ..., -0.1088, -0.1970, -0.3650],
        [-0.1003,  0.1913, -0.4233,  ...,  0.3125,  0.0998, -0.2520],
        [ 0.2542,  0.2179, -0.3316,  ..., -0.4188, -0.2184, -0.4282]],
       requires_grad=True)

In [33]:
w1[0,0].item()

-0.17936447262763977

# Simpler BatchNorm Grad

Let $x\in\mathbb{R}^{n\times d}$, $w\in\mathbb{R}^d$, $b\in\mathbb{R}^d$, $P = I - \frac{1}{n} \mathbf{1} \mathbf{1}^T \in \mathbb{R}^{n\times n}$, then

$$
    v = (\frac{1}{n} x^T P x)~.\text{diag}() \in \mathbb{R}^d
$$

$$
    o = \frac{Px}{\sqrt{v + \epsilon}} w + b \in \mathbb{R}^{n \times d}
$$

Denote $dx$ as grad from the end layer to current layer, $dy/dx$ as grad from next layer to current layer.

$$
    dw =  do \cdot \frac{do}{dw} = \left((Px) * do\right).~\text{sum}(\text{dim=0}) / \sqrt{v + \epsilon} \in \mathbb{R}^d
$$

$$
    db =  do \cdot \frac{do}{db} = \left(do\right).~\text{sum}(\text{dim=0}) \in \mathbb{R}^d
$$

Grad $dx$ is more complex, but if we directly use computation graph(see micrograd), the result will be clear.

$$
    dx = do \cdot \frac{\partial o}{\partial x} + dv \cdot \frac{d v}{d x}
$$

$$
    dv = do \cdot \frac{\partial o}{\partial v}
$$


In [393]:
torch.manual_seed(42)
# use double precision to avoid numerical error
n, d = 100, 10
# params
x = torch.randn(n, d, dtype=torch.float64)
w = torch.randn(d, dtype=torch.float64)
b = torch.randn(d, dtype=torch.float64)
# grad backprop from next layer
do = torch.randn(n, d, dtype=torch.float64)

# -------- manual --------- 
# buffer
P = torch.eye(n, dtype=torch.float64) - torch.ones(n, n, dtype=torch.float64) / n
# forward
v = (x.T @ P @ x).diag() / n  # (d,) # this may waste memory when d is large
o = (P @ x) / v.sqrt().view(1, d) * w.view(1, d) + b.view(1, d)  # (n, d)
# backward
w_grad = ((P @ x) * do).sum(dim=0) / v.sqrt().view(1, d) # dw
b_grad = do.sum(dim=0) # db
o_to_v_grad = ((o - b) * do / (-2 * v)).sum(dim=0) # do * do/dv
v_to_x_grad = o_to_v_grad * (P @ x) * 2 / n # dv * dv / dx
o_to_x_grad = (P @ do) * w.view(1, d) / v.sqrt().view(1, d) # do * do/dx
x_grad = v_to_x_grad + o_to_x_grad # dx

# -------- torch --------- 
xt, wt, bt = x.clone(), w.clone(), b.clone()
for p in [xt, wt, bt]:
    p.requires_grad = True
ot = (xt - xt.mean(dim=0, keepdim=True)) / xt.std(dim=0, keepdim=True, unbiased=False) * wt.view(1, d) + bt.view(1, d)
out = ot * do
out.sum().backward()

# -------- compare -------- 
print('forward pass:')
print(f'o relative error {((ot - o) / ot).abs().max()}')
print('backward pass:')
print(f'b grad relative error: {((bt.grad - b_grad) / bt.grad).abs().max()}')
print(f'w grad relative error: {((wt.grad - w_grad) / wt.grad).abs().max()}')
print(f'x grad relative error: {((xt.grad - x_grad) / xt.grad).abs().max()}')

forward pass:
o relative error 1.3214495959198304e-13
backward pass:
b grad relative error: 0.0
w grad relative error: 2.9079137144240243e-15
x grad relative error: 8.759191542421218e-14


check for each grad:

In [None]:
torch.manual_seed(42)
n, d = 100, 10
x = torch.randn(n, d, requires_grad=True, dtype=torch.float64)
w = torch.randn(d, requires_grad=True, dtype=torch.float64)
b = torch.randn(d, requires_grad=True, dtype=torch.float64)

# forward
v = (x.T @ P @ x).diag() / n  # (d,) # this may waste memory

# backward
dv = torch.randn_like(v)
(v * dv).sum().backward()
v_to_x_grad = dv * (P @ x) * 2 / n

err = ((x.grad - v_to_x_grad) / x.grad).abs().max().item()
print(f'relative error: {err}')


relative error: 2.4119488693758977e-15


In [None]:
torch.manual_seed(45)
n, d = 100, 10
ones = torch.ones(n, dtype=torch.float64)
x = torch.randn(n, d, requires_grad=True, dtype=torch.float64)
w = torch.randn(d, requires_grad=True, dtype=torch.float64)
b = torch.randn(d, requires_grad=True, dtype=torch.float64)
v = torch.rand(d, requires_grad=True, dtype=torch.float64)

# forward
o = (P @ x) / v.sqrt().view(1, d) * w.view(1, d) + b.view(1, d)  # (n, d)

# backward
do = torch.randn_like(o)
(o * do).sum().backward()

o_to_v_grad = ((o - b) * do / (-2 * v)).sum(dim=0)

err = ((v.grad - o_to_v_grad) / v.grad).abs().max().item()
print(f'relative error: {err}')



relative error: 8.061833389428677e-15


In [None]:
torch.manual_seed(45)
n, d = 100, 10
ones = torch.ones(n, dtype=torch.float64)
x = torch.randn(n, d, requires_grad=True, dtype=torch.float64)
w = torch.randn(d, requires_grad=True, dtype=torch.float64)
b = torch.randn(d, requires_grad=True, dtype=torch.float64)
v = torch.rand(d, requires_grad=True, dtype=torch.float64)

# forward
o = (P @ x) / v.sqrt().view(1, d) * w.view(1, d) + b.view(1, d)  # (n, d)

# backward
do = torch.randn_like(o)
(o * do).sum().backward()

o_to_x_grad = (P @ do) * w.view(1, d) / v.sqrt().view(1, d)

err = ((x.grad - o_to_x_grad) / x.grad).abs().max().item()
print(f'relative error: {err}')



relative error: 3.606235033323989e-13


# pytorchify

In [32]:
import torch

class Linear:

    def __init__(self, in_features, out_features, bias=True, dtype=torch.float64, generator=None):
        self.weight = torch.randn(in_features, out_features, dtype=dtype, generator=generator) * (in_features)**-0.5
        self.bias = torch.zeros(out_features, dtype=dtype) * 0 if bias else None
        self.weight_grad = None
        self.bias_grad = None

    def parameters(self):
        if self.bias is not None:
            return [self.weight, self.bias]
        else:
            return [self.weight]
    
    def grads(self):
        if self.bias is not None:
            return [self.weight_grad, self.bias_grad]
        else:
            return [self.weight_grad]
    
    def __call__(self, x):
        if self.bias is not None:
            out = x @ self.weight + self.bias
        else:
            out = x @ self.weight
        return out
    
    def backward(self, x, out, grad):
        """
            Input:
                x: input of current layer
                out: output of current layer
                grad: grad from next layer
            Output:
                x_grad: grad back to previous layer
        """
        x_grad = grad @ self.weight.T
        self.weight_grad = x.T @ grad
        if self.bias is not None:
            self.bias_grad = grad.sum(dim=0)
        return x_grad

class BatchNorm1d:
    def __init__(self, in_features, eps=1e-5, momentum=0.001, dtype=torch.float64): # manual bn need fp64
        self.weight = torch.ones(in_features, dtype=dtype)
        self.bias = torch.zeros(in_features, dtype=dtype)
        self.running_mean = torch.zeros(in_features, dtype=dtype)
        self.running_var = torch.ones(in_features, dtype=dtype)
        self.eps = eps
        self.momentum = momentum
        self._training = True # internal flag
        self.dtype = dtype
        # backward buffer
        self.weight_grad = None
        self.bias_grad = None
        self.v = None
        self.std = None

    def parameters(self):
        return [self.weight, self.bias]
    
    def grads(self):
        return [self.weight_grad, self.bias_grad]
    
    def __call__(self, x):
        b, d = x.shape
        if self._training:
            Pm = torch.ones(b, dtype=self.dtype) / b # mean projection matrix (b, )
            Pv = (torch.eye(b, dtype=self.dtype) - Pm) # var projection matrix (b, b)
            m = Pm @ x # (d,)
            v = (x.T @ Pv @ x).diag() / b  # (d,) # this may waste memory when d is large # TODO: very slow
            with torch.no_grad():
                self.running_mean = self.running_mean * (1 - self.momentum) + m * self.momentum
                self.running_var = self.running_var * (1 - self.momentum) + v * self.momentum
        else:
            m = self.running_mean
            v = self.running_var
        std = (v + self.eps).sqrt()
        out = self.weight * (x - m) / std + self.bias
        if self._training:
            # backward buffer
            self.v = v
            self.std = std
        return out
    
    def backward(self, x, out, grad):
        assert self._training, 'BatchNorm1d is not in training mode'
        b, d = x.shape
        Pm = torch.ones(b, dtype=self.dtype) / b 
        Pv = (torch.eye(b, dtype=self.dtype) - Pm)
        self.weight_grad = ((Pv @ x) * grad).sum(dim=0) / self.std # dw (d,)
        self.bias_grad = grad.sum(dim=0) # db (d,)
        o_to_v_grad = ((out - self.bias) * grad / (-2 * self.std.square())).sum(dim=0) # do * do/dv (d,)
        v_to_x_grad = o_to_v_grad * (Pv @ x) * 2 / b # dv * dv / dx (b, d)
        o_to_x_grad = (Pv @ grad) * self.weight / self.std # do * do/dx (b, d)
        x_grad = v_to_x_grad + o_to_x_grad # dx (b, d)
        return x_grad


class Tanh:
    
    def parameters(self):
        return []
    
    def grads(self):
        return []

    def __call__(self, x):
        out = x.tanh()
        return out
    
    def backward(self, x, out, grad):
        x_grad = grad * (1 - out**2)
        return x_grad



## check batchnorm

In [16]:
torch.manual_seed(42)
dtype = torch.float64
eps = 1e-5
# model
bn = BatchNorm1d(10, dtype=dtype, eps=eps)
# params
x = torch.randn(100, 10, dtype=dtype)
# ------- manual -------
# forward
o = bn(x)
# backward
do = torch.randn_like(o, dtype=dtype)
dx = bn.backward(x, o, do)

# ------- torch -------
import copy
bnt = copy.deepcopy(bn)
xt = x.clone()
xt.requires_grad = True
for p in bnt.parameters():
    p.requires_grad = True
# forward
ot = (xt - xt.mean(dim=0, keepdim=True)) /( xt.var(dim=0, keepdim=True, unbiased=False) + eps).sqrt() * bnt.weight + bnt.bias
# backward
(ot * do).sum().backward()

# -------- compare -------- 
print('forward pass:')
print(f'o relative error: {((ot - o) / ot).abs().max().item()}')
print('backward pass:')
print(f'db relative error: {((bnt.bias_grad - bn.bias_grad) / bnt.bias_grad).abs().max().item()}')
print(f'dw relative error: {((bnt.weight_grad - bn.weight_grad) / bnt.weight_grad).abs().max().item()}')
print(f'dx relative error: {((dx - xt.grad) / xt.grad).abs().max().item()}')

forward pass:
o relative error: 5.99848193077636e-15
backward pass:
db relative error: 0.0
dw relative error: 0.0
dx relative error: 2.7864881282787247e-14


## check mlp

In [445]:
torch.manual_seed(42)
n_embd = 30
n_hidden = 100
bs = 32
dtype = torch.float64
# model
layers = [Linear(n_embd, n_hidden, bias=False, dtype=dtype), BatchNorm1d(n_hidden, dtype=dtype), Tanh()]
for _ in range(70):
    layers.extend([Linear(n_hidden, n_hidden, bias=False, dtype=dtype), BatchNorm1d(n_hidden, dtype=dtype), Tanh()])
params = [p for l in layers for p in l.parameters()]
print(f'number of params: {sum(p.numel() for p in params) / 1e6:.2f}M')
# input
x = torch.randn(bs, n_embd, dtype=dtype, requires_grad=True)

# --- manual ---
# forward
outs = [x]
for l in layers:
    out = l(outs[-1])
    outs.append(out)

# backward
grad = torch.ones(bs, n_hidden)
for i in range(len(layers)-1, -1, -1):
    grad = layers[i].backward(outs[i], outs[i+1], grad)


# --- torch ---
for p in params:
    p.requires_grad = True
h = x
for l in layers:
    h = l(h)
h.sum().backward()

# --- compare ---
print('check grad:')
print(f'[Layer 1] weight grad relative error: {((params[0].grad - layers[0].weight_grad) / params[0].grad).abs().max().item()}')
print(f'x_grad relative error: {((x.grad - grad) / x.grad).abs().max().item()}')

number of params: 0.72M
check grad:
[Layer 1] weight grad relative error: 1.784616404524195e-11
x_grad relative error: 9.523983785609139e-12


## train mlp

In [39]:
import torch.nn.functional as F

g = torch.Generator().manual_seed(42)
n_embd = 10
n_hidden = 200
vocab_size = 27
block_size = 3
n_layer = 5
dtype = torch.float64
eval_interval = 1000

# model
C = torch.randn(vocab_size, n_embd, dtype=dtype, generator=g)
layers = [Linear(n_embd * block_size, n_hidden, bias=False, dtype=dtype, generator=g), BatchNorm1d(n_hidden, dtype=dtype), Tanh()]
for _ in range(n_layer-2):
    layers.extend([Linear(n_hidden, n_hidden, bias=False, dtype=dtype, generator=g), BatchNorm1d(n_hidden, dtype=dtype), Tanh()])
layers.extend([Linear(n_hidden, vocab_size, bias=False, dtype=dtype, generator=g), BatchNorm1d(vocab_size, dtype=dtype)])
params = [C] + [p for l in layers for p in l.parameters()]
print(f'number of params: {sum(p.numel() for p in params) / 1e6:.2f}M')
layers[-1].weight.data *= 0.1 # less confident

# args
bs = 32
n_steps = 10000
ini_lr = 1.0


torch.manual_seed(42)
for step in range(n_steps):
    lr = ini_lr if step < n_steps // 2 else ini_lr / 10
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]

    # forward
    emb = C[x].view(x.shape[0], -1)
    outs = [emb]
    for l in layers:
        out = l(outs[-1])
        outs.append(out)
    logits = outs[-1]
    # 2. loss
    exp_l = logits.exp()
    count = exp_l.sum(dim=-1, keepdim=True)
    probs = exp_l / count
    nlls = -probs.log()
    loss = nlls[torch.arange(y.shape[0]), y].mean()

    # backward
    # 1. zero grad
    nlls_grad = torch.zeros(bs, vocab_size, dtype=dtype)
    probs_grad = torch.zeros(bs, vocab_size, dtype=dtype)
    count_grad = torch.zeros(bs, 1, dtype=dtype)
    exp_l_grad = torch.zeros(bs, vocab_size, dtype=dtype)
    logits_grad = torch.zeros(bs, vocab_size, dtype=dtype)
    emb_grad = torch.zeros(bs, n_embd * block_size, dtype=dtype)
    # param grad
    C_grad = torch.zeros(vocab_size, n_embd, dtype=dtype)
    # 2. backward
    # loss
    nlls_grad[torch.arange(y.shape[0]), y] = 1 / bs
    probs_grad[torch.arange(y.shape[0]), y] = -1 / probs.data[torch.arange(y.shape[0]), y] * nlls_grad[torch.arange(y.shape[0]), y]
    count_grad = -(exp_l.data * probs_grad).sum(dim=-1, keepdim=True) / count.data**2
    exp_l_grad = probs_grad / count.data + count_grad  # one is from e/c to e, one is from c=\sum e to e
    logits_grad = exp_l.data * exp_l_grad
    # layers
    h_grad = logits_grad
    for i in range(len(layers)-1, -1, -1):
        h_grad = layers[i].backward(outs[i], outs[i+1], h_grad)
    # embedding
    emb_grad = h_grad
    C_grad.index_add_(dim=0, index=x.view(-1), source=emb_grad.view(-1, n_embd))

    if step % eval_interval == 0:
        x, y = X_val[:128], Y_val[:128] # TODO: large bs, batchnorm is slow
        emb = C[x].view(x.shape[0], -1)
        h = emb
        for l in layers:
            h = l(h)
        logits = h
        val_loss = F.cross_entropy(logits, y)
        print(f'step: {step}, train loss: {loss.item()}, val loss: {val_loss.item()}')
    
    # update
    param_grads = [C_grad] + [p for l in layers for p in l.grads()]
    for p, g in zip(params, param_grads):
        p.data -= lr * g
    
    

number of params: 0.13M
step: 0, train loss: 3.3138560633196645, val loss: 3.2952505718703575
step: 1000, train loss: 2.2650234162579297, val loss: 2.3929781293073993
step: 2000, train loss: 2.383280954053651, val loss: 2.2941183192098977
step: 3000, train loss: 2.0633511993347464, val loss: 2.366037332412258
step: 4000, train loss: 2.386221372864641, val loss: 2.255767476565825
step: 5000, train loss: 2.131021386443559, val loss: 2.338388504781813
step: 6000, train loss: 2.5602508862779385, val loss: 2.2380155731180147
step: 7000, train loss: 2.3065974647277168, val loss: 2.2459830960273215
step: 8000, train loss: 2.190829772113247, val loss: 2.218119743575988
step: 9000, train loss: 1.930482663268064, val loss: 2.228581415053821


## compare with torch

In [38]:
import torch.nn.functional as F

g = torch.Generator().manual_seed(42)
n_embd = 10
n_hidden = 200
vocab_size = 27
block_size = 3
n_layer = 5
dtype = torch.float64
eval_interval = 1000

# model
C = torch.randn(vocab_size, n_embd, dtype=dtype, generator=g)
layers = [Linear(n_embd * block_size, n_hidden, bias=False, dtype=dtype, generator=g), BatchNorm1d(n_hidden, dtype=dtype), Tanh()]
for _ in range(n_layer-2):
    layers.extend([Linear(n_hidden, n_hidden, bias=False, dtype=dtype, generator=g), BatchNorm1d(n_hidden, dtype=dtype), Tanh()])
layers.extend([Linear(n_hidden, vocab_size, bias=False, dtype=dtype, generator=g), BatchNorm1d(vocab_size, dtype=dtype)])
params = [C] + [p for l in layers for p in l.parameters()]
print(f'number of params: {sum(p.numel() for p in params) / 1e6:.2f}M')
layers[-1].weight.data *= 0.1 # less confident
for p in params:
    p.requires_grad = True

# args
bs = 32
n_steps = 10000
ini_lr = 1.0


torch.manual_seed(42)
for step in range(n_steps):
    lr = ini_lr if step < n_steps // 2 else ini_lr / 10
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]

    # forward
    emb = C[x].view(x.shape[0], -1)
    outs = [emb]
    for l in layers:
        out = l(outs[-1])
        outs.append(out)
    logits = outs[-1]
    # 2. loss
    exp_l = logits.exp()
    count = exp_l.sum(dim=-1, keepdim=True)
    probs = exp_l / count
    nlls = -probs.log()
    loss = nlls[torch.arange(y.shape[0]), y].mean()

    # backward
    loss.backward()

    if step % eval_interval == 0:
        with torch.no_grad():
            x, y = X_val[:128], Y_val[:128] # TODO: large bs, batchnorm is slow
            emb = C[x].view(x.shape[0], -1)
            h = emb
            for l in layers:
                h = l(h)
            logits = h
            val_loss = F.cross_entropy(logits, y)
            print(f'step: {step}, train loss: {loss.item()}, val loss: {val_loss.item()}')
    
    # update
    for p in params:
        p.data -= lr * p.grad
        p.grad = None
    
    

number of params: 0.13M
step: 0, train loss: 3.3138560633196645, val loss: 3.2952505718703575
step: 1000, train loss: 2.2650234162579306, val loss: 2.3929781293073997
step: 2000, train loss: 2.3832809540536495, val loss: 2.2941183192098977
step: 3000, train loss: 2.0633511993347464, val loss: 2.3660373324122586
step: 4000, train loss: 2.3862213728646404, val loss: 2.2557674765658247
step: 5000, train loss: 2.1310213864435594, val loss: 2.338388504781811
step: 6000, train loss: 2.560250886277938, val loss: 2.2380155731180142
step: 7000, train loss: 2.306597464727716, val loss: 2.2459830960273215
step: 8000, train loss: 2.190829772113249, val loss: 2.218119743575987
step: 9000, train loss: 1.9304826632680636, val loss: 2.228581415053821
