# makemore: becoming a backprop ninja

## Import

In [2]:
import torch 
import matplotlib.pyplot as plt 
import torch.nn.functional as F 
import numpy as np 
import random

%matplotlib inline 
%load_ext autoreload
%autoreload 2

In [3]:
SEED = 2147483647

## Data loading

In [4]:
words = open('name.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [5]:
chars = sorted(set(''.join(words)))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi['.'] = 0 
itos = {v: k for k, v in stoi.items()}
print(stoi, itos)

{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0} {1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [12]:
# Dataset builder 
block_size = 3 
vocab_size = len(itos)

def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
        
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y 

random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

Xtr, Ytr  = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])


torch.Size([182580, 3]) torch.Size([182580])
torch.Size([22767, 3]) torch.Size([22767])
torch.Size([22799, 3]) torch.Size([22799])


## Utils_fn

In [13]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

## Model 

In [14]:
g = torch.Generator().manual_seed(SEED)


n_embb = 10 
n_hidden = 64 

C = torch.randn((vocab_size, n_embb),               generator=g)
# Layer 1 
W1 = torch.randn((block_size * n_embb, n_hidden),   generator= g) * (5/3)/((n_embb * block_size) ** 0.5)
b1 = torch.randn(n_hidden,                          generator= g) * 0.1

# Layer 2
W2 = torch.randn((n_hidden, vocab_size),            generator=g) * 0.1
b2 = torch.randn(vocab_size,                        generator=g) * 0.1

bngain = torch.randn((1, n_hidden),                 generator=g) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden),                 generator=g) * 0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]

print(sum(p.nelement() for p in parameters))

for p in parameters:
    p.requires_grad = True
    

4137


In [15]:
batch_size = 32
n = batch_size 
# mini batch 
ix = torch.randint(low = 0 , high = Xtr.shape[0], size=(batch_size,), generator=g)
Xb , Yb = Xtr[ix] , Ytr[ix]
Xb.shape , C.shape

(torch.Size([32, 3]), torch.Size([27, 10]))

In [17]:
emb = C[Xb]                                                         # (32,3, 10)
embcat = emb.view(emb.shape[0], -1)                                 # (32, 30)
# layer 1 
hprebn = embcat @ W1 + b1                                           # (32, 30) @ (30, 64) + (1, 64) = (32, 64)
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)                           # (1, 64)
bndiff =  hprebn - bnmeani                                          # (32, 64)
bndiff2 = bndiff**2                                                 # (32, 64)
bnvar = 1/(n-1)*bndiff2.sum(0, keepdim=True)                        # (1, 64)
bnvar_inv = (bnvar + 1e-5)**-0.5                                    # (1, 64)
bnraw = bndiff * bnvar_inv                                          # (32, 64) * (1, 64) = (32, 64)
hpreact = bngain * bnraw + bnbias                                   # (1, 64) * (32, 64) + (1, 64) = (32, 64)

# Non - lineartity
h = torch.tanh(hpreact)                                             # (32, 64)
# Linear layer 2
logits = h @ W2 + b2                                                # (32, 64) @ (64, 27) + (1, 27) = (32, 27)
logit_maxes = logits.max(1, keepdim=True).values                    # (32, 1)
norm_logits = logits - logit_maxes                                  # (32, 27)
counts = norm_logits.exp()                                          # (32, 27)
counts_sum = counts.sum(1, keepdim=True)                            # (32, 1)
counts_sum_inv = counts_sum ** -1                                   # (32, 1)
probs = counts * counts_sum_inv                                      # (32, 27)
logprobs = probs.log()                                              # (32, 27)
loss = -logprobs[range(n), Yb].mean()                               # out of 27 vlaues checks the probability at index Yb 

for p in parameters:
    p.grad = None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss.item()

3.359337329864502

### Exercise 1: backprop through the whole thing manually, backpropagating through exactly all of the variables as they are defined in the forward pass above, one by one

In [None]:


dlogprobs = torch.zeros_like(logprobs) 
# why zero because the gradient at the rest of the values where we are not doing range(n), yb don't contribute to our loss 

dlogprobs[range(n), Yb] = -1.0/n
dprobs = (1.0/probs) * dlogprobs  

# d/dx (Log(x)) -> 1/x  # so when our correct probability is geting low value what will happen is 1/ prob will become greater 
# and multipling it by dlogprobs will boost it

dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
# here the point to note is that we have hvaing a broadcasting situation so we have something that is getting 
# its self replicated so for this we have to add the gradients coming from different branches 
# c = a * b  # a(3,3) b(3,1)
# a11 * b11 , a12 * b11, a13 *b11
# a21 * b21, a22 * b21, a23 * b21
# a31 * b31, a32 * b31, a33 * b31
# now derivative of c with respect to b is a and for broadcasting we will sum it 

dcounts = (counts_sum_inv * dprobs)
dcounts_sum = (-counts_sum **-2 ) * dcounts_sum_inv
dcounts += torch.ones_like(counts) * dcounts_sum
dnorm_logits = counts * dcounts # (norm_logits.exp()) * dcounts 
# since norm_logits.exp() we have already claculate and stored as conutns
dlogit_maxes = (-dnorm_logits).sum(1,keepdim=True)
dlogits = dnorm_logits.clone()
dlogits += F.one_hot(logits.max(dim=1).indices, num_classes=logits.shape[1]) * dlogit_maxes
# the reason why we are using this one_hot encoder is 
# logit_max = logits.max(1, keepdim=True).values <- this is our fn 
# so here what we are getting is just the max values in each rows 
# so the local gradient needs to be 1 at those indcies and at reast of the place it should be 0 

dh = dlogits @ W2.T
dW2 = h.T @ dlogits
# check the image stored to understand the reasoning behind image[]
db2 = (dlogits).sum(0, keepdim=True)

dhpreact = (1 - h**2) * dh

dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnbias = (dhpreact).sum(0, keepdim=True)
dbnraw = (bngain * dhpreact)
dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
dbnvar = (-0.5* (bnvar + 1e-5)**-1.5) * dbnvar_inv
dbndiff2 = ((1/(n-1))* torch.ones_like(bndiff2)) * dbnvar
dbndiff += (2 *bndiff) * dbndiff2 
dbnmeani = (-dbndiff).sum(0)
# why is the - taken 
dhprebn = 1.0/n * torch.ones_like(hprebn) * dbnmeani + dbndiff.clone()
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn 
db1 = dhprebn.sum(0)
demb = dembcat.view(emb.shape )
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for i in range(Xb.shape[1]):
        ix = Xb[k, i]
        dC[ix] += demb[k, i]

cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmeani', dbnmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | exact: True  | approximate: True  | maxdiff: