In [1]:
import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [2]:
with open('names.txt', 'r') as fp:
    words = fp.read().splitlines()
    
print(words[:5])

['emma', 'olivia', 'ava', 'isabella', 'sophia']


In [3]:
chars = sorted(list(set("".join(words))))
char_to_i = {char:i+1 for i, char in enumerate(chars)}
char_to_i['.'] = 0
i_to_char = {i:s for s, i in char_to_i.items()}
print(i_to_char)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [4]:
import random
random.seed(42)
random.shuffle(words)

In [5]:
ngram = 8

def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * ngram
        for ch in w + '.':
            ix = char_to_i[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    
    print(X.shape, Y.shape)
    return X, Y

n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

X_train, Y_train = build_dataset(words[:n1])
X_dev, Y_dev = build_dataset(words[n1:n2])
X_test, Y_test = build_dataset(words[n2:])


torch.Size([182625, 8]) torch.Size([182625])
torch.Size([22655, 8]) torch.Size([22655])
torch.Size([22866, 8]) torch.Size([22866])


In [7]:
from utils import Linear, BatchNorm1d, Tanh, Embedding, FlattenConsecutive, Sequential

In [8]:
torch.manual_seed(42)

<torch._C.Generator at 0x25385b2bf90>

In [46]:
ngram = 8
vocab_len = len(i_to_char)
embed_dim = 10
hidden_dim = 68

model = Sequential([
    Embedding(vocab_len, embed_dim),
    FlattenConsecutive(2), Linear(embed_dim * 2, hidden_dim, bias=False), BatchNorm1d(hidden_dim), Tanh(),
    FlattenConsecutive(2), Linear(hidden_dim * 2, hidden_dim, bias=False), BatchNorm1d(hidden_dim), Tanh(),
    FlattenConsecutive(2), Linear(hidden_dim * 2, hidden_dim, bias=False), BatchNorm1d(hidden_dim), Tanh(),
    Linear(hidden_dim, vocab_len)
])


# with torch.no_grad():
#     layers[-1].weight *= 0.1

parameters = model.parameters()
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

22397


In [47]:
ix = torch.randint(0, X_train.shape[0], (4,))
Xb, Yb = X_train[ix], Y_train[ix]
logits = model(Xb)
print(Xb.shape)

torch.Size([4, 8])


In [48]:
for layer in model.layers:
    print(layer.__class__.__name__,tuple(layer.out.shape))

Embedding (4, 8, 10)
FlattenConsecutive (4, 4, 20)
Linear (4, 4, 68)
BatchNorm1d (4, 4, 68)
Tanh (4, 4, 68)
FlattenConsecutive (4, 2, 136)
Linear (4, 2, 68)
BatchNorm1d (4, 2, 68)
Tanh (4, 2, 68)
FlattenConsecutive (4, 136)
Linear (4, 68)
BatchNorm1d (4, 68)
Tanh (4, 68)
Linear (4, 27)


In [49]:
max_steps = 200000
batch_size = 32
lossi = []

ud = []
for i in range(max_steps):
    
    # get a minibatch
    idxs = torch.randint(0, X_train.shape[0], (batch_size,))
    X, Y = X_train[idxs], Y_train[idxs]
    
    logits = model(X)
    loss = F.cross_entropy(logits, Y)
    
    for p in parameters:
        p.grad = None
    loss.backward()
    
    lr = 0.1 if i < 150000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad
        
    if i % 10000 == 0:
        print(f"loss {i}/{max_steps}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    break

loss 0/200000: 3.5670


In [None]:
plt.plot(torch.tensor(lossi).view(-1, 1000).mean(1))

In [59]:
for layer in model.layers:
    layer.training = False
    
@torch.no_grad()
def split_loss(split):
    X, Y = {
            'train': (X_train, Y_train),
            'dev': (X_dev, Y_dev),
            'test': (X_test, Y_test)
           }[split]

    logits = model(X)
    loss = F.cross_entropy(logits, Y)
    print(f"{split}: {loss.item()}")
    
split_loss('train')
split_loss('dev')

train: 3.53708815574646
dev: 3.538801431655884


In [60]:
for _ in range(20):
    out = []
    context = [0] * ngram
    while True:
        logits = model(torch.tensor([context]))
        probs = F.softmax(logits, dim=1)
        next_idx = torch.multinomial(probs, num_samples=1).item()
        context = context[1:] + [next_idx]
        out.append(i_to_char[next_idx])
        if next_idx == 0:
            break
    print(''.join(out))      
        

pj.
sbw.
snswtaihbtfiiflwlj.
baajnildgcxscepmgrlmxsogfwinqozjtvw.
limecgjatqpc.
dkgujsplsodoctlfktqvwxqfw.
dszixsxaugzakwuaufuzpcznjkzrzohlhdivsqbtfscturiztcuyxlqfnvzyqzspdrcivqztbpkwzsdrfmwkntdcfxgbxdkjph.
sciqxhwtazwyjbsz.
arwehaayrdwwx.
jqmcphmyqgnnifyidprqmzltklkkbwrivvstbsrhwh.
lqislpdjadn.
pxyhkudcgxz.
eqlafpxkhlwspjxiiesxftnvjeyjatuzpuxhoeolafefkcylmxyhdbfxolidzqmc.
ebogsshwtlllh.
jjlrqlwbtx.
jstkxkifxvqyzdrhkakrgabrhmxjbtcfhlpojqnpskovcqtisixqyrmccwbtqxqvjh.
xrqhmdbbaflilmvazsoiubftqscgpqlmfzafwddhlckfbipoihlqebjs.
qsuonnlcybalftsseqgfpaeybnfcdfcqrfvujkyqzrrlhlrmexqouqgcscuysiordlfltafvxkjblfmvkfbzfybulr.
xwigmrwbukwlfehwcfqfla.
ql.
