In [469]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [470]:
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [471]:
len(words)

32033

In [472]:
chars = [chr(i) for i in range(97, 97 + 26)]
stoi = {s: i + 1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
itos

{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z',
 0: '.'}

In [473]:
# build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?
X, Y = [], []
for w in words[:5]:
  
  print(w)
  context = [0] * block_size
  for ch in w + '.':
    ix = stoi[ch]
    X.append(context)
    Y.append(ix)
    print(''.join(itos[i] for i in context), '--->', itos[ix])
    context = context[1:] + [ix] # crop and append
  
X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


In [474]:
X.shape, X.dtype, Y.shape, Y.dtype


(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [475]:
X

tensor([[ 0,  0,  0],
        [ 0,  0,  5],
        [ 0,  5, 13],
        [ 5, 13, 13],
        [13, 13,  1],
        [ 0,  0,  0],
        [ 0,  0, 15],
        [ 0, 15, 12],
        [15, 12,  9],
        [12,  9, 22],
        [ 9, 22,  9],
        [22,  9,  1],
        [ 0,  0,  0],
        [ 0,  0,  1],
        [ 0,  1, 22],
        [ 1, 22,  1],
        [ 0,  0,  0],
        [ 0,  0,  9],
        [ 0,  9, 19],
        [ 9, 19,  1],
        [19,  1,  2],
        [ 1,  2,  5],
        [ 2,  5, 12],
        [ 5, 12, 12],
        [12, 12,  1],
        [ 0,  0,  0],
        [ 0,  0, 19],
        [ 0, 19, 15],
        [19, 15, 16],
        [15, 16,  8],
        [16,  8,  9],
        [ 8,  9,  1]])

In [476]:
Y

tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0])

In [477]:
g = torch.Generator().manual_seed(2147483647) # for reproducibility

In [478]:
C = torch.randn((27, 2), generator=g)

In [479]:
C

tensor([[ 1.5674, -0.2373],
        [-0.0274, -1.1008],
        [ 0.2859, -0.0296],
        [-1.5471,  0.6049],
        [ 0.0791,  0.9046],
        [-0.4713,  0.7868],
        [-0.3284, -0.4330],
        [ 1.3729,  2.9334],
        [ 1.5618, -1.6261],
        [ 0.6772, -0.8404],
        [ 0.9849, -0.1484],
        [-1.4795,  0.4483],
        [-0.0707,  2.4968],
        [ 2.4448, -0.6701],
        [-1.2199,  0.3031],
        [-1.0725,  0.7276],
        [ 0.0511,  1.3095],
        [-0.8022, -0.8504],
        [-1.8068,  1.2523],
        [ 0.1476, -1.0006],
        [-0.5030, -1.0660],
        [ 0.8480,  2.0275],
        [-0.1158, -1.2078],
        [-1.0406, -1.5367],
        [-0.5132,  0.2961],
        [-1.4904, -0.2838],
        [ 0.2569,  0.2130]])

In [480]:
C[5]

tensor([-0.4713,  0.7868])

In [481]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([-0.4713,  0.7868])

In [482]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [483]:
emb

tensor([[[ 1.5674, -0.2373],
         [ 1.5674, -0.2373],
         [ 1.5674, -0.2373]],

        [[ 1.5674, -0.2373],
         [ 1.5674, -0.2373],
         [-0.4713,  0.7868]],

        [[ 1.5674, -0.2373],
         [-0.4713,  0.7868],
         [ 2.4448, -0.6701]],

        [[-0.4713,  0.7868],
         [ 2.4448, -0.6701],
         [ 2.4448, -0.6701]],

        [[ 2.4448, -0.6701],
         [ 2.4448, -0.6701],
         [-0.0274, -1.1008]],

        [[ 1.5674, -0.2373],
         [ 1.5674, -0.2373],
         [ 1.5674, -0.2373]],

        [[ 1.5674, -0.2373],
         [ 1.5674, -0.2373],
         [-1.0725,  0.7276]],

        [[ 1.5674, -0.2373],
         [-1.0725,  0.7276],
         [-0.0707,  2.4968]],

        [[-1.0725,  0.7276],
         [-0.0707,  2.4968],
         [ 0.6772, -0.8404]],

        [[-0.0707,  2.4968],
         [ 0.6772, -0.8404],
         [-0.1158, -1.2078]],

        [[ 0.6772, -0.8404],
         [-0.1158, -1.2078],
         [ 0.6772, -0.8404]],

        [[-0.1158, -1

In [484]:
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)

In [485]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

In [486]:
h

tensor([[-0.9348,  1.0000,  0.9258,  ...,  0.9786, -0.1926,  0.9515],
        [ 0.2797,  0.9997,  0.7675,  ...,  0.9929,  0.9992,  0.9981],
        [-0.9960,  1.0000, -0.8694,  ..., -0.5159, -1.0000, -0.0069],
        ...,
        [-0.9996,  1.0000, -0.9273,  ..., -0.9999, -0.9974, -0.9970],
        [-0.9043,  1.0000,  0.9868,  ..., -0.7859, -0.4819,  0.9981],
        [-0.9048,  1.0000,  0.9553,  ...,  0.9866,  1.0000,  0.9907]])

In [487]:
h.shape

torch.Size([32, 100])

In [488]:
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)

In [489]:
logits = h @ W2 + b2

In [490]:
logits.shape

torch.Size([32, 27])

In [491]:
logits

tensor([[ 2.0191e+00, -5.0126e-01,  1.3371e+01, -1.3467e+00, -4.9756e+00,
         -9.9992e+00, -1.9701e+01,  1.0128e+01,  2.6051e+00,  1.9866e+01,
          2.1664e+01, -4.1630e+00, -6.2530e-01, -5.7732e-01,  1.2552e+01,
         -1.1293e+00,  2.3699e+00,  5.7146e+00,  6.7237e+00,  2.0819e+00,
         -2.7059e+00,  2.2633e+00, -4.2412e+00, -2.8045e+00,  1.0247e+01,
          2.8311e+00,  9.2402e+00],
        [-6.9784e+00,  8.3823e-01,  9.4671e+00, -1.1949e+00,  7.0614e+00,
          9.3584e-01, -6.3802e+00,  1.1896e+01,  1.3653e+01, -1.6366e+00,
          2.6264e+01,  2.5903e+00, -2.4583e+00, -1.1175e+00,  1.2883e+01,
          2.3099e-02,  8.1749e+00,  3.2475e+00,  8.1620e+00,  9.2478e+00,
         -5.2149e+00, -1.0714e+01, -1.6384e+01, -2.3681e+00,  5.7656e+00,
          4.1368e+00,  1.2497e+01],
        [ 4.1821e+00, -3.6009e+00,  1.7097e+01,  1.2406e+00, -5.6116e+00,
         -2.4239e-01, -4.6956e+00, -1.1157e+01, -9.5652e+00,  2.5903e+01,
          1.4100e+01,  1.0274e+01,  3.27

In [492]:
counts = logits.exp()
counts.shape

torch.Size([32, 27])

In [493]:
probs = counts / counts.sum(1, keepdim=True)

In [494]:
probs.shape

torch.Size([32, 27])

In [495]:
probs[0].sum()

tensor(1.)

In [496]:
torch.arange(32)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

In [497]:
Y

tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0])

In [498]:
loss = -probs[torch.arange(32), Y].log().mean()
loss


tensor(17.7697)