In [27]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


In [28]:
words = open('names.txt', 'r').read().splitlines()

chars = sorted(list(set(''.join(words))))

stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [29]:
block_size = 3

def build_dataset(words):

  X, Y = [], []
  for w in words:
    #print(w)
    context = [0]* block_size

    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      #print(''.join(itos[i] for i in context),'--->', itos[ix])
      context = context[1:]+[ix]

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y



import random
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))


Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

print(Xtr[:5])

torch.Size([182597, 3]) torch.Size([182597])
torch.Size([22868, 3]) torch.Size([22868])
torch.Size([22681, 3]) torch.Size([22681])
tensor([[ 0,  0,  0],
        [ 0,  0, 11],
        [ 0, 11,  5],
        [11,  5, 14],
        [ 5, 14, 26]])


In [30]:
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [31]:
#building MLP
n_embd = 10 #dimensionality of character embedding vectors
n_hidden = 200 #no. of neurons in hidden layer of MLP

C = torch.rand((vocab_size,n_embd))
W1 = torch.randn((n_embd*block_size,n_hidden)) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden) * 0.01  #b1 redundant because of batchnormalisation bias
W2 = torch.randn((n_hidden,vocab_size)) *0.1
b2 = torch.randn(vocab_size) * 0.1


bngain = torch.randn((1, n_hidden))*0.1+1.0
bnbias = torch.randn((1, n_hidden))*0.1  #non-zero to unmask errors in calculating gradients w.r.t biases.

#bnmean_running = torch.zeros((1, n_hidden))
#bngain_running = torch.ones((1, n_hidden))

parameters = [C, W1 , W2, b2, bngain, bnbias]


for p in parameters:
  p.requires_grad = True

In [32]:
batch_size = 32
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
Xb, Yb = Xtr[ix], Ytr[ix]




In [47]:
#minibatch construct
ix = torch.randint(0,Xtr.shape[0],(batch_size,))

Xb, Yb = Xtr[ix], Ytr[ix]

emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)

#Linear Layer 1
hprebn = embcat @ W1 + b1

#batchnorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)

bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0,keepdim=True)
bnvar_inv = (bnvar+1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
#non_linearity
h = torch.tanh(hpreact)
#linear layer 2
logits = h @ W2 + b2 #output layer , [32,27]
#cross entropy loss - subtract max neuron activation in layer from each activation

logit_maxes = logits.max(1, keepdim=True).values



#logit_maxes = logits.max(1, keepdim=True).values #  [32, 1]
norm_logits = logits - logit_maxes  # [32,27]
#exp each value to remove negatives
counts = norm_logits.exp()   # [32,27]
counts_sum = counts.sum(1, keepdims=True) # [32,1]
counts_sum_inv = counts_sum**-1
#convert each value to prob by dividing by sum of activations across one layer of neurons
probs = counts * counts_sum_inv
#log each value
logprobs = probs.log() # [32,27]

#range(n) is a list from 1 to 32.  Yb is list of the correct vocab indices for given input context. 
#we pick out the log prob of each correct output for each of the 32 inputs, add them and find the mean (equivalent to multiplying the raw values and diving by 32)
#if correct output indices has high probability assigned to it by neural net, then just the above would give big value close to 0
#if probability is low (which is bad we want the probs of correct values to be high), then above method gives very negative value
#but backpropagation minimises loss for best results so we take -loss for which lower negative values indicate high confidence.

loss = -logprobs[range(n), Yb].mean()


for p in parameters:
    p.grad=None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()

loss.backward()
loss

tensor(3.8917, grad_fn=<NegBackward0>)

In [None]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n
dprobs = dlogprobs * 1/probs
dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)  #if f(x,y,z,l) = xy + zy + ly , del f/del y = x+z+l
dcounts = dprobs*counts_sum_inv
dcounts_sum = -counts_sum**-2*dcounts_sum_inv
dcounts += torch.ones_like(counts)*dcounts_sum
dnorm_logits = dcounts*counts
dlogits = dnorm_logits.clone()
dlogit_maxes =  (-dnorm_logits).sum(1, keepdim=True)
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
dh = W2.sum(1, keepdim=True)

cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
