In [1]:
import torch
import torch.nn.functional as F

# Goal is to create a NN to take trigrams or more
1. Create X with input letters, create Y with output. Need to create itos and stoi
3. Create C which is a lookup table of dimension 2
4. Create a NN that takes that and gives our probabilities of next letter
5. calculate loss and reduce loss

In [181]:
words = open('names.txt', 'r').read().splitlines()
print(words[:5])

['emma', 'olivia', 'ava', 'isabella', 'sophia']


In [182]:
# create itos and stoi
chars = sorted(set(''.join(words)))
itos = {i+1: ch for i, ch in enumerate(chars)}
itos[0] = '.'
stoi = {ch: i for i, ch in itos.items()}



In [248]:
import numpy as np
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  for w in words:

    #print(w)

    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      #print(''.join(itos[i] for i in context), '--->', itos[ix])
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

torch.Size([182484, 3]) torch.Size([182484])
torch.Size([22869, 3]) torch.Size([22869])
torch.Size([22793, 3]) torch.Size([22793])


In [249]:
emb_size = 10
hidden_size = 300


g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, emb_size), generator=g)
W1 = torch.randn((block_size * emb_size, hidden_size), generator=g)
B1 = torch.randn((hidden_size), generator=g)
W2 = torch.randn((hidden_size, 27), generator=g)
B2 = torch.randn((27), generator=g)

parameters = [C, W1, B1, W2, B2]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

17697


In [250]:
batch_size = 32
n_epochs = 178001

for i in range(n_epochs):
    # forward pass
    #minibatch construct
    ix = torch.randint(0, Xtr.shape[0], (batch_size,))

    emb = C[Xtr[ix]]
    h = torch.tanh(emb.view(-1, block_size * emb_size) @ W1 + B1)
    logits = h @ W2 + B2
    loss = F.cross_entropy(logits, Ytr[ix])
    if i % 1000 == 0:
        print(f"{i} --> {loss.item()}")
              
    for p in parameters:
        p.grad = None

    loss.backward()
    #if i < int(n_epochs * 0.25):
    #    lr = 0.1
    #elif i < int(n_epochs * 0.75):
    #    lr = 0.01
    #else:
    #    lr = 0.0001
    lr = 0.1 if i < (n_epochs * 0.8) else 0.01
    for p in parameters:
        p.data -= 0.01 * p.grad 
print(loss.item())

0 --> 27.76041030883789
1000 --> 7.383129596710205
2000 --> 4.967247009277344
3000 --> 5.519577503204346
4000 --> 3.9406542778015137
5000 --> 3.1676886081695557
6000 --> 3.921692371368408
7000 --> 2.874091863632202
8000 --> 4.2571001052856445
9000 --> 3.5421793460845947
10000 --> 2.91940975189209
11000 --> 3.549651622772217
12000 --> 3.01814341545105
13000 --> 2.8955399990081787
14000 --> 3.216179370880127
15000 --> 3.2225229740142822
16000 --> 2.525787830352783
17000 --> 2.3461453914642334
18000 --> 2.481290817260742
19000 --> 3.1850271224975586
20000 --> 2.931797981262207
21000 --> 2.186738967895508
22000 --> 2.3855080604553223
23000 --> 2.3878979682922363
24000 --> 2.40989089012146
25000 --> 2.6075611114501953
26000 --> 2.828981399536133
27000 --> 2.7231318950653076
28000 --> 2.896237850189209
29000 --> 2.7958455085754395
30000 --> 3.064750909805298
31000 --> 2.4138758182525635
32000 --> 2.2294819355010986
33000 --> 2.2644565105438232
34000 --> 2.461005210876465
35000 --> 2.37145805

In [251]:
#beat xDev of 2.17 loss
emb = C[Xdev]
h = torch.tanh(emb.view(-1, block_size * emb_size) @ W1 + B1)

logits = h @ W2 + B2
loss = F.cross_entropy(logits, Ydev)
print(loss.item())

2.3074467182159424


In [247]:
emb = C[Xte]
h = torch.tanh(emb.view(-1, block_size * emb_size) @ W1 + B1)
logits = h @ W2 + B2
loss = F.cross_entropy(logits, Yte)
print(loss.item())

2.3086841106414795
