In [None]:
%reload_ext autoreload
%autoreload 2

%matplotlib widget

import torch

torch.manual_seed(42)


In [None]:
from data.names_dataset import NamesDataset, NamesDataSource
from torch.utils.data import DataLoader


# Initialize NamesDataset with the detected device
names_data_source = NamesDataSource.load(
    data_folder="../datasets/names", normalize_unicode=True
)

names_dataset = NamesDataset(names_data_source)
sample = names_dataset[0]
print(f"{sample.country=}")
print(f"{sample.name=}")
print(f"{sample.country_tensor=}")
print(f"{sample.name_tensor=}")


In [None]:
train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.85, 0.15])
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence


class NamesClassifierRNN(nn.Module):
    """
    D: input_size
    H: hidden_size
    C: output_size

    S: sequence_length
    N: batch_size
    """

    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        # rnn: [S, N, D] -> hidden [N, H]
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size)

        # fc: [N, H] -> [N, C]
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: shape [S, N, D]
        """
        # hidden: [N, H]
        _rnn_output, hidden = self.rnn(x)

        # output: [N, C]
        output = self.fc(hidden[0])
        return output


class NamesClassifierLSTM(nn.Module):
    """
    D: input_size
    H: hidden_size
    C: output_size

    S: sequence_length
    N: batch_size
    """

    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        # num_layers * num_directions == 4
        # lstm: [S, N, D] -> hidden [4, N, H]
        # Need to concatenate the last 2 hidden states since this is a bidirectional LSTM
        # hidden: [4, N, H] -> [N, H * 2]
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=False,
            num_layers=2,
            bidirectional=True,
        )

        # dropout: [N, H * 2] -> [N, H * 2]
        self.dropout = nn.Dropout(p=0.5)

        # fc: [N, H * 2] -> [N, C]
        self.fc = nn.Linear(hidden_size * 2, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: shape [S, N, D]
        """
        # hidden: [num_layers * num_directions, N, H]
        _lstm_output, (hidden, _cell) = self.lstm(x)

        # bidirectional_hidden_state: [N, H * 2]
        bidirectional_hidden_state = torch.cat((hidden[-2], hidden[-1]), dim=1)

        # dropout_output: [N, H * 2]
        dropout_output = self.dropout(bidirectional_hidden_state)

        # output: [N, C]
        output = self.fc(dropout_output)
        return output


test_model = NamesClassifierRNN(
    input_size=names_data_source.num_vocab,
    hidden_size=64,
    output_size=names_data_source.num_classes,
)
print(test_model)
test_input = names_dataset.name_to_one_hot("John")
print(test_input)
test_output = test_model(test_input)
print(test_output)
country_idx = torch.argmax(test_output)
print(names_data_source.countries[country_idx])


In [None]:
from common.batch_learner import Batch


def collate_fn(batch) -> Batch:
    """
    batch: list[Sample]
    """
    inputs = []
    labels = []
    for sample in batch:
        inputs.append(sample.name_tensor)
        labels.append(sample.country_tensor)
    return Batch(inputs=inputs, labels=labels)


print(collate_fn(names_dataset[:10]))


In [None]:
import torch
from common.batch_learner import BatchLearner
from common.metrics import (
    AccuracyMetric,
    ConfusionMatrixMetric,
    Metric,
    PrecisionMetric,
    RecallMetric,
    F1ScoreMetric,
)

BATCH_SIZE = 64
LEARNING_RATE = 0.001
HIDDEN_SIZE = 64
NUM_EPOCHS = 50
PATIENCE = 5

rnn = NamesClassifierRNN(
    input_size=names_data_source.num_vocab,
    hidden_size=HIDDEN_SIZE,
    output_size=names_data_source.num_classes,
)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=LEARNING_RATE)

learner = BatchLearner(
    model=rnn,
    optimizer=optimizer,
    criterion=criterion,
)

train_accuracy_metric = AccuracyMetric(num_classes=names_data_source.num_classes)
train_precision_metric = PrecisionMetric(classes=names_data_source.countries)
train_recall_metric = RecallMetric(classes=names_data_source.countries)
train_f1_metric = F1ScoreMetric(classes=names_data_source.countries)
train_metrics: list[Metric] = [
    train_accuracy_metric,
    train_precision_metric,
    train_recall_metric,
    train_f1_metric,
]

eval_accuracy_metric = AccuracyMetric(num_classes=names_data_source.num_classes)
eval_precision_metric = PrecisionMetric(classes=names_data_source.countries)
eval_recall_metric = RecallMetric(classes=names_data_source.countries)
eval_f1_metric = F1ScoreMetric(classes=names_data_source.countries)
eval_metrics: list[Metric] = [
    eval_accuracy_metric,
    eval_precision_metric,
    eval_recall_metric,
    eval_f1_metric,
]


train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

eval_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)

print("Starting training...")
train_losses, eval_losses = learner.fit(
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    num_epochs=NUM_EPOCHS,
    patience=PATIENCE,
    train_metrics=train_metrics,
    eval_metrics=eval_metrics,
)
print("Training completed.")


In [None]:
import matplotlib.pyplot as plt

confusion_matrix_metric = ConfusionMatrixMetric(
    classes=names_data_source.countries, normalize=True
)


full_dataloader = DataLoader(
    dataset=names_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)
final_loss = learner.final_eval(
    dataloader=full_dataloader, metrics=[confusion_matrix_metric]
)

_, ax = plt.subplots(figsize=(12, 8))
confusion_matrix_metric.plot(ax, "Confusion Matrix")

_, ax = plt.subplots(figsize=(5, 5))
ax.set_title("Loss and Accuracy")
ax.set_xlabel("Epoch")
train_accuracy_metric.plot(ax, "Train Accuracy")
eval_accuracy_metric.plot(ax, "Test Accuracy")
ax.plot(train_losses, label="Train Loss")
ax.plot(eval_losses, label="Test Loss")
ax.legend()


_, ax = plt.subplots(figsize=(5, 5), label="Precision")
train_precision_metric.plot_classes(ax, "Train")
eval_precision_metric.plot_classes(ax, "Eval")

_, ax = plt.subplots(figsize=(5, 5), label="Recall")
train_recall_metric.plot_classes(ax, "Train")
eval_recall_metric.plot_classes(ax, "Eval")

_, ax = plt.subplots(figsize=(5, 5), label="F1")
train_f1_metric.plot_classes(ax, "Train")
eval_f1_metric.plot_classes(ax, "Eval")


In [None]:
print(rnn.fc)
print(rnn.fc.weight.data)
assert rnn.fc.weight.grad is not None

hist, bin_edges = torch.histogram(rnn.fc.weight.data)
f, ax = plt.subplots(figsize=(5, 5))
ax.plot(bin_edges[:-1], hist)

plt.figure(figsize=(10, 5))
plt.imshow(rnn.fc.weight.abs() > 0.03)


In [None]:
likelihoods, indices = learner.predict_topk(
    names_dataset.name_to_one_hot("Albert"), k=3
)
for likelihood, country_idx in zip(likelihoods, indices):
    print(f"{likelihood:.2f} {names_data_source.countries[country_idx]}")

print(f"Total likelihood: {likelihoods.sum().item():.2f}")
