In [1]:
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# read in all the words

words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
# build hr vocab of charactersandmappings to/from integers
chars = sorted(list(set(''.join(words))))   
stoi = { ch:i for i,ch in enumerate(chars) }
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{0: '.', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j', 10: 'k', 11: 'l', 12: 'm', 13: 'n', 14: 'o', 15: 'p', 16: 'q', 17: 'r', 18: 's', 19: 't', 20: 'u', 21: 'v', 22: 'w', 23: 'x', 24: 'y', 25: 'z'}


In [4]:
#build the dataset

block_size = 3
X, Y = [], []

for w in words:
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        # print(''.join([itos[i] for i in context]), '--->', itos[ix])
        context = context[1:] + [ix]


X = torch.tensor(X)
Y = torch.tensor(Y)


In [5]:

X.shape, Y.shape

(torch.Size([228146, 3]), torch.Size([228146]))

In [6]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]

In [7]:
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True



3481


In [29]:
for _ in range(100):

    # minibatch construct
    ix = torch.randint(0, X.shape[0], (32,))


    # forward pass
    emb = C[X[ix]] # (32, 3, 2)
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
    logits = h @ W2 + b2 # (32, 27)
    loss = F.cross_entropy(logits, Y[ix])
    
    for p in parameters:
        p.grad = None
    loss.backward()
    #update
    for p in parameters:
        p.data += -0.1 * p.grad
print(loss.item())

2.3691954612731934
2.9099276065826416
2.112687110900879
2.653762102127075
2.891634225845337
2.188856840133667
2.711789846420288
2.416246175765991
2.073697328567505
2.641990900039673
2.3897008895874023
2.2222719192504883
2.378560781478882
2.243832588195801
2.863234758377075
2.4920461177825928
2.9885661602020264
2.65461802482605
2.509366989135742
2.546501398086548
2.3439981937408447
2.411078453063965
2.2023608684539795
2.3983614444732666
2.0666842460632324
2.2146198749542236
3.1297576427459717
2.64968204498291
2.4706151485443115
2.5774712562561035
2.682732582092285
2.315634250640869
2.1759328842163086
2.331610918045044
2.3621528148651123
2.7122483253479004
2.4999375343322754
3.1129531860351562
2.4066853523254395
2.404245138168335
2.157064199447632
2.0063774585723877
2.4333345890045166
2.527406692504883
2.3567981719970703
2.6762256622314453
2.6340646743774414
2.535440683364868
2.434802770614624
2.480407238006592
2.5462303161621094
2.015171527862549
2.6523659229278564
2.654015064239502
2.4

In [9]:
torch.randint(0, X.shape[0], (32, ))

tensor([133484,  94788, 114358, 141683, 111286,  93049,  65770,   7480, 187921,
        153476,  64192,  43697, 133952,  10552,  89086, 166542,  83006,  82803,
         14094, 174110,  24291,  52362, 123879, 164543,   5667,  58433, 209722,
        203952,  24106, 158902, 121734,  43743])