In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [2]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [3]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [4]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [5]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [6]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [7]:
batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [8]:
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.3382, grad_fn=<NegBackward0>)

In [9]:
counts.shape, counts_sum.shape, counts_sum_inv.shape, probs.shape

(torch.Size([32, 27]),
 torch.Size([32, 1]),
 torch.Size([32, 1]),
 torch.Size([32, 27]))

In [23]:
# Exercise 1: backprop through the whole thing manually, 
# backpropagating through exactly all of the variables 
# as they are defined in the forward pass above, one by one

'''
for example mean of a, b, c is
mean = (a + b + c) / 3 (NLL)
loss = -(a + b + c) / 3
loss = -1/3a + -1/3b + -1/3c
dloss/da = -1/3
dloss/db = -1/3
dloss/dc = -1/3

Therefore, generally dloss/da = -1/n
'''
'''
Additionaly we can see that logprobs is [32x27] and Yb is [32].
But in loss we are not taking account of other values of logprobs and shrinking it to Yb i.e 32. So what happens to gradient of other 
values of logprobs?
Well they don't matter as they don't contribute to the loss.
'''
dlogprobs = torch.zeros_like(logprobs) 
dlogprobs[range(n), Yb] = -1.0/n #(Derivate of -logprobs[range(n), Yb].mean() based on above logic)

dprobs = (1.0 / probs) * dlogprobs # Derivate of probs.log() (dlog/dx = 1/x) (* dlogprobs because of chain rule) 

"""
Derivative of counts*counts_sum_inv (counts_sum_inv and counts have different dimensions. So this multiplication takes
counts.shape = [32x27], counts_sum_inv.shape = [32x1]
place in 2 steps, 
1) Forecasting (replication) of counts_sum_inv w.r.t every dimension of counts and 
2) Multiplication
So while backpropogating, we will solve for step 2) i.e first find derivative of multiplication i.e.
dcounts_sum_inv = counts * dprobs     because,
if a = b * c,
then da/dc = c
therefore dcounts_sum_inv = counts and multiplication with dprobs because chain rule

and after this we will solve for step 1). We know in a case when a single tensor gets fed two gradients, we just sum the gradients. 
Therefore,
dcounts_sum_inv = (counts * dprobs).sum(1, keepdims=true) # keepdims true because we want to retain the shape
"""
dcounts_sum_inv = (counts * dprobs).sum(1, keepdims=True) 
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv # Derivative of counts_sum**-1 (dx**-1/dx = -(x**-2))
"""
Gradients divide into equal parts in addition (29:00)
+= because dcounts was calculated before as well. So we need to preserve the gradient
"""
dcounts += torch.ones_like(counts) * dcounts_sum

# Derivative of norm_logits.exp(). (de^x/dx=e^x. As we already calculated norm_logits.exp() in count, we 
# will use counts instead of norm_logits.exp()
dnorm_logits = counts * dcounts 
"""
Derivative of logits - logit_maxes
norm_logits.shape = [32, 27]
logits.shape = [32, 27]
logit_maxes.shape = [32, 1]
So there is a broadcasting happening in minus.
Meaning operation is taking place as follows:-
c11 c12 c13 = a11 a12 a13   b1
c21 c22 c23 = a21 a22 a23 - b2
c31 c32 c33 = a31 a32 a33   b3
so e.g. c32 = a32 - b3

This tells us the local derivative of a[] is 1, e.g. local derivative of a13 = 1 because a13 = (1)a13
And the local derivative of b[] is -1, e.g. local derivative of b1 = 1 because b1 = (-1)b1

Therefore the derivative of c will flow evenly to a[] and will add it up like before and b[] and we will add it up
like before because local derivative of b is -1, it will subtract in case of b.
Also, we will calculate dlogits twice because logits has two branches coming out of it in the forward pass.
"""
dlogits = dnorm_logits.clone() #Cloning for safety
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)

"""
Derivative of logits.max(1, keepdim=True).values

When we perform max() in pytorch, it plucks out the max value of each row of matrix and also tells us what is the index of
the plucked out value. These indexes are very important in backward pass because we essentially need to populate values
in the matrix at those specific indexes to calculate the gradient.

Now local derivate of each plucked out value will be one. Therefore,
dlogits = (local derivative that is 1) * dlogit_maxes
But because we have gradient in dlogits already from previous operation, we do +=

So to scatter the values in matrix at right indexes, we use one hot encoding, where the max value's index will be 1 and
rest will be 0. There are other approaches as well (40:18)
num_classes=logits.shape[1] means that dimension of each tensor should be logits.shape[1] that is 27

Also something to keep in mind is that dlogit_maxes is [32, 1] so when we apply chain rule below, dlogit_maxes will 
broadcast and be 're-routed' to whichever bits will be 'turned on' in one hot encoded matrix.

To visualise run: plt.imshow(F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]))
"""
dlogits +=  F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes

"""
Full explanation (42:00) .... Important
Derivative of logits = h @ W2 + b2
In matrix multiplication for example if eqn is
c = a @ b + c
so dl/da = dl/db @ transpose(b) where dl/db = [dl/dd11, dl/dd12]
                                              [dl/dd21, dl/dd22]

similarly dl/db = transpose(a) @ dl/db
and dl/dc = dl/dd * sum(first_dimension)
"""
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)

"""
derivative of h = torch.tanh(hpreact)
if eqn is 
a = tanh(z) 
Then da/dz = 1 - a**2
"""
dhpreact = ( 1 - h**2) * dh

"""
Derivative of hpreact = bngain * bnraw + bnbias
hpreact.shape = [32, 64]
bngain.shape = [1, 64]
bnraw.shape = [32, 64]
bnbias.shape = [1, 64]
so bngain and bnbias are being broadcasted 
So,
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
* dhpreact because chain rule and .sum(0, keepdim=True) because all gradients need to sum up vertically because of the shape

dbnraw = bngain * dhpreact (chain rule)

dbnbias = dhpreact.sum(0, keepdim=true)
All the gradients will flow equally to biases and we need to sum them up vertically

"""
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnraw = bngain * dhpreact
dbnbias = dhpreact.sum(0, keepdim=True)

"""
Derivative of bnraw = bndiff * bnvar_inv
bnraw.shape = [32, 64]
bndiff.shape = [32, 64]
bnvar_inv.shape = [1, 64]
So bnvar_inv is being broadcasted

We need to keep in mind that bndiff has 2 branches out of it. So this dbndiff is incomplete
dbndiff = bnvar_inv * dbnraw

dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
"""
dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)

"""
Derivative of bnvar_inv = (bnvar + 1e-5)**-0.5
if a = b**2
dl/da = 2*b
so dbnvar = (-0.5*(bnvar + 1e-5)) * dbnvar_inv
"""
dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv

cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
# cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
# cmp('bnmeani', dbnmeani, bnmeani)
# cmp('hprebn', dhprebn, hprebn)
# cmp('embcat', dembcat, embcat)
# cmp('W1', dW1, W1)
# cmp('b1', db1, b1)
# cmp('emb', demb, emb)
# cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | exact: True  | approximate: True  | maxdiff: