In [1]:
import torch
from torch.nn import functional as funct
g = torch.Generator().manual_seed(14442)

import pandas as pd
import altair as alt

import random
random.seed(30)

In [2]:
words = open("names.txt", "r").read().splitlines()

In [3]:
alphabet = sorted(list(set(".".join(words))))
char_to_idx = {c: i for i, c in enumerate(alphabet)}
idx_to_char = {i: c for i, c in enumerate(alphabet)}

In [4]:
context_size = 3
emb_length = 60
batch_size = 100
num_neurons = 1000

def build_dataset(words:list):
    xs, ys = [], []
    for w in words:
        context = [0]*context_size
        for ch in w + ".":
            idx = char_to_idx[ch]
            xs.append(context)
            ys.append(idx)
            context = context[1:] + [idx]
    xs, ys = torch.tensor(xs), torch.tensor(ys)

    return xs, ys

In [5]:
random.shuffle(words)

n1 = int(0.8*len(words)); n2 = int(0.9*len(words))
x_train, y_train = build_dataset(words[:n1])
x_val, y_val = build_dataset(words[n1:n2])
x_test, y_test = build_dataset(words[n2:])

### Initial setting
- set $b_2 = 0$ and scale $W_2$ by a small constant to drag the initial loss to $0$.

### Kaiming init
- Initialization of weights to prevent saturation when applying the non-linearity.

### Batchnorm
- Normalizes each batch in training to be approximately gaussian up to scaling and translation.
- Removes the effect of bias when centering the data.
- Highly inefficient in deep networks.


In [6]:
# Lookup table 
C = torch.randn((len(alphabet), emb_length), generator=g)

# Hidden layer
# Kaiming init for tanh (to avoid saturation of tanh).
W1 = torch.randn((context_size * emb_length, num_neurons), generator=g) * (5/3) / (context_size * emb_length)**0.5
#b1 = torch.randn(num_neurons, generator=g) # Has no effect in batch normalization.

# Output layer
W2 = torch.randn((num_neurons, len(alphabet)), generator=g) * 0.1 #To reduce initial setting loss
b2 = torch.randn(27, generator=g) * 0


bnscale = torch.ones((1, num_neurons))
bnshift = torch.zeros((1, num_neurons))

params = [C, W1, W2, b2, bnscale, bnshift]

for p in params:
    p.requires_grad = True

In [7]:
bn_mean = torch.zeros((1, num_neurons))
bn_std = torch.ones((1, num_neurons))

for i in range(100000):
    batch = torch.randint(0, x_train.shape[0], (batch_size, ))
    embedding = C[x_train[batch]]
    hpreact = embedding.view(-1, context_size * emb_length) @ W1 #+ b1) #Bias not needed as normalization removes its effect.
    
    #---------------------------------------------------------------------------------------------------------

    # BatchNorm the preactivation (converting in to approx. a Gaussian distr.)
    bn_mean_i = hpreact.mean(0, keepdim=True)
    bn_std_i =  hpreact.std(0, keepdim=True)
    h1 = bnscale * (hpreact - bn_mean_i) / (bn_std_i + 0.01) + bnshift

    with torch.no_grad():
        bn_mean = 0.999*bn_mean + 0.001*bn_mean_i
        bn_std = 0.999*bn_std + 0.001*bn_std_i

    #---------------------------------------------------------------------------------------------------------

    # Non-linearity
    h1 = torch.tanh(h1) 
    logits = h1 @ W2 + b2

    #---------------------------------------------------------------------------------------------------------

    # Loss
    loss = funct.cross_entropy(logits, y_train[batch]) #Equivalent to mean of neg log-likelihood

    #---------------------------------------------------------------------------------------------------------
    
    if i%2500==0:
        print(f"Iter {i} \t|\t Loss: {loss:.5f}")

    for p in params:
        p.grad = None

    loss.backward()

    lr = 0.1 if i < 30000 else 0.001

    for p in params:
        p.data += -lr * p.grad

Iter 0 	|	 Loss: 4.94409
Iter 2500 	|	 Loss: 2.05342
Iter 5000 	|	 Loss: 2.18867
Iter 7500 	|	 Loss: 2.24115
Iter 10000 	|	 Loss: 2.34288
Iter 12500 	|	 Loss: 2.30169
Iter 15000 	|	 Loss: 1.91409
Iter 17500 	|	 Loss: 2.16541
Iter 20000 	|	 Loss: 1.96124
Iter 22500 	|	 Loss: 2.29328
Iter 25000 	|	 Loss: 2.10862
Iter 27500 	|	 Loss: 1.91253
Iter 30000 	|	 Loss: 2.26748
Iter 32500 	|	 Loss: 1.79838
Iter 35000 	|	 Loss: 2.04707
Iter 37500 	|	 Loss: 2.16278
Iter 40000 	|	 Loss: 1.95799
Iter 42500 	|	 Loss: 2.14786
Iter 45000 	|	 Loss: 1.81286
Iter 47500 	|	 Loss: 1.87609
Iter 50000 	|	 Loss: 1.95449
Iter 52500 	|	 Loss: 1.99352
Iter 55000 	|	 Loss: 1.99787
Iter 57500 	|	 Loss: 1.92874
Iter 60000 	|	 Loss: 2.24800
Iter 62500 	|	 Loss: 2.05478
Iter 65000 	|	 Loss: 1.94512
Iter 67500 	|	 Loss: 1.89472
Iter 70000 	|	 Loss: 2.01362
Iter 72500 	|	 Loss: 1.94915
Iter 75000 	|	 Loss: 1.87865
Iter 77500 	|	 Loss: 2.03388
Iter 80000 	|	 Loss: 2.26627
Iter 82500 	|	 Loss: 2.08557
Iter 85000 	|	 Loss: 

In [8]:
@torch.no_grad
def eval_loss(split:str):
    x, y = {"train": (x_train, y_train), "val": (x_val, y_val), "test": (x_test, y_test)}[split]
    embedding_val = C[x]
    hpreact = embedding_val.view(-1, context_size * emb_length) @ W1 #+ b1
    h1 = bnscale * (hpreact - bn_mean) / (bn_std + 0.01) + bnshift
    h1 = torch.tanh(h1) 
    logits = h1 @ W2 + b2
    loss = funct.cross_entropy(logits, y)

    return loss

In [9]:
eval_loss("train"), eval_loss("val")

(tensor(1.9783), tensor(2.0638))

In [10]:
from sklearn.manifold import TSNE
C_emb = TSNE(n_components=2, perplexity=5).fit_transform(C.detach().numpy())

In [11]:
repr = pd.DataFrame({"x1": C_emb[:, 0], "x2": C_emb[:, 1], "char": alphabet})

scatter = alt.Chart(repr).mark_circle(size=60).encode(
    x='x1',
    y='x2',
    tooltip=['char']
).properties(
    width=500,
    height=500
).interactive()

chars = scatter.mark_text(
    align='left',
    baseline='middle',
    dx=7
).encode(
    text='char'
)

scatter + chars

### Sampling

In [12]:
for _ in range(20):
    idx = 0 
    res = ""
    cont = [0] * context_size
    while True:
        emb = C[torch.tensor(cont)]
        hpreact = emb.view(-1, context_size * emb_length) @ W1
        h = bnscale * (hpreact - bn_mean) / (bn_std + 0.01) + bnshift
        h = torch.tanh(h) #+ b1)
        logits = h @ W2 + b2
        probs = funct.softmax(logits, dim=1)
        idx = torch.multinomial(probs, 1, replacement=True, generator=g).item()
        if idx==0:
            break
        
        res += idx_to_char[idx]
        cont = cont[1: ] + [idx]

    print(res)

sai
rem
zaleigh
daciony
acean
bryson
temillouis
sha
kylah
nyelfo
sry
jamena
mae
doriel
eleia
jevelia
dio
abrinley
drea
shreecilphoella
