In [11]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt 
%matplotlib inline 

In [12]:
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [13]:
len(words)

32033

In [15]:
#build vocab with numbers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.']=0
print(stoi)
itos = {i:s for s, i in stoi.items()}
itos

{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}


{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z',
 0: '.'}

In [23]:
block_size = 3 # context length 
X, Y = [], []

for w in words[:5]:
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '-->', itos[ix])
        context = context[1:] + [ix] # crop and append 
        
X = torch.tensor(X)
Y = torch.tensor(Y)    

emma
... --> e
..e --> m
.em --> m
emm --> a
mma --> .
olivia
... --> o
..o --> l
.ol --> i
oli --> v
liv --> i
ivi --> a
via --> .
ava
... --> a
..a --> v
.av --> a
ava --> .
isabella
... --> i
..i --> s
.is --> a
isa --> b
sab --> e
abe --> l
bel --> l
ell --> a
lla --> .
sophia
... --> s
..s --> o
.so --> p
sop --> h
oph --> i
phi --> a
hia --> .


In [24]:
X.shape, Y.shape

(torch.Size([32, 3]), torch.Size([32]))

## Embedding lookup table 

In [26]:
C = torch.randn((27, 2))

In [34]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([-1.3542,  0.5460])

In [39]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [40]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [45]:
torch.cat([emb[:, 0,:], emb[:, 1,:], emb[:, 2,:]], 1).shape

torch.Size([32, 6])

In [49]:
torch.cat(torch.unbind(emb, 1), 1).shape

torch.Size([32, 6])

In [53]:
torch.cat(torch.unbind(emb, 1), 1) == emb.view(32, 6)

tensor([[True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, T

In [63]:
logits1 = emb.view(emb.shape[0], -1) @ W1 + b1 #ensure correct broadcasting
logits1

tensor([[ 0.4033,  0.8159, -2.0899,  ..., -0.6351, -0.6248, -0.7896],
        [-0.9831,  0.8320, -1.4370,  ..., -0.6641, -1.4439,  2.5688],
        [-0.4639, -0.1428, -2.3741,  ...,  0.4687,  2.0136, -2.4000],
        ...,
        [-1.3665, -1.1130, -2.5003,  ..., -0.4460, -0.3268,  1.2497],
        [-0.7607,  2.3863, -1.3500,  ...,  0.6987,  4.5260, -3.5474],
        [-0.2602,  0.7923, -2.5721,  ..., -1.4039,  0.9577, -2.3219]])

In [64]:
h = torch.tanh(logits)

In [65]:
h.shape

torch.Size([32, 100])

In [66]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [67]:
logits = h @ W2 + b2 
logits.shape

torch.Size([32, 27])

In [68]:
counts = logits.exp()
prob = counts/counts.sum(1, keepdims=True)

In [73]:
Y

tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0])

In [76]:
# nll loss
loss = -prob[torch.arange(32), Y].log().mean()
loss

tensor(14.4074)

In [None]:
# reproduce and clean up 

