In [81]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [82]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()

In [83]:
words[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [84]:
# build the vocabulary of characters and mappings to/from integers

chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [85]:
# build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?

X, Y = [], []
for w in words[:5]:
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]
X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


In [86]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [87]:
X

tensor([[ 0,  0,  0],
        [ 0,  0,  5],
        [ 0,  5, 13],
        [ 5, 13, 13],
        [13, 13,  1],
        [ 0,  0,  0],
        [ 0,  0, 15],
        [ 0, 15, 12],
        [15, 12,  9],
        [12,  9, 22],
        [ 9, 22,  9],
        [22,  9,  1],
        [ 0,  0,  0],
        [ 0,  0,  1],
        [ 0,  1, 22],
        [ 1, 22,  1],
        [ 0,  0,  0],
        [ 0,  0,  9],
        [ 0,  9, 19],
        [ 9, 19,  1],
        [19,  1,  2],
        [ 1,  2,  5],
        [ 2,  5, 12],
        [ 5, 12, 12],
        [12, 12,  1],
        [ 0,  0,  0],
        [ 0,  0, 19],
        [ 0, 19, 15],
        [19, 15, 16],
        [15, 16,  8],
        [16,  8,  9],
        [ 8,  9,  1]])

In [88]:
Y

tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0])

In [89]:
C = torch.randn([27, 2])
C

tensor([[-1.3928,  1.0513],
        [ 0.9430, -0.7710],
        [-0.5663,  0.3142],
        [-0.5587, -0.6742],
        [ 1.2576, -0.5143],
        [ 1.0726, -0.7035],
        [ 1.3266,  0.6150],
        [-1.4680,  0.6049],
        [ 1.5821, -0.6199],
        [ 0.7130,  0.3641],
        [-0.8849,  0.2152],
        [-1.1632, -3.3464],
        [-0.9703, -1.4101],
        [ 1.5312,  0.7711],
        [-0.6469,  0.0836],
        [-0.7248, -0.4552],
        [-0.5576, -0.8076],
        [ 1.5122, -1.5371],
        [-0.4425,  1.1260],
        [-0.0539,  1.1328],
        [-1.2659,  0.8406],
        [ 0.5055, -0.3863],
        [ 0.4232,  0.4980],
        [ 1.6621,  0.1320],
        [-0.2332, -0.7729],
        [-1.3071,  1.7987],
        [-0.4243, -0.4944]])

In [90]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([ 1.0726, -0.7035])

In [91]:
C[5]

tensor([ 1.0726, -0.7035])

In [92]:
C[[5, 6, 7]]

tensor([[ 1.0726, -0.7035],
        [ 1.3266,  0.6150],
        [-1.4680,  0.6049]])

In [93]:
C[torch.tensor([5, 6, 7, 7])]

tensor([[ 1.0726, -0.7035],
        [ 1.3266,  0.6150],
        [-1.4680,  0.6049],
        [-1.4680,  0.6049]])

In [94]:
C[X]

tensor([[[-1.3928,  1.0513],
         [-1.3928,  1.0513],
         [-1.3928,  1.0513]],

        [[-1.3928,  1.0513],
         [-1.3928,  1.0513],
         [ 1.0726, -0.7035]],

        [[-1.3928,  1.0513],
         [ 1.0726, -0.7035],
         [ 1.5312,  0.7711]],

        [[ 1.0726, -0.7035],
         [ 1.5312,  0.7711],
         [ 1.5312,  0.7711]],

        [[ 1.5312,  0.7711],
         [ 1.5312,  0.7711],
         [ 0.9430, -0.7710]],

        [[-1.3928,  1.0513],
         [-1.3928,  1.0513],
         [-1.3928,  1.0513]],

        [[-1.3928,  1.0513],
         [-1.3928,  1.0513],
         [-0.7248, -0.4552]],

        [[-1.3928,  1.0513],
         [-0.7248, -0.4552],
         [-0.9703, -1.4101]],

        [[-0.7248, -0.4552],
         [-0.9703, -1.4101],
         [ 0.7130,  0.3641]],

        [[-0.9703, -1.4101],
         [ 0.7130,  0.3641],
         [ 0.4232,  0.4980]],

        [[ 0.7130,  0.3641],
         [ 0.4232,  0.4980],
         [ 0.7130,  0.3641]],

        [[ 0.4232,  0

In [95]:
C[X].shape

torch.Size([32, 3, 2])

In [96]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [97]:
torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], dim=1)

tensor([[-1.3928,  1.0513, -1.3928,  1.0513, -1.3928,  1.0513],
        [-1.3928,  1.0513, -1.3928,  1.0513,  1.0726, -0.7035],
        [-1.3928,  1.0513,  1.0726, -0.7035,  1.5312,  0.7711],
        [ 1.0726, -0.7035,  1.5312,  0.7711,  1.5312,  0.7711],
        [ 1.5312,  0.7711,  1.5312,  0.7711,  0.9430, -0.7710],
        [-1.3928,  1.0513, -1.3928,  1.0513, -1.3928,  1.0513],
        [-1.3928,  1.0513, -1.3928,  1.0513, -0.7248, -0.4552],
        [-1.3928,  1.0513, -0.7248, -0.4552, -0.9703, -1.4101],
        [-0.7248, -0.4552, -0.9703, -1.4101,  0.7130,  0.3641],
        [-0.9703, -1.4101,  0.7130,  0.3641,  0.4232,  0.4980],
        [ 0.7130,  0.3641,  0.4232,  0.4980,  0.7130,  0.3641],
        [ 0.4232,  0.4980,  0.7130,  0.3641,  0.9430, -0.7710],
        [-1.3928,  1.0513, -1.3928,  1.0513, -1.3928,  1.0513],
        [-1.3928,  1.0513, -1.3928,  1.0513,  0.9430, -0.7710],
        [-1.3928,  1.0513,  0.9430, -0.7710,  0.4232,  0.4980],
        [ 0.9430, -0.7710,  0.4232,  0.4

In [98]:
torch.cat(torch.unbind(emb, 1), 1)

tensor([[-1.3928,  1.0513, -1.3928,  1.0513, -1.3928,  1.0513],
        [-1.3928,  1.0513, -1.3928,  1.0513,  1.0726, -0.7035],
        [-1.3928,  1.0513,  1.0726, -0.7035,  1.5312,  0.7711],
        [ 1.0726, -0.7035,  1.5312,  0.7711,  1.5312,  0.7711],
        [ 1.5312,  0.7711,  1.5312,  0.7711,  0.9430, -0.7710],
        [-1.3928,  1.0513, -1.3928,  1.0513, -1.3928,  1.0513],
        [-1.3928,  1.0513, -1.3928,  1.0513, -0.7248, -0.4552],
        [-1.3928,  1.0513, -0.7248, -0.4552, -0.9703, -1.4101],
        [-0.7248, -0.4552, -0.9703, -1.4101,  0.7130,  0.3641],
        [-0.9703, -1.4101,  0.7130,  0.3641,  0.4232,  0.4980],
        [ 0.7130,  0.3641,  0.4232,  0.4980,  0.7130,  0.3641],
        [ 0.4232,  0.4980,  0.7130,  0.3641,  0.9430, -0.7710],
        [-1.3928,  1.0513, -1.3928,  1.0513, -1.3928,  1.0513],
        [-1.3928,  1.0513, -1.3928,  1.0513,  0.9430, -0.7710],
        [-1.3928,  1.0513,  0.9430, -0.7710,  0.4232,  0.4980],
        [ 0.9430, -0.7710,  0.4232,  0.4

In [99]:
emb.view(32, 6)

tensor([[-1.3928,  1.0513, -1.3928,  1.0513, -1.3928,  1.0513],
        [-1.3928,  1.0513, -1.3928,  1.0513,  1.0726, -0.7035],
        [-1.3928,  1.0513,  1.0726, -0.7035,  1.5312,  0.7711],
        [ 1.0726, -0.7035,  1.5312,  0.7711,  1.5312,  0.7711],
        [ 1.5312,  0.7711,  1.5312,  0.7711,  0.9430, -0.7710],
        [-1.3928,  1.0513, -1.3928,  1.0513, -1.3928,  1.0513],
        [-1.3928,  1.0513, -1.3928,  1.0513, -0.7248, -0.4552],
        [-1.3928,  1.0513, -0.7248, -0.4552, -0.9703, -1.4101],
        [-0.7248, -0.4552, -0.9703, -1.4101,  0.7130,  0.3641],
        [-0.9703, -1.4101,  0.7130,  0.3641,  0.4232,  0.4980],
        [ 0.7130,  0.3641,  0.4232,  0.4980,  0.7130,  0.3641],
        [ 0.4232,  0.4980,  0.7130,  0.3641,  0.9430, -0.7710],
        [-1.3928,  1.0513, -1.3928,  1.0513, -1.3928,  1.0513],
        [-1.3928,  1.0513, -1.3928,  1.0513,  0.9430, -0.7710],
        [-1.3928,  1.0513,  0.9430, -0.7710,  0.4232,  0.4980],
        [ 0.9430, -0.7710,  0.4232,  0.4

In [100]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [101]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

In [102]:
h

tensor([[-0.6077, -0.9967,  0.2421,  ..., -0.9888,  0.9809, -1.0000],
        [-0.9440, -1.0000, -1.0000,  ...,  0.9355,  0.9812, -0.9935],
        [ 0.9796, -0.8199, -0.9855,  ..., -0.5294,  0.9797, -0.9994],
        ...,
        [-0.4682, -0.6108, -0.2797,  ...,  0.9957,  0.9868,  0.9934],
        [ 0.9846,  0.9969,  0.9694,  ..., -0.7001,  0.7028,  0.9628],
        [ 0.4133,  0.4860, -0.6045,  ...,  0.9828, -0.9780,  1.0000]])

In [103]:
h.shape

torch.Size([32, 100])

In [104]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [105]:
logits = h @ W2 + b2

In [106]:
logits.shape

torch.Size([32, 27])

In [107]:
counts = logits.exp()

In [108]:
probs = counts / counts.sum(1, keepdim=True)

In [109]:
probs.shape

torch.Size([32, 27])

In [110]:
probs[0].sum()

tensor(1.0000)

In [112]:
-probs[torch.arange(32), Y].log().mean()

tensor(16.7067)