In [159]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [160]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()

In [161]:
words[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [162]:
# build the vocabulary of characters and mappings to/from integers

chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [163]:
# build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?

X, Y = [], []
for w in words[:5]:
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]
X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


In [164]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [165]:
X

tensor([[ 0,  0,  0],
        [ 0,  0,  5],
        [ 0,  5, 13],
        [ 5, 13, 13],
        [13, 13,  1],
        [ 0,  0,  0],
        [ 0,  0, 15],
        [ 0, 15, 12],
        [15, 12,  9],
        [12,  9, 22],
        [ 9, 22,  9],
        [22,  9,  1],
        [ 0,  0,  0],
        [ 0,  0,  1],
        [ 0,  1, 22],
        [ 1, 22,  1],
        [ 0,  0,  0],
        [ 0,  0,  9],
        [ 0,  9, 19],
        [ 9, 19,  1],
        [19,  1,  2],
        [ 1,  2,  5],
        [ 2,  5, 12],
        [ 5, 12, 12],
        [12, 12,  1],
        [ 0,  0,  0],
        [ 0,  0, 19],
        [ 0, 19, 15],
        [19, 15, 16],
        [15, 16,  8],
        [16,  8,  9],
        [ 8,  9,  1]])

In [166]:
Y

tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0])

In [167]:
C = torch.randn([27, 2])
C

tensor([[-1.1232,  0.5990],
        [ 0.1396,  1.8691],
        [-0.7492,  1.1291],
        [ 1.6471,  0.2546],
        [ 0.4301, -1.7655],
        [-0.8806, -0.1010],
        [ 1.8179,  1.5176],
        [-1.2494, -0.5209],
        [-0.0534,  1.0390],
        [-0.2567, -0.0047],
        [ 0.0551, -0.3240],
        [ 2.2759, -0.9436],
        [-0.7089, -0.2337],
        [-1.0022,  0.9037],
        [-0.0350, -1.1494],
        [-1.9505, -0.4709],
        [ 0.2653, -1.4764],
        [ 0.4577,  0.5937],
        [ 1.7221,  1.2049],
        [-0.6718,  1.0432],
        [-0.5086,  0.2634],
        [ 2.3453, -0.6539],
        [ 0.5631,  1.7680],
        [-0.7234,  2.4806],
        [ 1.4669,  1.7232],
        [ 1.1773,  0.1628],
        [-2.2183, -0.8590]])

In [168]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([-0.8806, -0.1010])

In [169]:
C[5]

tensor([-0.8806, -0.1010])

In [170]:
C[[5, 6, 7]]

tensor([[-0.8806, -0.1010],
        [ 1.8179,  1.5176],
        [-1.2494, -0.5209]])

In [171]:
C[torch.tensor([5, 6, 7, 7])]

tensor([[-0.8806, -0.1010],
        [ 1.8179,  1.5176],
        [-1.2494, -0.5209],
        [-1.2494, -0.5209]])

In [172]:
C[X]

tensor([[[-1.1232,  0.5990],
         [-1.1232,  0.5990],
         [-1.1232,  0.5990]],

        [[-1.1232,  0.5990],
         [-1.1232,  0.5990],
         [-0.8806, -0.1010]],

        [[-1.1232,  0.5990],
         [-0.8806, -0.1010],
         [-1.0022,  0.9037]],

        [[-0.8806, -0.1010],
         [-1.0022,  0.9037],
         [-1.0022,  0.9037]],

        [[-1.0022,  0.9037],
         [-1.0022,  0.9037],
         [ 0.1396,  1.8691]],

        [[-1.1232,  0.5990],
         [-1.1232,  0.5990],
         [-1.1232,  0.5990]],

        [[-1.1232,  0.5990],
         [-1.1232,  0.5990],
         [-1.9505, -0.4709]],

        [[-1.1232,  0.5990],
         [-1.9505, -0.4709],
         [-0.7089, -0.2337]],

        [[-1.9505, -0.4709],
         [-0.7089, -0.2337],
         [-0.2567, -0.0047]],

        [[-0.7089, -0.2337],
         [-0.2567, -0.0047],
         [ 0.5631,  1.7680]],

        [[-0.2567, -0.0047],
         [ 0.5631,  1.7680],
         [-0.2567, -0.0047]],

        [[ 0.5631,  1

In [173]:
C[X].shape

torch.Size([32, 3, 2])

In [174]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [175]:
torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], dim=1)

tensor([[-1.1232,  0.5990, -1.1232,  0.5990, -1.1232,  0.5990],
        [-1.1232,  0.5990, -1.1232,  0.5990, -0.8806, -0.1010],
        [-1.1232,  0.5990, -0.8806, -0.1010, -1.0022,  0.9037],
        [-0.8806, -0.1010, -1.0022,  0.9037, -1.0022,  0.9037],
        [-1.0022,  0.9037, -1.0022,  0.9037,  0.1396,  1.8691],
        [-1.1232,  0.5990, -1.1232,  0.5990, -1.1232,  0.5990],
        [-1.1232,  0.5990, -1.1232,  0.5990, -1.9505, -0.4709],
        [-1.1232,  0.5990, -1.9505, -0.4709, -0.7089, -0.2337],
        [-1.9505, -0.4709, -0.7089, -0.2337, -0.2567, -0.0047],
        [-0.7089, -0.2337, -0.2567, -0.0047,  0.5631,  1.7680],
        [-0.2567, -0.0047,  0.5631,  1.7680, -0.2567, -0.0047],
        [ 0.5631,  1.7680, -0.2567, -0.0047,  0.1396,  1.8691],
        [-1.1232,  0.5990, -1.1232,  0.5990, -1.1232,  0.5990],
        [-1.1232,  0.5990, -1.1232,  0.5990,  0.1396,  1.8691],
        [-1.1232,  0.5990,  0.1396,  1.8691,  0.5631,  1.7680],
        [ 0.1396,  1.8691,  0.5631,  1.7

In [176]:
torch.cat(torch.unbind(emb, 1), 1)

tensor([[-1.1232,  0.5990, -1.1232,  0.5990, -1.1232,  0.5990],
        [-1.1232,  0.5990, -1.1232,  0.5990, -0.8806, -0.1010],
        [-1.1232,  0.5990, -0.8806, -0.1010, -1.0022,  0.9037],
        [-0.8806, -0.1010, -1.0022,  0.9037, -1.0022,  0.9037],
        [-1.0022,  0.9037, -1.0022,  0.9037,  0.1396,  1.8691],
        [-1.1232,  0.5990, -1.1232,  0.5990, -1.1232,  0.5990],
        [-1.1232,  0.5990, -1.1232,  0.5990, -1.9505, -0.4709],
        [-1.1232,  0.5990, -1.9505, -0.4709, -0.7089, -0.2337],
        [-1.9505, -0.4709, -0.7089, -0.2337, -0.2567, -0.0047],
        [-0.7089, -0.2337, -0.2567, -0.0047,  0.5631,  1.7680],
        [-0.2567, -0.0047,  0.5631,  1.7680, -0.2567, -0.0047],
        [ 0.5631,  1.7680, -0.2567, -0.0047,  0.1396,  1.8691],
        [-1.1232,  0.5990, -1.1232,  0.5990, -1.1232,  0.5990],
        [-1.1232,  0.5990, -1.1232,  0.5990,  0.1396,  1.8691],
        [-1.1232,  0.5990,  0.1396,  1.8691,  0.5631,  1.7680],
        [ 0.1396,  1.8691,  0.5631,  1.7

In [177]:
emb.view(32, 6)

tensor([[-1.1232,  0.5990, -1.1232,  0.5990, -1.1232,  0.5990],
        [-1.1232,  0.5990, -1.1232,  0.5990, -0.8806, -0.1010],
        [-1.1232,  0.5990, -0.8806, -0.1010, -1.0022,  0.9037],
        [-0.8806, -0.1010, -1.0022,  0.9037, -1.0022,  0.9037],
        [-1.0022,  0.9037, -1.0022,  0.9037,  0.1396,  1.8691],
        [-1.1232,  0.5990, -1.1232,  0.5990, -1.1232,  0.5990],
        [-1.1232,  0.5990, -1.1232,  0.5990, -1.9505, -0.4709],
        [-1.1232,  0.5990, -1.9505, -0.4709, -0.7089, -0.2337],
        [-1.9505, -0.4709, -0.7089, -0.2337, -0.2567, -0.0047],
        [-0.7089, -0.2337, -0.2567, -0.0047,  0.5631,  1.7680],
        [-0.2567, -0.0047,  0.5631,  1.7680, -0.2567, -0.0047],
        [ 0.5631,  1.7680, -0.2567, -0.0047,  0.1396,  1.8691],
        [-1.1232,  0.5990, -1.1232,  0.5990, -1.1232,  0.5990],
        [-1.1232,  0.5990, -1.1232,  0.5990,  0.1396,  1.8691],
        [-1.1232,  0.5990,  0.1396,  1.8691,  0.5631,  1.7680],
        [ 0.1396,  1.8691,  0.5631,  1.7

In [178]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [179]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

In [180]:
h

tensor([[-0.4022, -0.7067, -0.7679,  ...,  0.9510,  0.9940, -0.1875],
        [-0.9140, -0.9431, -0.8712,  ...,  0.9637,  0.9867, -0.0159],
        [-0.4428, -0.4182, -0.6355,  ...,  0.9807,  0.9928,  0.2494],
        ...,
        [-0.7353,  0.9663, -0.8908,  ...,  0.0693,  0.9973, -0.7554],
        [ 0.9067,  0.8773, -0.7843,  ..., -0.9963,  0.9938, -0.9359],
        [ 0.9514,  0.4171,  0.5301,  ...,  0.9999,  0.9219,  0.9407]])

In [181]:
h.shape

torch.Size([32, 100])

In [182]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [183]:
logits = h @ W2 + b2

In [184]:
logits.shape

torch.Size([32, 27])

In [185]:
counts = logits.exp()

In [186]:
probs = counts / counts.sum(1, keepdim=True)

In [187]:
probs.shape

torch.Size([32, 27])

In [188]:
probs[0].sum()

tensor(1.0000)

In [189]:
-probs[torch.arange(32), Y].log().mean()

tensor(18.5443)

In [190]:
# --- now made respectable :) ---

In [223]:
# build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?

X, Y = [], []
for w in words[:5]:
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]
X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


In [224]:
X.shape, Y.shape

(torch.Size([32, 3]), torch.Size([32]))

In [225]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn(27, 2, generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]

In [226]:
sum(p.nelement() for p in parameters) # number of parameters in total

3481

In [227]:
for p in parameters:
    p.requires_grad = True

In [228]:
for _ in range(1000):
    # forward pass
    emb = C[X]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2
    # counts = logits.exp()
    # probs = counts / counts.sum(1, keepdim=True)
    # loss = -probs[torch.arange(32), Y].log().mean()
    loss = F.cross_entropy(logits, Y)
    print(loss.item())
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update 
    for p in parameters:
        p.data += -0.1 * p.grad

17.76971435546875
13.656402587890625
11.298770904541016
9.452458381652832
7.984263896942139
6.891323089599609
6.100015640258789
5.452036380767822
4.8981523513793945
4.414664268493652
3.9858498573303223
3.6028308868408203
3.2621421813964844
2.96138072013855
2.6982970237731934
2.469712734222412
2.271660566329956
2.1012842655181885
1.957176923751831
1.8374860286712646
1.7380965948104858
1.6535117626190186
1.5790901184082031
1.5117673873901367
1.449605107307434
1.3913123607635498
1.3359928131103516
1.2830535173416138
1.232191801071167
1.1833821535110474
1.1367993354797363
1.092665195465088
1.0510929822921753
1.012027621269226
0.9752710461616516
0.9405569434165955
0.9076131582260132
0.8761925101280212
0.8460896015167236
0.8171363472938538
0.7891995310783386
0.7621749639511108
0.7359815835952759
0.710558295249939
0.6858615279197693
0.6618656516075134
0.6385658979415894
0.6159821152687073
0.5941662192344666
0.5732107162475586
0.553256630897522
0.5344885587692261
0.5171172618865967
0.501331806

In [229]:
loss

tensor(0.2561, grad_fn=<NllLossBackward0>)

In [230]:
logits.shape

torch.Size([32, 27])

In [234]:
logits

tensor([[ 3.7421e+00,  1.3235e+01,  5.3940e+00,  4.0988e+00, -7.6535e-01,
          1.3222e+01, -1.8069e+01,  5.7285e+00, -3.5979e+00,  1.3118e+01,
          6.8001e+00, -3.3704e+00, -1.8751e-01,  7.3862e+00,  3.0223e-01,
          1.3197e+01,  3.6976e+00,  6.6991e+00, -7.2902e-01,  1.3335e+01,
         -2.3995e+00,  4.7283e+00,  7.0238e+00, -6.4949e+00,  6.4408e+00,
         -5.4583e+00,  5.9195e+00],
        [-9.0184e+00,  7.5491e+00,  1.9528e+00,  1.4233e+00,  5.6973e+00,
          9.9765e+00, -2.4487e+00,  4.2988e+00,  8.6141e+00, -3.7782e+00,
          1.0311e+01, -9.6241e-01,  1.2333e+01,  1.7791e+01,  5.8880e+00,
          6.2631e+00,  8.7805e+00,  1.9645e+00,  7.4622e+00,  9.4323e+00,
          1.9641e+00, -9.0605e+00, -9.9813e+00, -1.1730e+01,  1.6828e+00,
         -4.0623e+00,  6.2619e+00],
        [ 8.0140e+00,  1.1974e+01,  1.0477e+01,  5.6248e+00, -1.9768e+00,
          5.6115e+00, -9.9594e+00, -8.7409e+00, -1.1285e+01,  1.4137e+01,
          5.7153e+00,  1.1738e+01,  5.32

In [233]:
logits.max(1)

torch.return_types.max(
values=tensor([13.3347, 17.7906, 20.6013, 20.6118, 16.7355, 13.3347, 15.9986, 14.1725,
        15.9149, 18.3614, 15.9397, 20.9265, 13.3347, 17.1088, 17.1319, 20.0600,
        13.3347, 16.5889, 15.1016, 17.0579, 18.5863, 15.9671, 10.8740, 10.6872,
        15.5056, 13.3347, 16.1793, 16.9743, 12.7427, 16.2007, 19.0847, 16.0194],
       grad_fn=<MaxBackward0>),
indices=tensor([19, 13, 13,  1,  0, 19, 12,  9, 22,  9,  1,  0, 19, 22,  1,  0, 19, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0]))

In [232]:
Y

tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0])