# makemore: becoming a backprop ninja

## 2024 08 20 practice

## 2024 11 26 finish the video 

## Question : how should I use x += y or x = x + y

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures


In [2]:
# read in all the words
words = open('makemore/names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [3]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [4]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [5]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [6]:
vocab_size

27

In [7]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [8]:
batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [9]:
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean() # plug out all the correct index probabilities and average

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way # important !!!! otherwise no gradient for the non-leaft tensors
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.3614, grad_fn=<NegBackward0>)

In [10]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1/n
# no matter how log probs is confidently wrong , the gradient of loss function wrt. logprobs is -1/n on the correct class

In [11]:
dlogprobs_probs = probs ** -1 
dprobs = dlogprobs * dlogprobs_probs # here a uniform distribution and a extreme value distribution will make different

In [12]:
dcounts_sum_inv = torch.einsum("ij,ij->i",dprobs,counts)[:,None] # dl/db_i = \sum_j dl/dc_ij * dc_ij/db_i from gradient chain rule

In [13]:
dcounts_sum = (-1*counts_sum**(-2))*dcounts_sum_inv

In [14]:
# dcounts should be solved by chain rule
# exp : z = yx, y = 2x; dz = ydx + xdy , dy = 2 dx 

# forward
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
# probs = counts * counts_sum_inv

dcounts_1 = torch.einsum("ij,i->ij",dprobs,counts_sum_inv[...,0]) # first parts of gradients
dcounts_2 = torch.ones_like(counts) * dcounts_sum
dcounts = dcounts_1 + dcounts_2

# Q : take two things into consideration, uniform softmax and a extreme value softmax, what will happen?

In [15]:
# forward
# counts = norm_logits.exp() ; the derivative of exp is exp x 
with torch.no_grad():
    dnorm_logits = dcounts * counts

In [16]:
# forward
# norm_logits = logits - logit_maxes # subtract max for numerical stability

# dlogit_maxes = -1 * torch.einsum("ij->i",dnorm_logits)
with torch.no_grad():
    dlogit_maxes  = -1 * torch.sum(dnorm_logits,1,keepdim=True)


In [17]:
# forward
# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes # subtract max for numerical stability
with torch.no_grad():
    # dlogits for two parts
    dlogits1 = torch.ones_like(logits) * dnorm_logits
    dlogits2 = torch.zeros_like(logits)
    # find max_positions
    max_positions = logits.argmax(1,keepdim=False)
    dlogits2[range(n),max_positions] = 1 * dlogit_maxes[...,0]
    dlogits = dlogits1 + dlogits2
    # we have dlogit_maxes 

In [18]:
# IN MLP, we don't need tensor / tensor, we can all get scalar / tensor

# forward
# Linear layer 2
# logits = h @ W2 + b2 # output layer
# don't compute the local gradient
# just give the gradient of the loss wrt. weight and linear

with torch.no_grad():
    dh =  torch.einsum("kj,ij->ki",dlogits,W2) 
    # dh = dlogits @ W2.T
    db2 = torch.sum(dlogits,dim=0,keepdim=True).squeeze()
    # dW2 = torch.einsum("kj,ki->ij",dlogits,h)  may cause numerical bugs
    dW2 = h.T @ dlogits 

In [19]:
# forward

# hpreact = bngain * bnraw + bnbias
# h = torch.tanh(hpreact) # hidden layer
with torch.no_grad():
    dhpreact =  dh * (1-h**2)
    dbngain = torch.sum(dhpreact*bnraw,dim=0,keepdim=True) 
    dbnbias = torch.sum(dhpreact,dim=0,keepdim=True)
    dbnraw = dhpreact * bngain

In [20]:
# forward
# emb = C[Xb] # embed the characters into vectors
# embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# # Linear layer 1
# hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# # BatchNorm layer
# bnmeani = 1/n*hprebn.sum(0, keepdim=True)
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff**2
# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
# bnvar_inv = (bnvar + 1e-5)**-0.5
# bnraw = bndiff * bnvar_inv
with torch.no_grad():
    dbnvar_inv = torch.sum(dbnraw * bndiff,dim=0,keepdim=True)
    dbnvar = dbnvar_inv * (-0.5) * (bnvar + 1e-5)**(-1.5)
    dbndiff2 = (1/(n-1) * dbnvar).repeat(32,1)
    dbndiff = 2 * bndiff * dbndiff2 + dbnraw * bnvar_inv
    dbnmeani = -1 * dbndiff.sum(dim=0,keepdim=True)   
    dhprebn = dbndiff + 1/n * dbnmeani
    dembcat = torch.einsum("ik,jk->ij",dhprebn,W1) 
    dW1 = torch.einsum("ij,ik->jk",embcat,dhprebn)
    db1 = dhprebn.sum(dim=0)
    demb = dembcat.reshape(emb.shape)
    dC = torch.zeros_like(C)
    dC[Xb] = demb

In [25]:
emb.shape, demb.shape,emb.grad.shape

(torch.Size([32, 3, 10]), torch.Size([32, 3, 10]), torch.Size([32, 3, 10]))

In [27]:
dC.shape, C.shape

(torch.Size([27, 10]), torch.Size([27, 10]))

In [21]:
# Exercise 1: backprop through the whole thing manually, 
# backpropagating through exactly all of the variables 
# as they are defined in the forward pass above, one by one

# -----------------
# YOUR CODE HERE :)
# -----------------

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1/n
dprobs = dlogprobs * (probs ** -1) # 对于真正label，如果概率p非常低，那么1/p就会非常大，因此logp会boost这个梯度
dcounts_sum_inv = (dprobs * counts).sum(dim=1,keepdim=True)   # two steps for pytorch : 1. 广播成同样的形状 2.对于同样的形状计算梯度
# element-wise梯度：dprobs * counts 是loss对于广播的之后的矩阵梯度
# 广播、replication梯度：在广播的维度相加


cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmeani', dbnmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: False | approximate: True  | maxdiff: 4.656612873077393e-10
bngain          | exact: False | approximate: True  | maxdiff: 1.862645149230957e-09
bnbias          | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09
bnraw   

In [33]:
C.shape

torch.Size([27, 10])

In [34]:
Xb.shape

torch.Size([32, 3])

In [36]:
C[Xb].shape

torch.Size([32, 3, 10])

In [32]:
C.grad

tensor([[-2.9942e-02,  1.3016e-02, -2.7884e-02, -1.8364e-02, -2.2851e-02,
         -2.7160e-03,  1.0079e-02,  3.9186e-02,  5.3068e-03, -1.2683e-02],
        [-1.1553e-03, -1.1990e-02,  2.6647e-02,  5.5990e-04,  2.8722e-02,
         -1.0913e-02, -6.5035e-03, -8.1762e-03, -1.0701e-02, -1.6956e-03],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0329e-03, -4.0306e-03,  2.4115e-04,  6.7376e-04,  1.4454e-03,
          2.1888e-03, -5.8566e-03, -5.8188e-03,  1.8356e-04,  2.2696e-03],
        [ 2.3291e-03,  6.5369e-03, -8.2161e-04,  8.6610e-03, -1.1278e-02,
         -1.2828e-03,  1.0304e-02,  1.5459e-03,  6.5826e-03,  7.1600e-03],
        [-7.0182e-04,  1.2286e-02,  5.2331e-03, -1.4472e-03, -4.6413e-03,
          4.1369e-04,  3.1254e-03,  8.0687e-03, -8.6854e-03,  8.2609e-03],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+0

In [120]:
import torch
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep) from Llamma3"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
        )

In [24]:
x = torch.randn(1, 2, 4, 2)
y = repeat_kv(x, 2)

In [25]:
x

tensor([[[[-0.9775, -1.0789],
          [ 0.1624, -1.8405],
          [ 0.9507,  0.3011],
          [-1.2406, -0.5837]],

         [[-1.1423,  1.7853],
          [ 0.0409,  1.1610],
          [-0.2402,  1.9578],
          [ 0.7019,  2.1591]]]])

In [26]:
y

tensor([[[[-0.9775, -1.0789],
          [-0.9775, -1.0789],
          [ 0.1624, -1.8405],
          [ 0.1624, -1.8405],
          [ 0.9507,  0.3011],
          [ 0.9507,  0.3011],
          [-1.2406, -0.5837],
          [-1.2406, -0.5837]],

         [[-1.1423,  1.7853],
          [-1.1423,  1.7853],
          [ 0.0409,  1.1610],
          [ 0.0409,  1.1610],
          [-0.2402,  1.9578],
          [-0.2402,  1.9578],
          [ 0.7019,  2.1591],
          [ 0.7019,  2.1591]]]])

In [27]:
z

tensor([[[[-0.7903,  2.4070, -0.4276,  0.0947, -1.2497],
          [-0.7903,  2.4070, -0.4276,  0.0947, -1.2497],
          [ 0.8955, -0.0571, -1.2833, -0.0313, -0.2610],
          [ 0.8955, -0.0571, -1.2833, -0.0313, -0.2610],
          [ 1.2286, -0.2659, -0.5508,  0.8652,  0.6643],
          [ 1.2286, -0.2659, -0.5508,  0.8652,  0.6643],
          [ 0.2125, -0.7300, -2.1901, -1.4154,  0.2888],
          [ 0.2125, -0.7300, -2.1901, -1.4154,  0.2888]],

         [[ 0.5943,  0.8673,  1.7449, -0.9035, -0.3915],
          [ 0.5943,  0.8673,  1.7449, -0.9035, -0.3915],
          [ 0.8661, -1.0370, -0.2133,  0.4721,  0.2251],
          [ 0.8661, -1.0370, -0.2133,  0.4721,  0.2251],
          [ 0.0478, -0.0921, -1.5668,  1.3102,  0.3000],
          [ 0.0478, -0.0921, -1.5668,  1.3102,  0.3000],
          [ 0.9285,  1.5418,  0.3743,  2.2138, -0.0551],
          [ 0.9285,  1.5418,  0.3743,  2.2138, -0.0551]],

         [[-0.5614,  0.1855,  0.5033,  0.4404,  0.4352],
          [-0.5614,  0.1855

In [28]:
z = torch.repeat_interleave(x, dim=2, repeats=2)

In [29]:
z

tensor([[[[-0.9775, -1.0789],
          [-0.9775, -1.0789],
          [ 0.1624, -1.8405],
          [ 0.1624, -1.8405],
          [ 0.9507,  0.3011],
          [ 0.9507,  0.3011],
          [-1.2406, -0.5837],
          [-1.2406, -0.5837]],

         [[-1.1423,  1.7853],
          [-1.1423,  1.7853],
          [ 0.0409,  1.1610],
          [ 0.0409,  1.1610],
          [-0.2402,  1.9578],
          [-0.2402,  1.9578],
          [ 0.7019,  2.1591],
          [ 0.7019,  2.1591]]]])

In [58]:
AA_FREQ = {'A': 0.07421620506799341,
 'R': 0.05161448614128464,
 'N': 0.044645808512757915,
 'D': 0.05362600083855441,
 'C': 0.02468745716794485,
 'Q': 0.03425965059141602,
 'E': 0.0543119256845875,
 'G': 0.074146941452645,
 'H': 0.026212984805266227,
 'I': 0.06791736761895376,
 'L': 0.09890786849715096,
 'K': 0.05815568230307968,
 'M': 0.02499019757964311,
 'F': 0.04741845974228475,
 'P': 0.038538003320306206,
 'S': 0.05722902947649442,
 'T': 0.05089136455028703,
 'W': 0.013029956129972148,
 'Y': 0.03228151231375858,
 'V': 0.07291909820561925}

In [35]:
# make ESM2 backprop


L = 256
hidden = 1280

pssm = torch.randn(L, 20)
pssm.requires_grad = True
W = torch.randn(20, hidden)
W.requires_grad = True

In [36]:
one_hot = pssm.argmax(1,keepdim=False)
one_hot = F.one_hot(one_hot,20).float()
one_hot.requires_grad = True
activation = one_hot @ W
loss = activation.sum()
# real gradient
loss.backward()

In [14]:
dl_dactivation = torch.ones_like(activation)

In [42]:
dl_dactivation.shape

torch.Size([256, 1280])

In [54]:
done_hot = torch.einsum("kj,ij->ki",dl_dactivation, W)
dW = torch.einsum("kj, ki->ij",dl_dactivation,one_hot)

In [50]:
done_hot 

tensor([[ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        ...,
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610]],
       grad_fn=<ViewBackward0>)

In [39]:
one_hot.grad

tensor([[ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        ...,
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610],
        [ 57.4189,  17.2741,  -7.4375,  ...,   7.2317,  46.7600, -11.4610]])

In [30]:
activation.grad

  activation.grad


In [40]:
one_hot

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)