In [80]:
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [81]:
words = open('names.txt', 'r').read().splitlines()

In [82]:
words[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [83]:
chars =  ['.'] + sorted(list(set(''.join(words)))) 
stoi = {s: i for i, s in enumerate(chars)}
itos = {i: s for s, i in stoi.items()}

In [84]:
# dataset
context_size = 3
def build_dataset(words):
    X, Y = [], []
    for w in words:
        # print(w)
        context = [0]*context_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            # print(''.join(itos[i] for i in context), '-->', itos[ix])
            context = context[1:]+[ix]
        
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])


torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [85]:
Xtr.shape, Ytr.shape

(torch.Size([182625, 3]), torch.Size([182625]))

In [86]:
emb_dim = 16
C = torch.randn((27,emb_dim))

In [87]:
# emb = C[X]
# emb.shape

#### Flatten input tensor
we flatten input tensor because a 2d tensor of context_size * embed_dim needs to flattened so it can be inputted to the neural network which only takes a 1d tensor

In [88]:
# Different ways to flatten
# torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], 1)  # not flexible
# torch.cat(torch.unbind(emb, 1),1).shape  # flexible but not efficient
# emb.reshape((32,-1))  # efficient but may or may not return a copy
# emb_out = emb.view(emb.shape[0], -1) # does not create a copy but needs tensor to be contiguous!

In [89]:
context_size*emb_dim

48

In [90]:
W1 = torch.randn((context_size*emb_dim, 200))
b1 = torch.randn(200)
W2 = torch.randn((200, 27))
b2 = torch.randn(27)
parameters = [C, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad = True

In [91]:
sum(p.nelement() for p in parameters)

15659

In [92]:
epochs = 200000
et = epochs/10
lr = 0.1

In [75]:
lossi=[]
stepi=[]
ep = 0

In [93]:
for i in range(epochs):
    # mini batch
    ix = torch.randint(0,Xtr.shape[0],(48,))

    #forward 
    emb = C[Xtr[ix]]
    emb_out = emb.view(emb.shape[0], -1)
    h = torch.tanh(emb_out @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Ytr[ix])

    if i > 100000:
        lr = 0.01
    # backward
    for p in parameters:
        p.grad = None
    loss.backward()
    for p in parameters:
        p.data += -lr*p.grad

    stepi.append(ep)
    lossi.append(loss.log10().item())
    ep+=1


In [128]:
# dev loss
emb = C[Xdev]
emb_out = emb.view(emb.shape[0], -1)
h = torch.tanh(emb_out @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Ydev)
loss.item()

2.14638352394104

In [129]:
# train loss
emb = C[Xtr]
emb_out = emb.view(emb.shape[0], -1)
h = torch.tanh(emb_out @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Ytr)
loss.item()

2.0744006633758545

In [130]:
# test loss
emb = C[Xte]
emb_out = emb.view(emb.shape[0], -1)
h = torch.tanh(emb_out @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Yte)
loss.item()

2.142019510269165

In [96]:
# plt.figure(figsize=(8,8))
# plt.scatter(C[:,0].data, C[:,1].data, s=200)
# for i in range(C.shape[0]):
#     plt.text(C[i,0].item(), C[i,1].item(), itos[i], ha="center", va='center', color="white")
# plt.grid('minor')

In [118]:
for i in range(10):
    context = [0]*context_size
    output = []
    while True:
        emb = C[context]
        ee = emb.view(1, -1)
        h = torch.tanh(ee @ W1 + b1)
        logits = h @ W2 + b2
        p = torch.softmax(logits, 1)
        ix = torch.multinomial(p, num_samples=1, replacement=True).item()
        output.append(itos[ix])
        context = context[1:]+[ix]
        if ix == 0:
            break
    print(''.join(output))

racler.
kin.
chendonn.
tritty.
corbea.
keny.
zyeadim.
prin.
keevy.
zihance.
