In [1]:
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import kagglehub
import torch
import os
%matplotlib inline

In [2]:
def download_dataset() -> str:
    path = kagglehub.dataset_download("rishitjakharia/names-txt")
    return path # for easier access later

In [3]:
dataset_path = download_dataset() + "/names.txt"

Downloading from https://www.kaggle.com/api/v1/datasets/download/rishitjakharia/names-txt?dataset_version_number=1...


100%|██████████| 113k/113k [00:00<00:00, 52.8MB/s]

Extracting files...





In [4]:
words = open(dataset_path, 'r').read().splitlines()

In [5]:
def make_int_char_maps() -> tuple[dict, dict]:
  chars = sorted(list(set(''.join(words))))
  ctoi = {c: i + 1 for i, c in enumerate(chars)}
  ctoi['.'] = 0
  itoc = {i: c for c, i in ctoi.items()}
  return ctoi, itoc

In [6]:
ctoi, itoc = make_int_char_maps()

In [7]:
print(itoc)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [8]:
print(ctoi)

{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}


In [9]:
vocab_size = len(ctoi)
print(vocab_size)

27


In [10]:
def build_dataset(inp_words) -> tuple[torch.tensor, torch.tensor]:

  block_size = 3 # context length -> how many chars does it take to predict the next
  inp, target = [], []
  for w in inp_words:
    context = [0] * block_size

    for ch in w + '.':
      ix = ctoi[ch]
      inp.append(context)
      target.append(ix)
      context = context[1:] + [ix] # crop and append

  inp = torch.tensor(inp)
  target = torch.tensor(target)
  return inp, target

In [11]:
import random
random.seed(42)

random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

In [12]:
# to later compare gradients we computed to Pytorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [13]:
n_emb = 10
n_context = 3
n_hidden = 64

g = torch.Generator().manual_seed(214783647)
C = torch.randn((vocab_size, n_emb), generator=g)

# first layer                                           # kaiming init
W1 = torch.randn((n_emb * n_context, n_hidden), generator=g) * (5/3) / ((n_emb * 3) ** 0.5)
B1 = torch.randn(n_hidden, generator=g) * 0.1 # these biases dont do anything since we saw they get subtracted in batch normlizing but its here for fun

# second layer
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
B2 = torch.randn(vocab_size, generator=g) * 0.1

# batchnorm params
bngain = torch.ones((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.zeros((1, n_hidden)) * 0.1

parameters = [C, W1, B1, W2, B2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True




4137


In [14]:
B = 32 # batch size
# forward pass
Xb = Xtr[:B]
Yb = Ytr[:B]

emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)

# linear layer 1
hprebn = embcat @ W1 + B1
# batch norm
bnmeani = 1 / emb.shape[0] * hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff ** 2
bnvar = 1 / (emb.shape[0] - 1) * bndiff2.sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5) ** -0.5
bndiff_norm = bndiff * bnvar_inv
bnraw = bndiff_norm * bngain + bnbias
hpreact = bngain * bnraw + bnbias
# non linearity
h = torch.tanh(hpreact)
# second linear layer
logits = h @ W2 + B2
# cross entropy loss
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract the max so we dont get inf in counts
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum ** -1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(Xb.shape[0]), Yb].mean()

# pytorch backwards
for p in parameters:
  p.grad = None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv,
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
          bnvar_inv, bnvar, bndiff_norm, bndiff, hprebn, bnmeani,
          embcat, emb]:
  t.retain_grad()

loss.backward()
loss



tensor(3.4207, grad_fn=<NegBackward0>)

In [43]:
from ctypes import c_size_t
# Ecercise 1
# Calculating all of the deriatives manually to do backprop

# L = loss
# dL/dL = 1

# dL/dlogprobs
# loss = -(a + b + c ... ) /B
#      = -a/B -b/B -c/B ...
#      = dL/da = -1/B, dL/db = -1/B ...
# adding just propates the gradient
# dl/dlogprobs = -1/B

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(B), Yb] = -1/B

# dL/dprobs = dL/dlogprobs * dlogprobs/dprobs -> chain rule
#           = -1/B         * ???
# logprobs = probs.log()
# logprobs = ln(probs)
# dlogprobs/dprobs = 1/probs

# dL/dprobs = dlogprobs * 1/probs
dprobs = dlogprobs * 1/probs

# probs = counts * counts_sum_inv

# p  = a * b
# a[3x3] * b[3x1]
# a11*b1 a12*b1 a13*b1
# a21*b2 a22*b2 a23*b2
# a31*b3 a23*b3 a33*b3
# p[3x3]

# since we reuse b1, b2, b3 we need to add their gradients and keep the original shape [32x1]
# dL/dcsi = (counts * dprobs).sum(1, keepdim=true)
dcsi= (counts * dprobs).sum(1, keepdim=True)

# refering to our p = a * b
# a represents counts
# dL/da = dL/dp * dp/da
#       = dprobs * dcsi
dcounts = dprobs * counts_sum_inv

# counts_sum_inv = counts_sum ** -1
# dcounts_sum = dL/dcsi * -1 * counts_sum**-2
dcounts_sum = dcsi * (-1*counts_sum**-2)

# counts_sum = counts.sum(1, keepdims=True)
# addition pretty much carries the gradient over
# counts_sum = c1 + c2 + c3 ...
# dL/dc1 = dL/dcounts_sum * dcounts_sum/dc1
#        = dcounts_sum * 1
# we want to add to the already existing dcounts since its apart of multiple nodes in the prop tree

dcounts += dcounts_sum

cmp('dlogprobs', dlogprobs, logprobs)
cmp('dprobs', dprobs, probs)
cmp('dcsi', dcsi, counts_sum_inv)
cmp('dcounts_sum', dcounts_sum, counts_sum)
cmp('dcounts', dcounts, counts)



#counts_sum = counts.sum(1, keepdims=True)
#counts_sum_inv = counts_sum ** -1


dlogprobs       | exact: True  | approximate: True  | maxdiff: 0.0
dprobs          | exact: True  | approximate: True  | maxdiff: 0.0
dcsi            | exact: True  | approximate: True  | maxdiff: 0.0
dcounts_sum     | exact: True  | approximate: True  | maxdiff: 0.0
dcounts         | exact: True  | approximate: True  | maxdiff: 0.0
