In [1]:
import random
import torch

words = open('names.txt', 'r').read().splitlines()
words = list(set(words))
random.seed(42)
random.shuffle(words)
len(words)

chs = list(set(''.join(words + ['.'])))
chs = sorted(chs, reverse=False)
stoi = {ch: i for i, ch in enumerate(chs)}
itos = {i: ch for i, ch in enumerate(chs)}

# predict next token use previous 3 tokens
X, Y = [], []

for w in words:
    context = '...'
    for ch in w + '.':
        x = [stoi[c] for c in context]
        y = stoi[ch]
        X.append(x)
        Y.append(y)
        context = context[1:] + ch

X = torch.tensor(X)
Y = torch.tensor(Y)
n1, n2  = int(0.8 * len(X)), int(0.9 * len(X))

X_train, X_val, X_test = X.tensor_split([n1, n2])
Y_train, Y_val, Y_test = Y.tensor_split([n1, n2])

X_train.shape, X_val.shape, X_test.shape, Y_train.shape, Y_val.shape, Y_test.shape


(torch.Size([169062, 3]),
 torch.Size([21133, 3]),
 torch.Size([21133, 3]),
 torch.Size([169062]),
 torch.Size([21133]),
 torch.Size([21133]))

# implement backward from scratch


In [76]:
n_embd = 10
n_hidden = 200
vocab_size = 27
block_size = 3

def get_params():
    torch.manual_seed(42)
    C = torch.randn(vocab_size, n_embd)
    w1 = torch.randn(n_embd * block_size, n_hidden) * (n_embd * block_size)**-0.5
    w2 = torch.randn(n_hidden, vocab_size) * (5/3) * (n_hidden)**-0.5 * 0.1 # 0.1 is for less confident at initialization
    b2 = torch.randn(vocab_size) * 0
    bnw = torch.ones(n_hidden)
    bnb = torch.zeros(n_hidden)
    params = [C, w1, w2, b2, bnw, bnb]
    for p in params:
        p.requires_grad = True
    return params

params = get_params()
C, w1, w2, b2, bnw, bnb = params
bs = 32
idx = torch.randint(0, X_train.shape[0], (bs,))
x, y = X_train[idx], Y_train[idx]

## forward and torch backward


In [71]:
# buffer
mean_proj = torch.ones(1, bs) / bs
var_proj = (torch.eye(bs) - mean_proj)

# forward
emb = C[x].view(x.shape[0], -1)
emb.retain_grad()
hpreact = emb @ w1
hpreact.retain_grad()
bnmeani = mean_proj @ hpreact
bnmeani.retain_grad()
bnstdi = (var_proj @ hpreact).square().mean(dim=0, keepdim=True).sqrt()
bnstdi.retain_grad()
hpreact_bn = (hpreact - bnmeani) / bnstdi * bnw + bnb
hpreact_bn.retain_grad()
h = hpreact_bn.tanh()
h.retain_grad()
logits = h @ w2 + b2
logits.retain_grad()
# 2. loss
exp_l = logits.exp()
exp_l.retain_grad()
count = exp_l.sum(dim=-1, keepdim=True)
count.retain_grad()
probs = exp_l / count
probs.retain_grad()
nlls = -probs.log()
nlls.retain_grad()
loss = nlls[torch.arange(y.shape[0]), y].mean()

# backward
loss.backward()

## manual backward

In [72]:
def init_grad():
    # buffer grad
    nlls_grad = torch.zeros(bs, vocab_size)
    probs_grad = torch.zeros(bs, vocab_size)
    count_grad = torch.zeros(bs, 1)
    exp_l_grad = torch.zeros(bs, vocab_size)
    logits_grad = torch.zeros(bs, vocab_size)
    h_grad = torch.zeros(bs, n_hidden)
    hpreact_bn_grad = torch.zeros(bs, n_hidden)
    bnmeani_grad = torch.zeros(1, n_hidden)
    bnstdi_grad = torch.zeros(1, n_hidden)
    bnvari_grad = torch.zeros(1, n_hidden)
    hpreact_grad = torch.zeros(bs, n_hidden)
    emb_grad = torch.zeros(bs, n_embd * block_size)
    # param grad
    C_grad = torch.zeros(vocab_size, n_embd)
    w1_grad = torch.zeros(n_embd * block_size, n_hidden)
    w2_grad = torch.zeros(n_hidden, vocab_size)
    b2_grad = torch.zeros(vocab_size)
    bnw_grad = torch.zeros(n_hidden)
    bnb_grad = torch.zeros(n_hidden)
    buffer_grads = [
            nlls_grad, probs_grad, count_grad, exp_l_grad, logits_grad, 
            h_grad, hpreact_bn_grad, bnmeani_grad, bnstdi_grad, bnvari_grad, hpreact_grad, emb_grad
        ]
    param_grads = [C_grad, w1_grad, w2_grad, b2_grad, bnw_grad, bnb_grad]
    return buffer_grads, param_grads

def zero_grad():
    for g in buffer_grads:
        g.fill_(0)
    for g in param_grads:
        g.fill_(0)

buffer_grads, param_grads = init_grad()
(
    nlls_grad, probs_grad, count_grad, exp_l_grad, logits_grad, 
    h_grad, hpreact_bn_grad, bnmeani_grad, bnstdi_grad, bnvari_grad, hpreact_grad, emb_grad
) = buffer_grads
C_grad, w1_grad, w2_grad, b2_grad, bnw_grad, bnb_grad = param_grads

In [73]:
def backward(check=False):
    # 1. loss
    nlls_grad[torch.arange(y.shape[0]), y] = 1 / bs
    probs_grad[torch.arange(y.shape[0]), y] = -1 / probs.data[torch.arange(y.shape[0]), y] * nlls_grad[torch.arange(y.shape[0]), y]
    count_grad = -(exp_l.data * probs_grad).sum(dim=-1, keepdim=True) / count.data**2
    exp_l_grad = probs_grad / count.data + count_grad  # one is from e/c to e, one is from c=\sum e to e
    logits_grad = exp_l.data * exp_l_grad

    # 2. logits
    h_grad = logits_grad @ w2.data.T
    hpreact_bn_grad = h_grad * (1 - h.data**2)
    # bn
    bnmeani_grad = ((-bnw.data / bnstdi.data) * hpreact_bn_grad).sum(dim=0, keepdim=True)
    bnstdi_grad = (-((hpreact.data - bnmeani.data) * bnw.data / bnstdi.data**2) * hpreact_bn_grad).sum(dim=0, keepdim=True)
    # hpreact
    hpreact_grad_mean = bnmeani_grad * torch.ones_like(hpreact.data) / bs
    hpreact_grad_std = bnstdi_grad * (1 / 2 / bnstdi.data) * (1 / bs) * (2 * var_proj @ hpreact.data)
    hpreact_grad_direct = hpreact_bn_grad * (bnw.data / bnstdi.data)
    hpreact_grad = hpreact_grad_mean + hpreact_grad_std + hpreact_grad_direct
    # emb
    emb_grad = hpreact_grad @ w1.data.T
    
    # 3. params
    C_grad.index_add_(dim=0, index=x.view(-1), source=emb_grad.view(-1, n_embd)) # add emb_grad[i] to C[x[i]]
    w1_grad = emb.data.T @ hpreact_grad
    w2_grad = h.data.T @ logits_grad
    b2_grad = logits_grad.sum(dim=0)
    bnw_grad = ((hpreact.data - bnmeani.data) / bnstdi.data * hpreact_bn_grad).sum(dim=0)
    bnb_grad = hpreact_bn_grad.sum(dim=0)
    if check:
        is_equal1 = [torch.allclose(nlls_grad, nlls.grad), torch.allclose(probs_grad, probs.grad), torch.allclose(count_grad, count.grad), torch.allclose(exp_l_grad, exp_l.grad), torch.allclose(logits_grad, logits.grad)]
        is_equal2 = [torch.allclose(h_grad, h.grad), torch.allclose(hpreact_bn_grad, hpreact_bn.grad), torch.allclose(bnmeani_grad, bnmeani.grad), torch.allclose(bnstdi_grad, bnstdi.grad), torch.allclose(hpreact_grad, hpreact.grad), torch.allclose(emb_grad, emb.grad)]
        is_equal3 = [torch.allclose(C_grad, C.grad), torch.allclose(w1_grad, w1.grad), torch.allclose(w2_grad, w2.grad), torch.allclose(b2_grad, b2.grad), torch.allclose(bnw_grad, bnw.grad), torch.allclose(bnb_grad, bnb.grad)]
        return all(is_equal1), all(is_equal2), all(is_equal3)

equal1, equal2, equal3 = backward(check=True)
print('same grad for loss calculation:', equal1)
print('same grad for logits calculation:', equal2)
print('same grad for params:', equal3)




same grad for loss calculation: True
same grad for logits calculation: True
same grad for params: True


In [88]:
import torch.nn.functional as F

# model
params = get_params()
C, w1, w2, b2, bnw, bnb = params
bnmean_running = torch.zeros(n_hidden)
bnstd_running = torch.ones(n_hidden)
buffer_grads, param_grads = init_grad()
(
    nlls_grad, probs_grad, count_grad, exp_l_grad, logits_grad, 
    h_grad, hpreact_bn_grad, bnmeani_grad, bnstdi_grad, bnvari_grad, hpreact_grad, emb_grad
) = buffer_grads
C_grad, w1_grad, w2_grad, b2_grad, bnw_grad, bnb_grad = param_grads

# args
bs = 32
n_steps = 10000
ini_lr = 1.0

# buffer
mean_proj = torch.ones(1, bs) / bs
var_proj = (torch.eye(bs) - mean_proj)

torch.manual_seed(42)
for step in range(n_steps):
    lr = ini_lr if step < n_steps // 2 else ini_lr / 10
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]

    # forward
    emb = C[x].view(x.shape[0], -1)
    emb.retain_grad()
    hpreact = emb @ w1
    hpreact.retain_grad()
    bnmeani = mean_proj @ hpreact
    bnmeani.retain_grad()
    bnstdi = (var_proj @ hpreact).square().mean(dim=0, keepdim=True).sqrt()
    bnstdi.retain_grad()
    hpreact_bn = (hpreact - bnmeani) / bnstdi * bnw + bnb
    hpreact_bn.retain_grad()
    h = hpreact_bn.tanh()
    h.retain_grad()
    logits = h @ w2 + b2
    logits.retain_grad()
    # 2. loss
    exp_l = logits.exp()
    exp_l.retain_grad()
    count = exp_l.sum(dim=-1, keepdim=True)
    count.retain_grad()
    probs = exp_l / count
    probs.retain_grad()
    nlls = -probs.log()
    nlls.retain_grad()
    loss = nlls[torch.arange(y.shape[0]), y].mean()

    # backward
    loss.backward() # for compare, need to implement before manual backward()
    equal1, equal2, equal3 = backward(check=True)
    if step % 1000 == 0:
        with torch.no_grad():
            emb = C[X_val].view(X_val.shape[0], -1)
            hpreact = emb @ w1
            hpreact = (hpreact - hpreact.mean(dim=0, keepdim=True)) / hpreact.std(dim=0, keepdim=True) * bnw + bnb
            h = hpreact.tanh()
            logits = h @ w2 + b2
            val_loss = F.cross_entropy(logits, Y_val)
            print(f'step: {step}, train loss: {loss.item()}, val loss: {val_loss.item()}, grad_equal: {all([equal1, equal2, equal3])}')
    
    # update
    for p, g in zip(params, param_grads):
        p.data -= lr * g
        zero_grad()
        p.grad = None # for compare
    with torch.no_grad():
        bnmean_running = bnmean_running * 0.99 + bnmeani * 0.01
        bnstd_running = bnstd_running * 0.99 + bnstdi * 0.01
    
    

step: 0, train loss: 3.318601369857788, val loss: 3.314592123031616, grad_equal: True
step: 1000, train loss: 3.2841665744781494, val loss: 3.279362678527832, grad_equal: True
step: 2000, train loss: 3.237403392791748, val loss: 3.261941432952881, grad_equal: True
step: 3000, train loss: 3.278920888900757, val loss: 3.250934600830078, grad_equal: True
step: 4000, train loss: 3.2361202239990234, val loss: 3.2448129653930664, grad_equal: True
step: 5000, train loss: 3.265164375305176, val loss: 3.2416939735412598, grad_equal: True
step: 6000, train loss: 3.2153725624084473, val loss: 3.241469621658325, grad_equal: True
step: 7000, train loss: 3.263246536254883, val loss: 3.2412421703338623, grad_equal: True
step: 8000, train loss: 3.2219350337982178, val loss: 3.241028070449829, grad_equal: True
step: 9000, train loss: 3.2198879718780518, val loss: 3.2408359050750732, grad_equal: True


# compare with torch

In [85]:
import torch.nn.functional as F

# model
params = get_params()
C, w1, w2, b2, bnw, bnb = params
bnmean_running = torch.zeros(n_hidden)
bnstd_running = torch.ones(n_hidden)

# args
bs = 32
n_steps = 10000
ini_lr = 1.0

# buffer
mean_proj = torch.ones(1, bs) / bs
var_proj = (torch.eye(bs) - mean_proj)

torch.manual_seed(42)
for step in range(n_steps):
    lr = ini_lr if step < n_steps // 2 else ini_lr / 10
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]

    # forward
    emb = C[x].view(x.shape[0], -1)
    emb.retain_grad()
    hpreact = emb @ w1
    hpreact.retain_grad()
    bnmeani = mean_proj @ hpreact
    bnmeani.retain_grad()
    bnstdi = (var_proj @ hpreact).square().mean(dim=0, keepdim=True).sqrt()
    bnstdi.retain_grad()
    hpreact_bn = (hpreact - bnmeani) / bnstdi * bnw + bnb
    hpreact_bn.retain_grad()
    h = hpreact_bn.tanh()
    h.retain_grad()
    logits = h @ w2 + b2
    logits.retain_grad()
    # 2. loss
    exp_l = logits.exp()
    exp_l.retain_grad()
    count = exp_l.sum(dim=-1, keepdim=True)
    count.retain_grad()
    probs = exp_l / count
    probs.retain_grad()
    nlls = -probs.log()
    nlls.retain_grad()
    loss = nlls[torch.arange(y.shape[0]), y].mean()

    # backward
    loss.backward() # for compare, need to implement before manual backward()
    if step % 1000 == 0:
        with torch.no_grad():
            emb = C[X_val].view(X_val.shape[0], -1)
            hpreact = emb @ w1
            hpreact = (hpreact - hpreact.mean(dim=0, keepdim=True)) / hpreact.std(dim=0, keepdim=True) * bnw + bnb
            h = hpreact.tanh()
            logits = h @ w2 + b2
            val_loss = F.cross_entropy(logits, Y_val)
            print(f'step: {step}, train loss: {loss.item()}, val loss: {val_loss.item()}')
    
    # update
    for p in params:
        p.data -= lr * p.grad
        p.grad = None # for compare
    with torch.no_grad():
        bnmean_running = bnmean_running * 0.99 + bnmeani * 0.01
        bnstd_running = bnstd_running * 0.99 + bnstdi * 0.01
    
    

step: 0, train loss: 3.318601369857788, val loss: 3.314592123031616
step: 1000, train loss: 2.234222173690796, val loss: 2.458853006362915
step: 2000, train loss: 2.4280223846435547, val loss: 2.4311578273773193
step: 3000, train loss: 2.7830185890197754, val loss: 2.400797128677368
step: 4000, train loss: 2.239293336868286, val loss: 2.3843557834625244
step: 5000, train loss: 2.6441760063171387, val loss: 2.380204200744629
step: 6000, train loss: 2.298123598098755, val loss: 2.313563823699951
step: 7000, train loss: 2.212418794631958, val loss: 2.3081648349761963
step: 8000, train loss: 2.199744701385498, val loss: 2.3080670833587646
step: 9000, train loss: 2.272951364517212, val loss: 2.3036866188049316
