In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
torch.manual_seed(42); # seed rng for reproducibility
from minitorch import Linear
from minitorch import BatchNorm1d
from minitorch import Tanh
from minitorch import Embedding
from minitorch import Flatten
from minitorch import FlattenConsecutive
from minitorch import Sequential

In [2]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [3]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [4]:
# build the dataset
block_size = 8 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
    X, Y = [], []

    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # crop and append

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182625, 8]) torch.Size([182625])
torch.Size([22655, 8]) torch.Size([22655])
torch.Size([22866, 8]) torch.Size([22866])


In [5]:
n_emb = 15
n_hidden = 200

model = Sequential([
    Embedding(vocab_size, n_emb),
    Flatten(),
    Linear(n_emb * block_size, n_hidden), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden, vocab_size)
])

with torch.no_grad():
    model.layers[-1].weight *= 0.1
    
parameters = model.parameters()
print("Number of Parameters: ", sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

Number of Parameters:  30205


In [6]:
iterations = 100000
batch_size = 32
lossi = []

for i in range(iterations):
    ix = torch.randint(0, Xtr.shape[0], (batch_size, ))
    Xbatch, Ybatch = Xtr[ix], Ytr[ix]
    
    logits = model(Xbatch)
    loss = F.cross_entropy(logits, Ybatch)
    
    for p in parameters:
        p.grad = None
    
    loss.backward()
    
    lr = 0.1 if i < 60000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad
    
    if i % 10000 == 0:
        print(f'{i:7d}/{iterations:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())

      0/ 100000: 3.2967
  10000/ 100000: 1.9924
  20000/ 100000: 1.9962
  30000/ 100000: 2.1464
  40000/ 100000: 2.2175
  50000/ 100000: 1.9653
  60000/ 100000: 2.2802
  70000/ 100000: 2.3400
  80000/ 100000: 2.2383
  90000/ 100000: 2.0205


In [7]:
# put layers into eval mode (needed for batchnorm especially)
for layer in model.layers:
    layer.training = False
model.layers[3].running_var.shape

torch.Size([1, 200])

In [8]:
# evaluate the loss
@torch.no_grad() # this decorator disables gradient tracking inside pytorch
def split_loss(split):
    x,y = {
        'train': (Xtr, Ytr),
        'val': (Xdev, Ydev),
        'test': (Xte, Yte),
    }[split]
    logits = model(x)
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss('train')
split_loss('val')

train 1.9203920364379883
val 2.0220963954925537


In [9]:
num_samples = 20
for _ in range(num_samples):
    out = []
    context = [0] * block_size
    
    while True:
        logits = model(torch.tensor([context]))
        probs = F.softmax(logits, dim=1)
        
        ix = torch.multinomial(probs, num_samples=1).item()
        context = context[1:] + [ix]
        out.append(ix)
        
        if ix == 0:
            break
    
    print(''.join(itos[i] for i in out))

ahrovellam.
daily.
barbesta.
da.
jasin.
rirnola.
lovelle.
emerson.
ana.
aaristays.
alex.
indra.
dalanne.
oakhar.
solany.
geovanno.
quelle.
arnson.
dayad.
dlace.
