In [None]:
import random

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from nn_zero_to_hero.datasets import WordTokensDataset
from nn_zero_to_hero.loss import calculate_loss
from nn_zero_to_hero.models import WordTokenModel
from nn_zero_to_hero.optimizers import StepBasedLrGDOptimizer
from nn_zero_to_hero.tokens import sample_from_model, tokens_to_int_mapping
from nn_zero_to_hero.trainers import train_model_simple
from nn_zero_to_hero.vizs import plot_embeddings

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
BLOCK_SIZE = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
words = open("../../data/names.txt", "r").read().splitlines()
# words = [w.lower() for w in  open("../../data/names_finnish.txt", "r").read().splitlines() if "." not in w]
words[:8]

In [None]:
len(words)

In [None]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set("".join(words))))
STOI, ITOS = tokens_to_int_mapping(chars)

print(ITOS)

In [None]:
random.seed(42)

random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

train_dataset = WordTokensDataset(words[:n1], BLOCK_SIZE, STOI)
validation_dataset = WordTokensDataset(words[n1:n2], BLOCK_SIZE, STOI)
test_dataset = WordTokensDataset(words[n2:], BLOCK_SIZE, STOI)

In [None]:
train_dataset.X.shape, train_dataset.Y.shape

In [None]:
model = WordTokenModel(
    token_count=len(STOI),
    block_size=BLOCK_SIZE,
    embedding_layer_size=5,
    hidden_layer_size=100,
    generator=torch.Generator().manual_seed(2147483647),
).to(device)

In [None]:
batch_size = 32
epochs = 20

batches_by_epoch = len(train_dataset) // batch_size
optimizer = StepBasedLrGDOptimizer(
    model.parameters(),
    max_step_to_lr=[
        (batches_by_epoch * epochs * 0.5, 0.1),
        (batches_by_epoch * epochs * 0.75, 0.01),
        (None, 0.001),
    ],
)

stats_df = train_model_simple(
    model,
    dataset=train_dataset,
    optimizer=optimizer,
    epochs=epochs,
    batch_size=batch_size,
    device=device,
)

In [None]:
plt.plot(stats_df["step"], np.log10(stats_df["loss"]))

In [None]:
training_loss = calculate_loss(model, train_dataset, F.cross_entropy, device)
validation_loss = calculate_loss(model, validation_dataset, F.cross_entropy, device)
print(f"{training_loss = :4f}, {validation_loss = :4f}")

In [None]:
plot_embeddings(model.C, ITOS)

In [None]:
# sample from the model
g = torch.Generator(device).manual_seed(2147483647 + 10)

for _ in range(20):
    s = sample_from_model(
        model,
        block_size=BLOCK_SIZE,
        device=device,
        itos=ITOS,
        generator=g,
    )
    print(s)