In [None]:
# load data
import random
import time

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

words = open("names.txt").read().splitlines()
print(len(words))
chars = sorted(list(set("".join(words))))

stoi = {s: i+1 for i, s in enumerate(chars)}
stoi["."] = 0

itos = {i: s for s, i in stoi.items()}


In [None]:
# params
use_cuda = True
if use_cuda:
    torch.set_default_device('cuda:0')
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
    # only relevant for Ampere or later GPUs
    # torch.backends.cuda.matmul.allow_tf32 = True
else:
    torch.set_default_device('cpu')
# embed_size = 2
# embed_size = 10
embed_size = 30
# hidden_size = 100
# hidden_size = 200
hidden_size = 500
# minibatch_size = 32
minibatch_size = 1900
# block_size = 3  # context length
block_size = 4  # context length
init_lr = 0.15
# lr_decay = 0.99
lr_decay = 0.95
# regularization_loss_factor =  0.005
regularization_loss_factor =  0.05

outer_i = 100
# mbs = [1800 + i * 10 for i in range(outer_i)]
inner_i = 1000


In [None]:
# initialize model
def build_dataset(words):

    # inputs, labels
    X, Y = [], []

    for w in words:
        # print(w)
        context = [0] * block_size
        for ch in w + ".":
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            # print("".join(itos[i] for i in context), "--->", itos[ix])
            context = context[1:] + [ix]  # crop and append
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    # print(X.shape, Y.shape)
    return X, Y


# | training split| dev/validation split| test split |
# |            80%|                  10%| 10%        |
# dev split used for developing hyperparameters
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

if use_cuda:
    g = torch.Generator("cuda:0").manual_seed(2147483647)
else:
    g = torch.Generator().manual_seed(2147483647)

# embedding layer
C = torch.randn(27, embed_size, generator=g)
# hidden layer
W1 = torch.rand(block_size * embed_size, hidden_size, generator=g)
b1 = torch.rand(hidden_size, generator=g)
# output layer
W2 = torch.rand(hidden_size, 27, generator=g)
b2 = torch.rand(27, generator=g)

parameters = [C, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad = True
print(
    f"parameters: {sum(p.nelement() for p in parameters)}")
last_loss = None
lr = init_lr
lrs = []
losstr = []
lossdev = []


In [None]:
# train

# dump initial loss
# TODO: refactor all loss calculations into a function
emb = C[Xtr]
h = torch.tanh(emb.view(-1, block_size * embed_size) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Ytr)
loss += regularization_loss_factor * ((W1**2).mean() + (W2**2).mean())
print(f"training set:      {loss.item():.4f}")

emb = C[Xdev]
h = torch.tanh(emb.view(-1, block_size * embed_size) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Ydev)
loss += regularization_loss_factor * ((W1**2).mean() + (W2**2).mean())
print(f"dev set:           {loss.item():.4f}")
print(f"learing rate:      {lr:.4f}")
# low_ns = None
# low_ns_size = None
batch_size = minibatch_size
for o in range(outer_i):
    # lossi = []
    # timesm = []
    # timesf = []
    # timesb = []
    # timesu = []
    # batch_size = mbs[o]

    begin = time.time()
    for i in range(inner_i):
        # start = time.time()
        # minibatch
        ix = torch.randint(0, Xtr.shape[0], (batch_size,))
        # timesm.append(time.time() - start)
        # start = time.time()
        # forward pass
        emb = C[Xtr[ix]]  # (batch_size, 3, 2)
        # emb.view is equivalent to torch.cat(torch.unbind(emb, 1), 1), but much more efficient
        # giving -1 as the first dimension makes torch infer it from the other dimensions
        h = torch.tanh(emb.view(-1, block_size * embed_size) @ W1 + b1)
        logits = h @ W2 + b2

        # counts = logits.exp()
        # prob = counts / counts.sum(1, keepdim=True)
        # loss = -prob[torch.arange(emb.shape[0]), Ytr].log().mean()
        loss = F.cross_entropy(logits, Ytr[ix])  # equivalent to the above
        # add regularization loss
        loss += regularization_loss_factor * ((W1**2).mean() + (W2**2).mean())
        # timesf.append(time.time() - start)
        # start = time.time()

        # backward pass
        for p in parameters:
            p.grad = None

        loss.backward()

        # timesb.append(time.time() - start)
        # start = time.time()

        # update
        for p in parameters:
            p.data += -lr * p.grad
            # p.data += -0.1 * p.grad
        # lossi.append(loss.item())
        # timesu.append(time.time() - start)

    emb = C[Xtr]
    h = torch.tanh(emb.view(-1, block_size * embed_size) @ W1 + b1)
    logits = h @ W2 + b2
    losst = F.cross_entropy(logits, Ytr)
    losst += regularization_loss_factor * ((W1**2).mean() + (W2**2).mean())
    losstr.append(losst.log().item())

    emb = C[Xdev]
    h = torch.tanh(emb.view(-1, block_size * embed_size) @ W1 + b1)
    logits = h @ W2 + b2
    lossd = F.cross_entropy(logits, Ydev)
    lossd += regularization_loss_factor * ((W1**2).mean() + (W2**2).mean())
    lossdev.append(lossd.log().item())

    # * 10 to look better on the plot
    lrs.append(lr * 10)

    if last_loss and last_loss < losst.item():
        lr *= lr_decay

    last_loss = losst.item()

    # print(f"s={batch_size:6d}", end=", ")
    # print(f"m={sum(timesm) * 1000 / len(timesm):8.4f}ms", end=", ")
    # print(f"f={sum(timesf) * 1000/ len(timesf):8.4f}ms", end=", ")
    # print(f"b={sum(timesb) * 1000 / len(timesb):8.4f}ms", end=", ")
    # print(f"u={sum(timesu) * 1000 / len(timesu):8.4f}ms", end=", ")
    t = time.time() - begin
    print(f"t={t:8.4f}ms", end=", ")
    # print(f"samples={inner_i * batch_size}", end=", ")
    ns_per_sample = (t / (inner_i * batch_size)) * 1_000_000_000
    # if low_ns is None or ns_per_sample < low_ns:
    #     low_ns = ns_per_sample
    #     low_ns_size = batch_size
    print(
        f"per sample={ns_per_sample:6.2f}ns", end=", ")
    print(f"fit={(lossd.item() / losst.item()) * 100:6.2f}% losst={losst.item():8.4f} lossd={lossd.item():8.4f} lr={lr:.4f}")

# print(f"best per sample={low_ns:6.2f}ns @ batch_size={low_ns_size:6d}")
# print()
emb = C[Xtr]
h = torch.tanh(emb.view(-1, block_size * embed_size) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Ytr)
loss += regularization_loss_factor * ((W1**2).mean() + (W2**2).mean())
print(f"training set:      {loss.item():.4f}")

emb = C[Xdev]
h = torch.tanh(emb.view(-1, block_size * embed_size) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Ydev)
loss += regularization_loss_factor * ((W1**2).mean() + (W2**2).mean())
print(f"dev set:           {loss.item():.4f}")
print(f"learing rate:      {lr:.4f}")

plt.plot(lrs, "b")
plt.plot(losstr, "r")
plt.plot(lossdev, "g")
fit = torch.tensor(lossdev, device='cpu') / torch.tensor(losstr, device='cpu')
plt.plot(fit, "y")
# plt.plot(losstr)
# plt.plot(lossdev)


In [None]:
# inspect

def dump_c():
    plt.figure(figsize=(8, 8))
    plt.scatter(C[:, 0].data, C[:, 1].data, s=200)
    for i in range(C.shape[0]):
        plt.text(C[i, 0].item(), C[i, 1].item(), itos[i],
                 ha="center", va="center", color="white")
    plt.grid("minor")


def sample():
    if use_cuda:
        g = torch.Generator("cuda:0").manual_seed(2147483647)
    else:
        g = torch.Generator().manual_seed(2147483647)

    for _ in range(20):
        out = []
        context = [0] * block_size  # "......"
        while True:
            # 1 because of 1 in [context], was all samples (in the batch) during training
            emb = C[torch.tensor([context])]  # (1, block_size, embed_size)
            h = torch.tanh(emb.view(1, -1) @ W1 + b1)
            logits = h @ W2 + b2

            probs = F.softmax(logits, dim=1)
            ix = torch.multinomial(probs, num_samples=1, generator=g).item()
            context = context[1:] + [ix]
            out.append(ix)
            if ix == 0:
                break
        print(''.join(itos[i] for i in out))

# only works for embed_size = 2
# dump_c()
sample()
