In [1]:
import random
import torch

words = open('names.txt', 'r').read().splitlines()
words = list(set(words))
random.seed(42)
random.shuffle(words)
len(words)

chs = list(set(''.join(words + ['.'])))
chs = sorted(chs, reverse=False)
stoi = {ch: i for i, ch in enumerate(chs)}
itos = {i: ch for i, ch in enumerate(chs)}

# predict next token use previous 3 tokens
X, Y = [], []

for w in words:
    context = '...'
    for ch in w + '.':
        x = [stoi[c] for c in context]
        y = stoi[ch]
        X.append(x)
        Y.append(y)
        context = context[1:] + ch

X = torch.tensor(X)
Y = torch.tensor(Y)
n1, n2  = int(0.8 * len(X)), int(0.9 * len(X))

X_train, X_val, X_test = X.tensor_split([n1, n2])
Y_train, Y_val, Y_test = Y.tensor_split([n1, n2])

X_train.shape, X_val.shape, X_test.shape, Y_train.shape, Y_val.shape, Y_test.shape


(torch.Size([169062, 3]),
 torch.Size([21133, 3]),
 torch.Size([21133, 3]),
 torch.Size([169062]),
 torch.Size([21133]),
 torch.Size([21133]))

# implement backward from scratch


In [449]:
n_embd = 10
n_hidden = 200
vocab_size = 27
block_size = 3

def get_params():
    torch.manual_seed(42)
    C = torch.randn(vocab_size, n_embd)
    w1 = torch.randn(n_embd * block_size, n_hidden) * (n_embd * block_size)**-0.5
    w2 = torch.randn(n_hidden, vocab_size) * (5/3) * (n_hidden)**-0.5 * 0.1 # 0.1 is for less confident at initialization
    b2 = torch.randn(vocab_size) * 0
    bnw = torch.ones(n_hidden)
    bnb = torch.zeros(n_hidden)
    params = [C, w1, w2, b2, bnw, bnb]
    for p in params:
        p.requires_grad = True
    return params

params = get_params()
C, w1, w2, b2, bnw, bnb = params
bs = 32
idx = torch.randint(0, X_train.shape[0], (bs,))
x, y = X_train[idx], Y_train[idx]

## forward and torch backward


In [3]:
# buffer
mean_proj = torch.ones(1, bs) / bs
var_proj = (torch.eye(bs) - mean_proj)

# forward
emb = C[x].view(x.shape[0], -1)
emb.retain_grad()
hpreact = emb @ w1
hpreact.retain_grad()
bnmeani = mean_proj @ hpreact
bnmeani.retain_grad()
bnstdi = (var_proj @ hpreact).square().mean(dim=0, keepdim=True).sqrt()
bnstdi.retain_grad()
hpreact_bn = (hpreact - bnmeani) / bnstdi * bnw + bnb
hpreact_bn.retain_grad()
h = hpreact_bn.tanh()
h.retain_grad()
logits = h @ w2 + b2
logits.retain_grad()
# 2. loss
exp_l = logits.exp()
exp_l.retain_grad()
count = exp_l.sum(dim=-1, keepdim=True)
count.retain_grad()
probs = exp_l / count
probs.retain_grad()
nlls = -probs.log()
nlls.retain_grad()
loss = nlls[torch.arange(y.shape[0]), y].mean()

# backward
loss.backward()

## manual backward

In [4]:
# buffer grad
nlls_grad = torch.zeros(bs, vocab_size)
probs_grad = torch.zeros(bs, vocab_size)
count_grad = torch.zeros(bs, 1)
exp_l_grad = torch.zeros(bs, vocab_size)
logits_grad = torch.zeros(bs, vocab_size)
h_grad = torch.zeros(bs, n_hidden)
hpreact_bn_grad = torch.zeros(bs, n_hidden)
bnmeani_grad = torch.zeros(1, n_hidden)
bnstdi_grad = torch.zeros(1, n_hidden)
bnvari_grad = torch.zeros(1, n_hidden)
hpreact_grad = torch.zeros(bs, n_hidden)
emb_grad = torch.zeros(bs, n_embd * block_size)
# param grad
C_grad = torch.zeros(vocab_size, n_embd)
w1_grad = torch.zeros(n_embd * block_size, n_hidden)
w2_grad = torch.zeros(n_hidden, vocab_size)
b2_grad = torch.zeros(vocab_size)
bnw_grad = torch.zeros(n_hidden)
bnb_grad = torch.zeros(n_hidden)


In [5]:
# 1. loss
nlls_grad[torch.arange(y.shape[0]), y] = 1 / bs
probs_grad[torch.arange(y.shape[0]), y] = -1 / probs.data[torch.arange(y.shape[0]), y] * nlls_grad[torch.arange(y.shape[0]), y]
count_grad = -(exp_l.data * probs_grad).sum(dim=-1, keepdim=True) / count.data**2
exp_l_grad = probs_grad / count.data + count_grad  # one is from e/c to e, one is from c=\sum e to e
logits_grad = exp_l.data * exp_l_grad

# 2. logits
h_grad = logits_grad @ w2.data.T
hpreact_bn_grad = h_grad * (1 - h.data**2)
# bn
bnmeani_grad = ((-bnw.data / bnstdi.data) * hpreact_bn_grad).sum(dim=0, keepdim=True)
bnstdi_grad = (-((hpreact.data - bnmeani.data) * bnw.data / bnstdi.data**2) * hpreact_bn_grad).sum(dim=0, keepdim=True)
# hpreact
hpreact_grad_mean = bnmeani_grad * torch.ones_like(hpreact.data) / bs
hpreact_grad_std = bnstdi_grad * (1 / 2 / bnstdi.data) * (1 / bs) * (2 * var_proj @ hpreact.data)
hpreact_grad_direct = hpreact_bn_grad * (bnw.data / bnstdi.data)
hpreact_grad = hpreact_grad_mean + hpreact_grad_std + hpreact_grad_direct
# emb
emb_grad = hpreact_grad @ w1.data.T

# 3. params
C_grad.index_add_(dim=0, index=x.view(-1), source=emb_grad.view(-1, n_embd)) # add emb_grad[i] to C[x[i]]
w1_grad = emb.data.T @ hpreact_grad
w2_grad = h.data.T @ logits_grad
b2_grad = logits_grad.sum(dim=0)
bnw_grad = ((hpreact.data - bnmeani.data) / bnstdi.data * hpreact_bn_grad).sum(dim=0)
bnb_grad = hpreact_bn_grad.sum(dim=0)

# check
is_equal1 = [torch.allclose(nlls_grad, nlls.grad), torch.allclose(probs_grad, probs.grad), torch.allclose(count_grad, count.grad), torch.allclose(exp_l_grad, exp_l.grad), torch.allclose(logits_grad, logits.grad)]
is_equal2 = [torch.allclose(h_grad, h.grad), torch.allclose(hpreact_bn_grad, hpreact_bn.grad), torch.allclose(bnmeani_grad, bnmeani.grad), torch.allclose(bnstdi_grad, bnstdi.grad), torch.allclose(hpreact_grad, hpreact.grad), torch.allclose(emb_grad, emb.grad)]
is_equal3 = [torch.allclose(C_grad, C.grad), torch.allclose(w1_grad, w1.grad), torch.allclose(w2_grad, w2.grad), torch.allclose(b2_grad, b2.grad), torch.allclose(bnw_grad, bnw.grad), torch.allclose(bnb_grad, bnb.grad)]
print('same grad for loss calculation:', is_equal1)
print('same grad for logits calculation:', is_equal2)
print('same grad for params:', is_equal3)




same grad for loss calculation: [True, True, True, True, True]
same grad for logits calculation: [True, True, True, True, True, True]
same grad for params: [True, True, True, True, True, True]


# loop

In [28]:
import torch.nn.functional as F

# model
torch.manual_seed(42)
C = torch.randn(vocab_size, n_embd)
w1 = torch.randn(n_embd * block_size, n_hidden) * (n_embd * block_size)**-0.5
w2 = torch.randn(n_hidden, vocab_size) * (5/3) * (n_hidden)**-0.5 * 0.1 # 0.1 is for less confident at initialization
b2 = torch.randn(vocab_size) * 0
bnw = torch.ones(n_hidden)
bnb = torch.zeros(n_hidden)
params = [C, w1, w2, b2, bnw, bnb]
bnmean_running = torch.zeros(n_hidden)
bnstd_running = torch.ones(n_hidden)

# args
bs = 32
n_steps = 10000
ini_lr = 1.0

# buffer
mean_proj = torch.ones(1, bs) / bs
var_proj = (torch.eye(bs) - mean_proj)

torch.manual_seed(42)
for step in range(n_steps):
    lr = ini_lr if step < n_steps // 2 else ini_lr / 10
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]

    # ---------------- forward --------------------
    # 1. logits
    emb = C[x].view(x.shape[0], -1)
    hpreact = emb @ w1
    bnmeani = mean_proj @ hpreact
    bnstdi = (var_proj @ hpreact).square().mean(dim=0, keepdim=True).sqrt()
    hpreact_bn = (hpreact - bnmeani) / bnstdi * bnw + bnb
    h = hpreact_bn.tanh()
    logits = h @ w2 + b2
    # 2. loss
    exp_l = logits.exp()
    count = exp_l.sum(dim=-1, keepdim=True)
    probs = exp_l / count
    nlls = -probs.log()
    loss = nlls[torch.arange(y.shape[0]), y].mean()
    


    # ---------------- backward --------------------
    # 0. zero grad
    # buffer grad
    nlls_grad = torch.zeros(bs, vocab_size)
    probs_grad = torch.zeros(bs, vocab_size)
    count_grad = torch.zeros(bs, 1)
    exp_l_grad = torch.zeros(bs, vocab_size)
    logits_grad = torch.zeros(bs, vocab_size)
    h_grad = torch.zeros(bs, n_hidden)
    hpreact_bn_grad = torch.zeros(bs, n_hidden)
    bnmeani_grad = torch.zeros(1, n_hidden)
    bnstdi_grad = torch.zeros(1, n_hidden)
    bnvari_grad = torch.zeros(1, n_hidden)
    hpreact_grad = torch.zeros(bs, n_hidden)
    emb_grad = torch.zeros(bs, n_embd * block_size)
    # param grad
    C_grad = torch.zeros(vocab_size, n_embd)
    w1_grad = torch.zeros(n_embd * block_size, n_hidden)
    w2_grad = torch.zeros(n_hidden, vocab_size)
    b2_grad = torch.zeros(vocab_size)
    bnw_grad = torch.zeros(n_hidden)
    bnb_grad = torch.zeros(n_hidden)

    # 1. loss
    nlls_grad[torch.arange(y.shape[0]), y] = 1 / bs
    probs_grad[torch.arange(y.shape[0]), y] = -1 / probs.data[torch.arange(y.shape[0]), y] * nlls_grad[torch.arange(y.shape[0]), y]
    count_grad = -(exp_l.data * probs_grad).sum(dim=-1, keepdim=True) / count.data**2
    exp_l_grad = probs_grad / count.data + count_grad  # one is from e/c to e, one is from c=\sum e to e
    logits_grad = exp_l.data * exp_l_grad

    # 2. logits
    h_grad = logits_grad @ w2.data.T
    hpreact_bn_grad = h_grad * (1 - h.data**2)
    # bn
    bnmeani_grad = ((-bnw.data / bnstdi.data) * hpreact_bn_grad).sum(dim=0, keepdim=True)
    bnstdi_grad = (-((hpreact.data - bnmeani.data) * bnw.data / bnstdi.data**2) * hpreact_bn_grad).sum(dim=0, keepdim=True)
    # hpreact
    hpreact_grad_mean = bnmeani_grad * torch.ones_like(hpreact.data) / bs
    hpreact_grad_std = bnstdi_grad * (1 / 2 / bnstdi.data) * (1 / bs) * (2 * var_proj @ hpreact.data)
    hpreact_grad_direct = hpreact_bn_grad * (bnw.data / bnstdi.data)
    hpreact_grad = hpreact_grad_mean + hpreact_grad_std + hpreact_grad_direct
    # emb
    emb_grad = hpreact_grad @ w1.data.T
    
    # 3. params
    C_grad.index_add_(dim=0, index=x.view(-1), source=emb_grad.view(-1, n_embd)) # add emb_grad[i] to C[x[i]]
    w1_grad = emb.data.T @ hpreact_grad
    w2_grad = h.data.T @ logits_grad
    b2_grad = logits_grad.sum(dim=0)
    bnw_grad = ((hpreact.data - bnmeani.data) / bnstdi.data * hpreact_bn_grad).sum(dim=0)
    bnb_grad = hpreact_bn_grad.sum(dim=0)
    param_grads = [C_grad, w1_grad, w2_grad, b2_grad, bnw_grad, bnb_grad]

    if step % 1000 == 0:
        with torch.no_grad():
            emb = C[X_val].view(X_val.shape[0], -1)
            hpreact = emb @ w1
            hpreact = (hpreact - hpreact.mean(dim=0, keepdim=True)) / hpreact.std(dim=0, keepdim=True) * bnw + bnb
            h = hpreact.tanh()
            logits = h @ w2 + b2
            val_loss = F.cross_entropy(logits, Y_val)
            print(f'step: {step}, train loss: {loss.item()}, val loss: {val_loss.item()}')
    
    # update
    for p, g in zip(params, param_grads):
        p.data -= lr * g
    
    bnmean_running = bnmean_running * 0.99 + bnmeani * 0.01
    bnstd_running = bnstd_running * 0.99 + bnstdi * 0.01
    
    

step: 0, train loss: 3.3074488639831543, val loss: 3.3160440921783447
step: 1000, train loss: 2.583611488342285, val loss: 2.4248502254486084
step: 2000, train loss: 2.4143431186676025, val loss: 2.390408992767334
step: 3000, train loss: 2.1222798824310303, val loss: 2.379809856414795
step: 4000, train loss: 2.1513724327087402, val loss: 2.374274969100952
step: 5000, train loss: 2.349586009979248, val loss: 2.386371374130249
step: 6000, train loss: 2.2481563091278076, val loss: 2.2990658283233643
step: 7000, train loss: 2.0973422527313232, val loss: 2.2958080768585205
step: 8000, train loss: 2.5018582344055176, val loss: 2.2962090969085693
step: 9000, train loss: 2.072721242904663, val loss: 2.288125991821289


In [29]:
w1

tensor([[-0.1794, -0.1739,  0.0228,  ...,  0.0201, -0.1145,  0.1450],
        [-0.3108,  0.1481, -0.2057,  ..., -0.3503,  0.2971,  0.1593],
        [-0.0755,  0.0891, -0.0533,  ..., -0.1936, -0.2113, -0.3970],
        ...,
        [ 0.0905, -0.2106, -0.5778,  ..., -0.1088, -0.1970, -0.3650],
        [-0.1003,  0.1913, -0.4233,  ...,  0.3125,  0.0998, -0.2520],
        [ 0.2542,  0.2179, -0.3316,  ..., -0.4188, -0.2184, -0.4282]])

In [30]:
w1[0,0].item()

-0.17936445772647858

# compare with torch

In [31]:
import torch.nn.functional as F

# model
params = get_params()
C, w1, w2, b2, bnw, bnb = params
bnmean_running = torch.zeros(n_hidden)
bnstd_running = torch.ones(n_hidden)

# args
bs = 32
n_steps = 10000
ini_lr = 1.0

# buffer
mean_proj = torch.ones(1, bs) / bs
var_proj = (torch.eye(bs) - mean_proj)

torch.manual_seed(42)
for step in range(n_steps):
    lr = ini_lr if step < n_steps // 2 else ini_lr / 10
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]

    # forward
    emb = C[x].view(x.shape[0], -1)
    hpreact = emb @ w1
    bnmeani = mean_proj @ hpreact
    bnstdi = (var_proj @ hpreact).square().mean(dim=0, keepdim=True).sqrt()
    hpreact_bn = (hpreact - bnmeani) / bnstdi * bnw + bnb
    h = hpreact_bn.tanh()
    logits = h @ w2 + b2
    # 2. loss
    exp_l = logits.exp()
    count = exp_l.sum(dim=-1, keepdim=True)
    probs = exp_l / count
    nlls = -probs.log()
    loss = nlls[torch.arange(y.shape[0]), y].mean()

    # backward
    loss.backward()
    if step % 1000 == 0:
        with torch.no_grad():
            emb = C[X_val].view(X_val.shape[0], -1)
            hpreact = emb @ w1
            hpreact = (hpreact - hpreact.mean(dim=0, keepdim=True)) / hpreact.std(dim=0, keepdim=True) * bnw + bnb
            h = hpreact.tanh()
            logits = h @ w2 + b2
            val_loss = F.cross_entropy(logits, Y_val)
            print(f'step: {step}, train loss: {loss.item()}, val loss: {val_loss.item()}')
    
    # update
    for p in params:
        p.data -= lr * p.grad
        p.grad = None
    with torch.no_grad():
        bnmean_running = bnmean_running * 0.99 + bnmeani * 0.01
        bnstd_running = bnstd_running * 0.99 + bnstdi * 0.01
    
    

step: 0, train loss: 3.3074488639831543, val loss: 3.3160440921783447
step: 1000, train loss: 2.583611488342285, val loss: 2.4248502254486084
step: 2000, train loss: 2.4143431186676025, val loss: 2.390408992767334
step: 3000, train loss: 2.1222798824310303, val loss: 2.379809856414795
step: 4000, train loss: 2.1513726711273193, val loss: 2.374274969100952
step: 5000, train loss: 2.349586009979248, val loss: 2.386371374130249
step: 6000, train loss: 2.2481563091278076, val loss: 2.2990663051605225
step: 7000, train loss: 2.0973422527313232, val loss: 2.2958080768585205
step: 8000, train loss: 2.5018584728240967, val loss: 2.2962090969085693
step: 9000, train loss: 2.072721242904663, val loss: 2.288125991821289


In [32]:
w1

tensor([[-0.1794, -0.1739,  0.0228,  ...,  0.0201, -0.1145,  0.1450],
        [-0.3108,  0.1481, -0.2057,  ..., -0.3503,  0.2971,  0.1593],
        [-0.0755,  0.0891, -0.0533,  ..., -0.1936, -0.2113, -0.3970],
        ...,
        [ 0.0905, -0.2106, -0.5778,  ..., -0.1088, -0.1970, -0.3650],
        [-0.1003,  0.1913, -0.4233,  ...,  0.3125,  0.0998, -0.2520],
        [ 0.2542,  0.2179, -0.3316,  ..., -0.4188, -0.2184, -0.4282]],
       requires_grad=True)

In [33]:
w1[0,0].item()

-0.17936447262763977

# simpler grad
## BatchNorm 

Let $x\in\mathbb{R}^{n\times d}$, $w\in\mathbb{R}^d$, $b\in\mathbb{R}^d$, define $\bar{x} = x\text{.mean}(\text{dim=0})$ then

$$
    o = \frac{x - \bar{x}}{\sqrt{(x - \bar{x})^2\text{.mean(dim=0)} + \epsilon}} w + b \in \mathbb{R}^{n \times d}
$$

Note: as torch, we don't use Bessel correction

Let $s = \sqrt{(x - \bar{x})^2\text{.mean(dim=0)} + \epsilon}$ and $x_{\text{norm}} = \frac{x - \bar{x}}{s}$.

Denote $dx$ as grad from the end layer to current layer, $dy/dx$ as grad from next layer to current layer.

$$
    dw =  do \cdot \frac{do}{dw} = \left(x_{\text{norm}} * do\right).~\text{sum}(\text{dim=0}) \in \mathbb{R}^d
$$

$$
    db =  do \cdot \frac{do}{db} = \left(do\right).~\text{sum}(\text{dim=0}) \in \mathbb{R}^d
$$

Grad $dx$ is more complex, but if we directly use computation graph to calculate grad in scalar level, and then simplify the computation with tensor operations and algebraic transformation. It's easy to see

$$
    dx = \left(
            s * do - s * do\text{.mean(dim=0)} - \frac{1}{s} * (x - \bar{x}) * ((x - \bar{x}) * do).~\text{mean}(\text{dim=0})
    \right) * w * \frac{1}{s^2}
$$

Combine same terms and use quantities already calculated in forward pass, we get

$$
\begin{aligned}
    dx &= \left(
            (do - do\text{.mean(dim=0)}) - \frac{x - \bar{x}}{s} * \left(\frac{x - \bar{x}}{s} * do\right).~\text{mean}(\text{dim=0})
    \right) * w * \frac{1}{s} \\
    &= \left(
            (do - \frac{db}{n}) - x_\text{norm} * \left(\frac{dw}{n}\right)
    \right) * w * \frac{1}{s}
\end{aligned}
$$

## LayerNorm

Almost the same as BatchNorm, but we need to consider the last dim.

$$
    dw =  do \cdot \frac{do}{dw} = \left(x_{\text{norm}} * do\right).~\text{sum}(\text{dim=[0,1,ndim-1]}) \in \mathbb{R}^d
$$

$$
    db =  do \cdot \frac{do}{db} = \left(do\right).~\text{sum}(\text{dim=[0,1,ndim-1]}) \in \mathbb{R}^d
$$

$$
    dx = \left(
            (do - do\text{.mean(dim=-1)}) - \frac{x - \bar{x}}{s} * \left(\frac{x - \bar{x}}{s} * do\right).~\text{mean}(\text{dim=-1})
    \right) * w * \frac{1}{s}
$$


# pytorchify

In [15]:
from tiny_torch import *

## check cross entropy loss

In [9]:
# --- manual ---
loss_fn = CrossEntropyLoss()
x = torch.randn(100, 10, dtype=torch.float64)
y = torch.randint(0, 10, (100,))
loss = loss_fn(x, y)
x_grad = loss_fn.backward(grad=1.0) # last layer, dloss=1.0


# --- torch ---
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
xt, yt = x.clone(), y.clone()
xt.requires_grad = True
loss = loss_fn(xt, yt)
loss.backward()
print(f'x grad relative error: {((xt.grad - x_grad) / xt.grad).abs().max().item()}')

x grad relative error: 6.827384472361946e-16


## check batchnorm

In [10]:
torch.manual_seed(42)
dtype = torch.float64
eps = 1e-5
# model
bn = BatchNorm1d(10, dtype=dtype, eps=eps)
# params
x = torch.randn(100, 10, dtype=dtype)
# ------- manual -------
# forward
o = bn(x)
# backward
do = torch.randn_like(o, dtype=dtype)
dx = bn.backward(do)

# ------- torch -------
import torch.nn as nn
bnt = nn.BatchNorm1d(10, dtype=dtype, eps=eps)
bnt.weight.data = bn.weight.data
bnt.bias.data = bn.bias.data
xt = x.clone()
xt.requires_grad = True
# forward
ot = bnt(xt)
# backward
(ot * do).sum().backward()

# -------- compare -------- 
print('forward pass:')
print(f'o relative error: {((ot - o) / ot).abs().max().item()}')
print('backward pass:')
print(f'db relative error: {((bnt.bias.grad - bn.bias_grad) / bnt.bias.grad).abs().max().item()}')
print(f'dw relative error: {((bnt.weight.grad - bn.weight_grad) / bnt.weight.grad).abs().max().item()}')
print(f'dx relative error: {((dx - xt.grad) / xt.grad).abs().max().item()}')


forward pass:
o relative error: 5.649903956929978e-15
backward pass:
db relative error: 4.587874512547105e-16
dw relative error: 1.4906969562791864e-15
dx relative error: 1.8066190659672137e-14


## check layernorm

In [11]:
torch.manual_seed(42)
dtype = torch.float64
eps = 1e-5
# model
ln = LayerNorm(10, dtype=dtype, eps=eps)
# params
x = torch.randn(3, 32, 100, 10, dtype=dtype)
# ------- manual -------
# forward
o = ln(x)
# backward
do = torch.randn_like(o, dtype=dtype)
dx = ln.backward(do)

# ------- torch -------
import torch.nn as nn
lnt = nn.LayerNorm(10, dtype=dtype, eps=eps)
lnt.weight.data = ln.weight.data
lnt.bias.data = ln.bias.data
xt = x.clone()
xt.requires_grad = True
# forward
ot = lnt(xt)
# backward
ot.backward(do)

# -------- compare -------- 
print('forward pass:')
print(f'o relative error: {((ot - o) / ot).abs().max().item()}')
print('backward pass:')
print(f'db relative error: {((lnt.bias.grad - ln.bias_grad) / lnt.bias.grad).abs().max().item()}')
print(f'dw relative error: {((lnt.weight.grad - ln.weight_grad) / lnt.weight.grad).abs().max().item()}')
print(f'dx relative error: {((dx - xt.grad) / xt.grad).abs().max().item()}')



forward pass:
o relative error: 1.1410527296324882e-11
backward pass:
db relative error: 8.099100649135957e-15
dw relative error: 6.349328555149738e-15
dx relative error: 7.678927882277915e-12


## check mlp

In [11]:
torch.manual_seed(42)
n_embd = 30
n_hidden = 100
bs = 32
dtype = torch.float64
# model
layers = [Linear(n_embd, n_hidden, bias=False, dtype=dtype), BatchNorm1d(n_hidden, dtype=dtype), Tanh()]
for _ in range(70):
    layers.extend([Linear(n_hidden, n_hidden, bias=False, dtype=dtype), BatchNorm1d(n_hidden, dtype=dtype), Tanh()])
params = [p for l in layers for p in l.parameters()]
print(f'number of params: {sum(p.numel() for p in params) / 1e6:.2f}M')
# input
x = torch.randn(bs, n_embd, dtype=dtype, requires_grad=True)

# --- manual ---
# forward
h = x
for l in layers:
    h = l(h)

# backward
grad = torch.ones(bs, n_hidden)
for i in range(len(layers)-1, -1, -1):
    grad = layers[i].backward(grad)


# --- torch ---
for p in params:
    p.requires_grad = True
h = x
for l in layers:
    h = l(h)
h.sum().backward()

# --- compare ---
print('check grad:')
print(f'[Layer 1] weight grad relative error: {((params[0].grad - layers[0].weight_grad) / params[0].grad).abs().max().item()}')
print(f'x_grad relative error: {((x.grad - grad) / x.grad).abs().max().item()}')

number of params: 0.72M
check grad:
[Layer 1] weight grad relative error: 6.348547564942509e-12
x_grad relative error: 1.2315982383522732e-11


# train mlp and compare with torch
Exactly the same(regardless of tiny float error)

In [2]:
from tiny_torch import *

class MLP(Module):
    def __init__(self, vocab_size, block_size, n_embd, n_hidden, n_layer, dtype=torch.float64, generator=None):
        layers = [Embedding(vocab_size, n_embd, dtype=dtype, generator=generator), Flatten(), Linear(n_embd * block_size, n_hidden, bias=False, dtype=dtype, generator=generator), BatchNorm1d(n_hidden, dtype=dtype), Tanh()]
        for _ in range(n_layer-2):
            layers.extend([Linear(n_hidden, n_hidden, bias=False, dtype=dtype, generator=generator), BatchNorm1d(n_hidden, dtype=dtype), Tanh()])
        layers.extend([Linear(n_hidden, vocab_size, bias=False, dtype=dtype, generator=generator), BatchNorm1d(vocab_size, dtype=dtype)])
        layers[-1].weight.data *= 0.1
        self.net = Sequential(layers)
        self.block_size = block_size

    def parameters(self):
        return self.net.parameters()
    
    def grads(self):
        return self.net.grads()

    def __call__(self, x):
        return self.net(x)

    def backward(self, grad):
        grad = self.net.backward(grad)
        return grad # None
    
    def eval(self):
        for l in self.net.layers:
            l._training = False

    def train(self):
        for l in self.net.layers:
            l._training = True

    def generate(self, s, max_new_tokens, do_sample=True, temperature=1.0):
        assert isinstance(s, str), 'str in, str out'
        assert len(s) == self.block_size, 'input string length must be equal to block size'
        x = torch.tensor([[stoi[ch] for ch in s]])
        for _ in range(max_new_tokens):
            cond = x[:, -self.block_size:]
            logits = self(cond) * (1 / temperature)
            probs = logits.softmax(dim=-1)
            if do_sample:
                next_x = torch.multinomial(probs, num_samples=1)
            else:
                next_x = probs.argmax(dim=-1, keepdim=True)
            x = torch.cat([x, next_x], dim=-1)
            if next_x.item() == 0:
                break
        s = ''.join([itos[idx.item()] for idx in x[0]])
        return s



In [12]:
import types
import torch.nn as nn
import torch.nn.functional as F

block_size = 3
n_embd = 10
n_hidden = 100
vocab_size = 27
n_layer = 5
dtype = torch.float64
eval_interval = 1000
n_steps = 200000
bs = 32
ini_lr = 1.0
eps=1e-5; momentum=0.001

torch.manual_seed(42)
# original model
model = MLP(vocab_size, block_size, n_embd, n_hidden, n_layer, dtype)
# torch model
layers = [
    nn.Embedding(vocab_size, n_embd), nn.Flatten(),
    nn.Linear(n_embd * block_size, n_hidden, bias=False), nn.BatchNorm1d(n_hidden, eps=eps, momentum=momentum), nn.Tanh(),
]
for _ in range(n_layer-2):
    layers.extend([
        nn.Linear(n_hidden, n_hidden, bias=False), nn.BatchNorm1d(n_hidden, eps=eps, momentum=momentum), nn.Tanh(),
    ])
layers.extend([
    nn.Linear(n_hidden, vocab_size, bias=False), nn.BatchNorm1d(vocab_size, eps=eps, momentum=momentum),
])
model_t = nn.Sequential(*layers).to(dtype)
# copy parameters
for i, p_t in enumerate(model_t.parameters()):
    p = model.parameters()[i]
    if i == 0 or p.ndim == 1: # i=0 is embedding layer
        p_t.data = p.data.clone() # clone to avoid inplace operation
    else:
        p_t.data = p.data.clone().T
# add generate method to model_t
model_t.block_size = block_size
model_t.generate = types.MethodType(model.generate.__func__, model_t)
# optimizer
optimizer = SGD(model, ini_lr)
optimizer_t = torch.optim.SGD(model_t.parameters(), lr=ini_lr)
# loss
loss_fn = CrossEntropyLoss()

model_t.train()
model.train()
for step in range(n_steps):
    lr = ini_lr if step < int(n_steps * 0.75) else ini_lr / 10
    optimizer.lr = lr
    for param_group in optimizer_t.param_groups:
        param_group['lr'] = lr
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]

    # ----- torch -----
    # forward
    logits_t = model_t(x)
    loss_t = F.cross_entropy(logits_t, y)
    # backward
    loss_t.backward()
    # update
    optimizer_t.step()
    optimizer_t.zero_grad()

    # ----- manual -----
    # forward
    logits = model(x)
    loss = loss_fn(logits, y)
    # backward
    # since grad buffer is stored in model class, we need to call backward imediately after forward
    # otherwise, grad buffer will be overwritten by next forward
    h_grad = loss_fn.backward(grad=1.0) # last layer, dloss=1.0
    model.backward(h_grad)
    # update
    optimizer.step()
    optimizer.zero_grad()

    # eval
    if step % eval_interval == 0: 
        # val loss is actually one step later than train loss
        x, y = X_val, Y_val
        model_t.eval()
        model.eval()
        with torch.no_grad():
            logits_t = model_t(x)
            val_loss_t = F.cross_entropy(logits_t, y)
            print(f'model_t step: {step}, train loss: {loss_t.item()}, val loss: {val_loss_t.item()}')
        logits = model(x)
        val_loss = loss_fn(logits, y)
        print(f'model   step: {step}, train loss: {loss.item()}, val loss: {val_loss.item()}')
        print(f"eval model forward diff: {(model(x) - model_t(x)).abs().max()}")
        rm, rv = model.net.layers[3].running_mean, model.net.layers[3].running_var
        rmt, rvt = model_t[3].running_mean, model_t[3].running_var
        print(f"eval model running mean diff: {(rm - rmt).abs().max()}")
        print(f"eval model running var diff: {(rv - rvt).abs().max()}")
        model_t.train()
        model.train()
    
        x, y = X_train[idx], Y_train[idx]
        print(f"train model forward diff: {(model(x) - model_t(x)).abs().max()}")
        print()




model_t step: 0, train loss: 3.306549425058978, val loss: 3.26912796100787
model   step: 0, train loss: 3.3065494250589778, val loss: 3.2691279610078707
eval model forward diff: 1.0408340855860843e-16
eval model running mean diff: 4.0657581468206416e-20
eval model running var diff: 0.0
train model forward diff: 5.551115123125783e-16

model_t step: 1000, train loss: 3.0082452560929775, val loss: 2.6382328068110885
model   step: 1000, train loss: 3.008245256092977, val loss: 2.6382328068110703
eval model forward diff: 9.00790553259867e-12
eval model running mean diff: 2.8245461525244764e-13
eval model running var diff: 1.120215031846783e-13
train model forward diff: 3.3674868449296014e-12

model_t step: 2000, train loss: 2.4426375385086634, val loss: 2.405136117561469
model   step: 2000, train loss: 2.4426375385087455, val loss: 2.4051361175614887
eval model forward diff: 8.146372465489549e-12
eval model running mean diff: 6.818989817247711e-13
eval model running var diff: 5.601075159233

In [14]:
x, y = X_test, Y_test

model.eval()
logits = model(x)
test_loss = loss_fn(logits, y).item()
print(f'model  test loss: {test_loss}')
model_t.eval()
logits = model_t(x)
test_loss = F.cross_entropy(logits, y).item()
print(f'model_t test loss: {test_loss}')
print()

torch.manual_seed(42)
for _ in range(10):
    out = model.generate('.' * block_size, max_new_tokens=10, do_sample=True, temperature=0.5)
    print(out)
print()
torch.manual_seed(42)
with torch.no_grad():
    for _ in range(10):
        out = model_t.generate('.' * block_size, max_new_tokens=10, do_sample=True, temperature=0.5)
        print(out)




model  test loss: 2.109030602058709
model_t test loss: 2.1090306021799923

...anuel.
...avann.
...aarian.
...dan.
...shan.
...silvin.
...alaya.
...jermann.
...elianna.
...anna.

...anuel.
...avann.
...aarian.
...dan.
...shan.
...silvin.
...alaya.
...jermann.
...elianna.
...anna.
