In [1]:
import torch

In [3]:
words = open('names.txt','r').read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [4]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

Create the dataset

In [17]:
xs,ys = [],[]

for w in words:
    chs=['.']+list(w)+['.']
    for ch1,ch2 in zip(chs,chs[1:]):
        ix1=stoi[ch1]
        ix2=stoi[ch2]
        xs.append(ix1)
        ys.append(ix2)
xs=torch.tensor(xs)
ys=torch.tensor(ys)

In [18]:
print(f'{xs=}')
print(f'{ys=}')

xs=tensor([ 0,  5, 13,  ..., 25, 26, 24])
ys=tensor([ 5, 13, 13,  ..., 26, 24,  0])


In [19]:
print(f'Number of examples {len(xs)}')

Number of examples 228146


intitialise the network

In [None]:
# randomly initialize 27 neurons' weights. each neuron receives 27 inputs

g = torch.Generator().manual_seed(2147483647)
W = torch.rand((27,27),generator=g,requires_grad=True)

Optimisation loop

In [21]:
import torch.nn.functional as F

In [32]:
for k in range(100):  # Run the training loop for 100 iterations

    # === FORWARD PASS ===

    # Convert input character indices into one-hot encoded vectors.
    # xs: tensor of shape (batch_size,), each value is an index from 0 to 26
    # One-hot encoding gives shape: (batch_size, 27)
    # Example: index 2 -> [0, 0, 1, 0, ..., 0]
    xenc = F.one_hot(xs, num_classes=27).float()

    # Compute logits by matrix multiplication
    # W: weight matrix of shape (27, 27)
    # xenc @ W → (batch_size, 27), each row is a raw prediction (unnormalized)
    logits = xenc @ W

    # === SOFTMAX FUNCTION ===

    # Convert logits into exponentiated values (numerator of softmax)
    # This is done instead of calling torch.softmax to reinforce the manual math
    counts = logits.exp()

    # Normalize to get probabilities across each row (i.e., per sample)
    # Sum across dim=1 to get the denominator of softmax
    # probs: shape (batch_size, 27), each row sums to 1
    probs = counts / counts.sum(1, keepdims=True)

    # === LOSS CALCULATION ===

    # Negative log-likelihood:
    # For each sample, pick the probability corresponding to the correct next character (ys)
    # torch.arange(len(xs)) creates row indices: [0, 1, 2, ..., batch_size-1]
    # probs[rows, ys] gives the predicted prob for the true label
    # .log(): log-likelihood → negative for loss
    # .mean(): average over batch
    # Add L2 regularization: encourages smaller weights to reduce overfitting
    loss = -probs[torch.arange(len(xs)), ys].log().mean() + 0.01 * (W ** 2).mean()

    # Print loss at each iteration to monitor training
    print(loss.item())

    # === BACKWARD PASS ===

    # Zero out previous gradients manually (no optimizer used)
    W.grad = None

    # Backpropagate to compute gradient of loss with respect to W
    loss.backward()

    # === WEIGHT UPDATE ===

    # Perform stochastic gradient descent manually
    # W.grad: gradient of loss w.r.t. W
    # Learning rate = 50 (large, to speed up convergence in small examples)
    # .data is used to perform in-place updates without interfering with autograd
    W.data -= 50 * W.grad


2.4818949699401855
2.4818878173828125
2.4818809032440186
2.4818737506866455
2.4818668365478516
2.4818599224090576
2.4818530082702637
2.4818460941314697
2.4818389415740967
2.48183274269104
2.481825828552246
2.4818193912506104
2.4818124771118164
2.4818062782287598
2.481799364089966
2.4817934036254883
2.4817867279052734
2.4817802906036377
2.481774091720581
2.4817678928375244
2.481761932373047
2.481755256652832
2.4817490577697754
2.481743574142456
2.4817371368408203
2.481731414794922
2.4817254543304443
2.481719493865967
2.4817137718200684
2.481707811355591
2.4817020893096924
2.481696367263794
2.4816904067993164
2.481685161590576
2.4816792011260986
2.4816739559173584
2.48166823387146
2.4816627502441406
2.4816575050354004
2.481651782989502
2.4816465377807617
2.4816410541534424
2.481635570526123
2.481630325317383
2.4816253185272217
2.4816200733184814
2.481614828109741
2.48160982131958
2.48160457611084
2.481599807739258
2.4815945625305176
2.4815895557403564
2.4815845489501953
2.481579542160034