# Makemore : Becoming a backprop ninja

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [2]:
#taking the dataset
names = open("names.txt","r").read().splitlines()
print(len(names))
print(max(len(w) for w in names))
print(names[:5])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia']


In [3]:
#building the vocabulary for characters and mapping to/from integer
chars =sorted(list((set("".join(names)))))
string_int = {s:i+1 for i,s in enumerate(chars)}
string_int["."]=0
int_string = {i:s for s,i in string_int.items()}
vocab_size = len(int_string)
print(int_string,vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'} 27


In [4]:
# building the data
block_size = 3
def building_dataset(names):
    x,y=[],[]
    for name in names:
        context =[0]*block_size
        for ch in name+".":
            ix = string_int[ch]
            x.append(context)
            y.append(ix)
            context = context[1:]+[ix]
    x,y =torch.tensor(x),torch.tensor(y)
    print(x.shape,y.shape)
    return x,y
import random 
random.seed(32)
random.shuffle(names)
n1 = int(0.8*len(names))
n2 = int(0.9*len(names))
xtr ,ytr = building_dataset(names[:n1])#80% of traning data
xdev,ydev=building_dataset(names[n1:n2])
xte,yte=building_dataset(names[:n2])

torch.Size([182408, 3]) torch.Size([182408])
torch.Size([22794, 3]) torch.Size([22794])
torch.Size([205202, 3]) torch.Size([205202])


In [5]:
#Now the boilerplate is done lets get to the action

In [6]:
n_embd =10
n_hidden =64
g= torch.Generator().manual_seed(2139344)
C = torch.randn((vocab_size,n_embd),generator=g)
#Layer 1
w1 = torch.randn((n_embd*block_size,n_hidden),generator=g)*(5/3)/((n_embd*block_size)**0.5)
b1 = torch.randn(n_hidden,generator=g)
#Layer 2 
w2 = torch.randn((n_hidden,vocab_size),generator=g)*0.1
b2 = torch.randn(vocab_size,generator=g)*0.1
#BatchNorm parameters
bngain= torch.randn((1,n_hidden))*0.1+1.0
bnbias = torch.randn((1,n_hidden))*0.1

#Note: I am initializating many of these parameters in non-standard ways
#because sometimes initialization with e.g all zeros could mask an incorrect 
#implimentation of the backward pass

parameters = [C,w1,b1,w2,b2,bngain,bnbias]
print(sum(p.nelement() for p in parameters))#no of parameters in total
for p in parameters:
    p.requires_grad=True

4137


In [7]:
batch_size =32
n= batch_size# a shorter  variable also ,for convineience 
#construct a mini batch
ix = torch.randint(0,xtr.shape[0],(batch_size,),generator=g)
xb,yb =xtr[ix],ytr[ix] #batch x,y

In [8]:
#forward pass but more expanded 
emb =C[xb]
embcat = emb.view(emb.shape[0],n_embd*block_size)#concatinate the vector
#Linear layer
hprebn = embcat@w1+b1
#BatchNorm layer
bnmeani = 1/n*hprebn.sum(0,keepdim=True)
bndiff=hprebn-bnmeani #taking the mean - value
bndiff2 = bndiff**2 
bnvariance = 1/(n-1)*(bndiff2).sum(0,keepdim=True) #note :in basic standard deviation we use ((mean-x)^2/n)but Bessel's correction(dividing by n-1,not n)
bnvariance_inv = (bnvariance +1e-5)**-0.5
bnraw =bndiff*bnvariance_inv
hpreact= bngain*bnraw+bnbias

#Non linearity
h = torch.tanh(hpreact)
#Linear layer 2
logits = h@w2+b2
#cross entrophy loss (same as F.cross_entrophy(logits,yb)  but in more open)
logit_maxes = logits.max(1,keepdim=True).values
norm_logits = logits-logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum=counts.sum(1,keepdims=True)
counts_sum_inv = counts_sum**-1 # if i use (1/counts_sum) instead then i can't get backprop to be bit exact
probs=counts*counts_sum_inv
probs=counts*counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n),yb].mean()
#pytorch backward pass
for p in parameters:
    p.grad =None
for t in [logprobs,probs,counts,counts_sum,counts_sum_inv,norm_logits,logit_maxes,logits,h,hpreact,bnraw,bnvariance_inv,
         bnvariance,bndiff2,bndiff,hprebn,hprebn,bnmeani,embcat,emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.3135, grad_fn=<NegBackward0>)

In [9]:
#utility function we will use later when comparing manual gradiet to pytoarch gradient 
def cmp(s,dt,t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt,t.grad)
    maxdiff = (dt-t.grad).abs().max().item()
    print(f"{s:15s} | exact:{str(ex):5s} | approximate : {str(app):5s} | maxdiff: {maxdiff}")

In [16]:
#Excerise 1 : we have to do the dy/dx for each of the loss wrt each of the thing which happen
# for loss wrt logprobs the dloss/dlogprob is -1/n so 
dlogprobs = torch.zeros_like((logprobs))
dlogprobs[range(n),yb] = -1/n
cmp("logprobs",dlogprobs,logprobs)
#wrt probs 
dprobs = 1/probs*dlogprobs
cmp("dprobs",dprobs,probs)
#wrt count_sum_inv 
dcountsum_inv = (counts*dprobs).sum(1,keepdims=True)
cmp("count_sum_inv",dcountsum_inv,counts_sum_inv)
#wrt counts
dcounts = counts_sum_inv*dprobs
cmp("counts",dcounts,counts) #there is a reason output is false because the counts is depending in more places and then added up but we'll come to that later
#wrt count_sum**
dcount_sum = -(1/counts_sum**2)*dcountsum_inv
cmp("count_sum",dcount_sum,counts_sum)
#wrt wrt counts
dcounts += torch.ones_like(counts)*dcount_sum # i did the += because i want it to add with the previous counts which was false
cmp("counts",dcounts,counts) # as we can see its true now
#wrt norm_logits
dnorm_logits = norm_logits.exp()*dcounts
cmp("norm_logits",dnorm_logits,norm_logits)
#wrt logit maxes 
dlogits = dnorm_logits.clone()
cmp("logits",dlogits,logits)# the reason it is false because it is dependent in more than one place
dlogit_max = (-dnorm_logits).sum(1,keepdims=True)
cmp("dlogit_max",dlogit_max,logit_maxes)
#wrt logits
dlogits += F.one_hot(logits.max(1).indices,num_classes=logits.shape[1])*dlogit_max
cmp("logits",dlogits,logits)
#wrt h
dh = dlogits @ w2.T
cmp("h",dh,h)
#wrt w2
dw2 =h.T @ dlogits 
cmp("w2",dw2,w2)
#wrt b2
db2 = dlogits.sum(0)
cmp("b2",db2,b2)
#wrt hpreact
dhpreact = (1.0 - h**2)*dh
cmp("hpreact",dhpreact,hpreact)
#wrt bngain
dbngain = (bnraw*dhpreact).sum(0,keepdims=True)
cmp("gain",bngain,bngain)
#wrt bnraw
dbnraw = (bngain*dhpreact)
cmp("bnbias",bnraw,bnraw)
#wrt bnbias
dbnbias = dhpreact.sum(0,keepdims=True)
cmp("bnbias",bnbias,bnbias)
#wrt bndiff
dbndiff = bnvariance_inv*dbnraw
cmp("bndiff",dbndiff,bndiff)
#wrt bnvariance
dbnvariance_inv = (bndiff*dbnraw).sum(0,keepdims=True)
cmp("bnvariance_inv",dbnvariance_inv,bnvariance_inv)
#wrt bnvariance
dbnvariance = (-0.5*(bnvariance +1e-5)**-1.5)*dbnvariance_inv
cmp("bnvariance",dbnvariance,bnvariance)
#wrt bndiff2 
dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2)*dbnvariance
cmp("bndiff2",dbndiff2,bndiff2)
#wrt bndiff
dbndiff += (2*bndiff)*dbndiff2
cmp("bndiff",dbndiff,bndiff)
#wrt hprebn
dhprebn = dbndiff.clone()
cmp("hprebn",dhprebn,hprebn)
#wrt dbnmeani
dbnmeani =(-torch.ones_like(bndiff)*dbndiff).sum(0,keepdims=True)
cmp("bnmeani",dbnmeani,bnmeani)
#wrt hprebn
dhprebn += 1.0/n*(torch.ones_like(hprebn)*dbnmeani)
cmp("hprebn",dhprebn,hprebn)
#wrt emcat
dembcat = dhprebn @w1.T
cmp("embcat",dembcat,embcat)
#wrt w1
dw1 = embcat.T@dhprebn 
cmp("w1",dw1,w1)
#wrt b1
db1 =  dhprebn.sum(0)
cmp("b1",db1,b1)
#wrt emb
demb = dembcat.view(emb.shape)
cmp("emb",demb,emb)
#wrt C
dC = torch.zeros_like(C)
for k in range(xb.shape[0]):
    for j in range(xb.shape[1]):
        ix = xb[k,j]
        dC[ix] += demb[k,j]
cmp("C",dC,C)

logprobs        | exact:True  | approximate : True  | maxdiff: 0.0
dprobs          | exact:True  | approximate : True  | maxdiff: 0.0
count_sum_inv   | exact:True  | approximate : True  | maxdiff: 0.0
counts          | exact:False | approximate : False | maxdiff: 0.004510226659476757
count_sum       | exact:True  | approximate : True  | maxdiff: 0.0
counts          | exact:True  | approximate : True  | maxdiff: 0.0
norm_logits     | exact:True  | approximate : True  | maxdiff: 0.0
logits          | exact:False | approximate : True  | maxdiff: 4.6566128730773926e-09
dlogit_max      | exact:True  | approximate : True  | maxdiff: 0.0
logits          | exact:True  | approximate : True  | maxdiff: 0.0
h               | exact:True  | approximate : True  | maxdiff: 0.0
w2              | exact:True  | approximate : True  | maxdiff: 0.0
b2              | exact:True  | approximate : True  | maxdiff: 0.0
hpreact         | exact:False | approximate : True  | maxdiff: 4.656612873077393e-10
gain    

In [None]:
#Exercise 2 in this we'll have to do the dloss/dlogits directly withoud using all those elements we were using 
