In [34]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline
torch.set_printoptions(sci_mode=False)

In [2]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [4]:
chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}
itos

{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z',
 0: '.'}

In [44]:
block_size = 3 # context length
X, Y = [], []
for w in words:
    # print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        # print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix] # crop and append
X = torch.tensor(X)
Y = torch.tensor(Y)

In [45]:
X.shape, X.dtype

(torch.Size([228146, 3]), torch.int64)

In [7]:
C = torch.randn((27, 2))

In [8]:
C

tensor([[ 0.7030,  0.2940],
        [-1.2391,  0.9402],
        [ 0.3882, -1.6560],
        [-0.0316, -1.7858],
        [ 0.6885, -0.5969],
        [-0.3218,  0.0069],
        [-0.3638, -1.0915],
        [ 0.0875, -0.0233],
        [-0.6470,  0.1373],
        [ 0.5179, -0.0314],
        [-1.0424, -0.3409],
        [ 1.1197,  0.8320],
        [ 0.6046,  0.2364],
        [-1.3388, -0.9880],
        [ 1.0588, -0.9405],
        [-0.2343, -2.7502],
        [-0.4643, -0.8355],
        [ 1.7109,  0.0808],
        [ 0.0316,  2.0751],
        [-2.1025, -0.2582],
        [-0.3265,  0.5111],
        [ 0.0050,  1.8499],
        [-0.3591, -0.5080],
        [ 0.3001,  0.4629],
        [-0.2839, -1.5159],
        [ 0.9426, -0.2920],
        [-1.4189,  0.2960]])

In [14]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [15]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [17]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

In [18]:
h.shape

torch.Size([32, 100])

In [19]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [21]:
logits = h @ W2 + b2
logits.shape

torch.Size([32, 27])

In [22]:
counts = logits.exp()
prob = counts / counts.sum(dim=1, keepdims=True)

In [24]:
prob[0].sum()

tensor(1.)

In [36]:
loss = -prob[torch.arange(32), Y].log().mean()
loss

tensor(14.7996)

In [37]:
parameters = [C, W1, W2, b1, b2]
sum(p.nelement() for p in parameters)

3481

In [76]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)

parameters = [C, W1, W2, b1, b2]
for p in parameters:
    p.requires_grad = True

In [77]:
lre = torch.linspace(-3, 0, 1000)
lrs = 10**lre
lrs

tensor([0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012,
        0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0013, 0.0013, 0.0013,
        0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0014,
        0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0016, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017,
        0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019,
        0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0022, 

In [81]:
lri = []
lossi = []
for i in range(10000):
    ix = torch.randint(0, X.shape[0], (32, ))
    emb = C[X[ix]]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
    logits = h @ W2 + b2 # (32 ,27)
    loss = F.cross_entropy(logits, Y[ix])
    
    for p in parameters:
        p.grad = None
    loss.backward()
    
    for p in parameters:
        p.data += -0.01 * p.grad
print(loss.item())

2.3428256511688232


In [82]:
emb = C[X]
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32 ,27)
loss = F.cross_entropy(logits, Y)
loss

tensor(2.3932, grad_fn=<NllLossBackward0>)