<a href="https://colab.research.google.com/github/samitha278/multilayer-perceptron-char-lm/blob/main/mlp_backprop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [32]:
import torch
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
%matplotlib inline

In [33]:
# download the names.txt file from github
!wget https://github.com/samitha278/multilayer-perceptron-char-lm/blob/main/data/names.txt

--2025-08-23 07:07:55--  https://github.com/samitha278/multilayer-perceptron-char-lm/blob/main/data/names.txt
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘names.txt.4’

names.txt.4             [ <=>                ] 560.09K  --.-KB/s    in 0.04s   

2025-08-23 07:07:55 (14.2 MB/s) - ‘names.txt.4’ saved [573536]



In [34]:
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [35]:
#character mapping
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


### Build the dataset

In [31]:
block_size = 3 # context length

def build_dataset(words):
  X, Y = [], []

  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix]

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y


random.seed(278)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182418, 3]) torch.Size([182418])
torch.Size([22885, 3]) torch.Size([22885])
torch.Size([22843, 3]) torch.Size([22843])


### Build MLP

In [37]:
n_embd = 10      # dimensionality of the character embedding vectors
n_hidden = 64    # number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647)
C  = torch.randn((vocab_size, n_embd),            generator=g)


# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 useless because of BN


# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1


# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1



parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
  p.requires_grad = True

4137


## Forward pass

In [50]:
batch_size = 32
n = batch_size
#minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

In [51]:

emb = C[Xb]                         # characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors


# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation

# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # Bessel's correction
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv

hpreact = bngain * bnraw + bnbias


# # hidden layer - Non-linearity
h = torch.tanh(hpreact)


# Linear layer 2
logits = h @ W2 + b2 # output layer

# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability

counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1

probs = counts * counts_sum_inv
logprobs = probs.log()

loss = -logprobs[range(n), Yb].mean()



In [52]:
# PyTorch backward pass
for p in parameters:
  p.grad = None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv,
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.4314, grad_fn=<NegBackward0>)

### Comparing manual gradients to PyTorch gradients

In [53]:
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

## Backprop - manually



In [49]:
logprobs.shape

torch.Size([32, 27])

In [48]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n),Yb] = -1.0/n



dprobs = dlogprobs * (1.0/probs)

dcounts_sum_inv = (dprobs * counts).sum(1,keepdims = True)


dcounts_sum = dcounts_sum_inv * (-counts_sum**-2)




# dcounts = dprobs * counts_sum_inv
# dnorm_logits = dcounts * norm_logits.exp()
# dlogit_maxes = dnorm_logits * -1
# dlogits = dnorm_logits
# dh = dlogits @ W2.T
# dW2 = dlogits.T  @ h
# db2 = dlogits





cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)




logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0


In [69]:
dcounts_sum_inv

tensor([[-0.4334],
        [-0.2831],
        [-0.3073],
        [-0.3577],
        [-0.2564],
        [-0.2807],
        [-0.3888],
        [-0.3340],
        [-0.3460],
        [-0.3888],
        [-0.2885],
        [-0.3860],
        [-0.2924],
        [-0.3351],
        [-0.4162],
        [-0.2858],
        [-0.3605],
        [-0.3847],
        [-0.4073],
        [-0.2855],
        [-0.4097],
        [-0.3888],
        [-0.3572],
        [-0.3888],
        [-0.3288],
        [-0.3043],
        [-0.2233],
        [-0.2933],
        [-0.4156],
        [-0.3675],
        [-0.3095],
        [-0.3148]], grad_fn=<SumBackward1>)

In [71]:
dprobs * counts

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000, -0.4966,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, -0.1824,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0585,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.5979,  0.0000

In [68]:
dlogprobs * (1.0/probs) * counts

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000, -0.3573,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, -0.2302,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.2569,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.3806,  0.0000

In [66]:
counts_sum_inv

tensor([[0.0875],
        [0.1358],
        [0.1217],
        [0.0821],
        [0.0817],
        [0.0994],
        [0.0902],
        [0.0875],
        [0.0875],
        [0.0883],
        [0.1153],
        [0.1017],
        [0.1138],
        [0.1308],
        [0.1140],
        [0.0875],
        [0.0752],
        [0.0981],
        [0.0875],
        [0.1152],
        [0.0915],
        [0.0906],
        [0.0794],
        [0.1103],
        [0.0752],
        [0.0949],
        [0.0875],
        [0.0956],
        [0.0828],
        [0.0875],
        [0.1197],
        [0.0977]], grad_fn=<PowBackward0>)

In [62]:
dlogprobs * (1.0/probs)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000, -0.6500,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, -1.1565,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -3.3393,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.1434,  0.0000

In [55]:
probs

tensor([[0.0138, 0.0120, 0.0415, 0.0158, 0.0339, 0.0179, 0.0819, 0.0431, 0.0461,
         0.0515, 0.0173, 0.0304, 0.0384, 0.0428, 0.0165, 0.0505, 0.0624, 0.0357,
         0.0481, 0.0316, 0.0476, 0.0358, 0.0875, 0.0140, 0.0295, 0.0248, 0.0295],
        [0.0165, 0.0164, 0.0411, 0.0212, 0.0666, 0.0143, 0.0278, 0.0434, 0.0351,
         0.0316, 0.0159, 0.0265, 0.0253, 0.0933, 0.0305, 0.0467, 0.0252, 0.0279,
         0.0447, 0.0803, 0.1358, 0.0144, 0.0272, 0.0180, 0.0168, 0.0270, 0.0306],
        [0.0278, 0.0476, 0.0203, 0.0923, 0.0334, 0.0817, 0.0297, 0.0286, 0.0256,
         0.0338, 0.0504, 0.0182, 0.0174, 0.0170, 0.0094, 0.0130, 0.0187, 0.0220,
         0.0292, 0.0298, 0.0298, 0.0388, 0.0192, 0.0833, 0.1217, 0.0431, 0.0184],
        [0.0521, 0.0212, 0.0233, 0.0580, 0.0362, 0.0289, 0.0184, 0.0588, 0.0284,
         0.0362, 0.0293, 0.0219, 0.0276, 0.0522, 0.0273, 0.0347, 0.0311, 0.0602,
         0.0370, 0.0614, 0.0821, 0.0369, 0.0403, 0.0175, 0.0196, 0.0265, 0.0325],
        [0.0534, 0.0205,

In [56]:
1.0/probs

tensor([[ 72.7181,  83.3676,  24.0873,  63.4545,  29.4621,  55.8846,  12.2099,
          23.1999,  21.6764,  19.4205,  57.7382,  32.8719,  26.0383,  23.3784,
          60.4254,  19.7833,  16.0331,  27.9914,  20.8009,  31.6195,  21.0260,
          27.8977,  11.4348,  71.4264,  33.8710,  40.2572,  33.9266],
        [ 60.4595,  60.9931,  24.3405,  47.2623,  15.0223,  69.6892,  36.0246,
          23.0296,  28.5194,  31.6781,  63.0219,  37.7442,  39.5026,  10.7175,
          32.7386,  21.4296,  39.6087,  35.8761,  22.3791,  12.4589,   7.3654,
          69.3613,  36.7584,  55.5133,  59.4823,  37.0096,  32.6858],
        [ 36.0143,  20.9920,  49.2662,  10.8387,  29.9098,  12.2391,  33.7222,
          34.9772,  39.0321,  29.6144,  19.8375,  54.9362,  57.3698,  58.8050,
         106.8560,  77.0627,  53.5612,  45.4278,  34.2152,  33.6132,  33.5854,
          25.7472,  52.1565,  12.0091,   8.2201,  23.1878,  54.3903],
        [ 19.1960,  47.0836,  42.9058,  17.2405,  27.6279,  34.5551,  54.2981,
