In [2]:
import torch
import torch.nn.functional as F

# Goal is to create a NN to take trigrams or more
1. Create X with input letters, create Y with output. Need to create itos and stoi
3. Create C which is a lookup table of dimension 2
4. Create a NN that takes that and gives our probabilities of next letter
5. calculate loss and reduce loss

In [3]:
words = open('names.txt', 'r').read().splitlines()
print(words[:5])

['emma', 'olivia', 'ava', 'isabella', 'sophia']


In [4]:
# create itos and stoi
chars = sorted(set(''.join(words)))
itos = {i+1: ch for i, ch in enumerate(chars)}
itos[0] = '.'
stoi = {ch: i for i, ch in itos.items()}



In [5]:
import numpy as np
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  for w in words:

    #print(w)

    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      #print(''.join(itos[i] for i in context), '--->', itos[ix])
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [13]:
emb_size = 10
hidden_size = 300


g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, emb_size), generator=g)
W1 = torch.randn((block_size * emb_size, hidden_size),  generator=g) * ((5/3)/((block_size * emb_size)**0.5))
B1 = torch.randn((hidden_size), generator=g) * 0.01
W2 = torch.randn((hidden_size, 27), generator=g) * 0.01
B2 = torch.randn((27), generator=g) * 0

bngain = torch.ones(((1, hidden_size)))
bnbias = torch.zeros((1, hidden_size))

parameters = [C, W1, B1, W2, B2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

17697


In [14]:
batch_size = 32
n_epochs = 200000

for i in range(n_epochs):
    # forward pass
    #minibatch construct
    ix = torch.randint(0, Xtr.shape[0], (batch_size,))

    emb = C[Xtr[ix]]
    embcat = emb.view(-1, block_size * emb_size)
    hprev = embcat @ W1 + B1
    h = torch.tanh(hprev)
    logits = h @ W2 + B2
    loss = F.cross_entropy(logits, Ytr[ix])
    if i % 1000 == 0:
        print(f"{i} --> {loss.item()}")
              
    for p in parameters:
        p.grad = None

    loss.backward()
    lr = 0.1 if i < (n_epochs * 0.8) else 0.01
    for p in parameters:
        p.data -= 0.01 * p.grad 
print(loss.item())

0 --> 3.2957422733306885
1000 --> 2.680158853530884
2000 --> 2.3803117275238037
3000 --> 2.5010952949523926
4000 --> 2.270599842071533
5000 --> 2.249739646911621
6000 --> 2.272958517074585
7000 --> 2.0339412689208984
8000 --> 2.260019302368164
9000 --> 2.193295955657959
10000 --> 2.413416624069214
11000 --> 2.163043260574341
12000 --> 2.3071162700653076
13000 --> 2.3425326347351074
14000 --> 1.9942249059677124
15000 --> 2.0277786254882812
16000 --> 2.3108320236206055
17000 --> 2.3720836639404297
18000 --> 2.2831814289093018
19000 --> 2.2127387523651123
20000 --> 2.1985628604888916
21000 --> 2.2543423175811768
22000 --> 2.443128824234009
23000 --> 2.179368734359741
24000 --> 2.1791648864746094
25000 --> 1.9956285953521729
26000 --> 2.217833995819092
27000 --> 2.024867296218872
28000 --> 2.258545160293579
29000 --> 2.4722695350646973
30000 --> 1.99701988697052
31000 --> 1.884708285331726
32000 --> 2.452099561691284
33000 --> 2.0579142570495605
34000 --> 2.117550849914551
35000 --> 2.3904

In [15]:
#beat xDev of 2.17 loss
emb = C[Xdev]
h = torch.tanh(emb.view(-1, block_size * emb_size) @ W1 + B1)

logits = h @ W2 + B2
loss = F.cross_entropy(logits, Ydev)
print(loss.item())

2.119819402694702


In [247]:
emb = C[Xte]
h = torch.tanh(emb.view(-1, block_size * emb_size) @ W1 + B1)
logits = h @ W2 + B2
loss = F.cross_entropy(logits, Yte)
print(loss.item())

2.3086841106414795
