In [None]:
%reload_ext autoreload
%autoreload 2

import torch
from data.names_dataset import NamesDataset
from torch.utils.data import DataLoader

# Initialize NamesDataset with the detected device
names_dataset = NamesDataset(data_folder="../datasets/names")


# train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.85, 0.15])
train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.85, 0.15])
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [30]:
import time
import math
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence


class NamesClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NamesClassifier, self).__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=False,
        )
        self.h2o = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(p=0.5)  # Add dropout regularization

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is of shape (seq_length, batch_size, input_size)
        _lstm_output, (hidden, _cell) = self.lstm(x)
        # Use the last hidden state for classification
        last_hidden_state = hidden[-1]  # Shape: (batch_size, hidden_size)
        last_hidden_state = self.dropout(last_hidden_state)  # Apply dropout
        h2o_output = self.h2o(last_hidden_state)  # Shape: (batch_size, output_size)
        return h2o_output


def collate_fn(batch):
    inputs, labels = zip(*batch)
    inputs_padded = pad_sequence(list(inputs), batch_first=False, padding_side="left")
    labels = torch.stack(labels)
    return inputs_padded, labels

In [None]:
import matplotlib.pyplot as plt
from common.learner import Learner
from common.metrics import AccuracyMetric, ConfusionMatrixMetric, LossMetric, Metric

BATCH_SIZE = 64
LEARNING_RATE = 0.001
HIDDEN_SIZE = 64
NUM_EPOCHS = 50
PATIENCE = 5


train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)

num_classes = len(names_dataset.countries)


rnn = NamesClassifier(
    input_size=len(names_dataset.index_to_token),
    hidden_size=HIDDEN_SIZE,
    output_size=len(names_dataset.countries),
)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=LEARNING_RATE)

learner = Learner(
    model=rnn,
    optimizer=optimizer,
    criterion=criterion,
)

train_loss_metric = LossMetric(num_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE)
train_accuracy_metric = AccuracyMetric(
    num_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, num_classes=num_classes
)
train_confusion_matrix_metric = ConfusionMatrixMetric(
    num_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, num_classes=num_classes
)
train_metrics: list[Metric] = [
    train_loss_metric,
    train_accuracy_metric,
    # train_confusion_matrix_metric,
]

eval_loss_metric = LossMetric(num_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE)
eval_accuracy_metric = AccuracyMetric(
    num_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, num_classes=num_classes
)
eval_confusion_matrix_metric = ConfusionMatrixMetric(
    num_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, num_classes=num_classes
)
eval_metrics: list[Metric] = [
    eval_loss_metric,
    eval_accuracy_metric,
    # eval_confusion_matrix_metric,
]


print("Starting training...")
learner.fit(
    train_dataloader,
    test_dataloader,
    NUM_EPOCHS,
    train_metrics,
    eval_metrics,
)
print("Training completed.")

plt.figure()
plt.plot(train_loss_metric.epoch_losses, label="Train Loss")
plt.plot(eval_loss_metric.epoch_losses, label="Test Loss")
plt.plot(train_accuracy_metric.epoch_corrects, label="Train Accuracy")
plt.plot(eval_accuracy_metric.epoch_corrects, label="Test Accuracy")
plt.xlabel("Epoch")
plt.legend()
plt.show()

In [None]:
print(train_loss_metric.epoch_losses)
print(eval_loss_metric.epoch_losses)

In [None]:
def predict(model: nn.Module, input: str):
    model.eval()
    with torch.no_grad():
        base_tensor = names_dataset.name_to_tensor(input)
        for i in range(20):
            tensor = torch.cat(
                [
                    base_tensor,
                    torch.zeros((i, base_tensor.shape[1]), dtype=torch.float32),
                ],
                dim=0,
            )
            output = model(tensor)
            print(f"{tensor.shape=} {names_dataset.countries[output.argmax().item()]}")


print(predict(rnn, "Hai"))