In [None]:
%reload_ext autoreload
%autoreload 2

import torch
from data.names_dataset import NamesDataset
from torch.utils.data import DataLoader

# Check for available accelerators
device = (
    torch.accelerator.current_accelerator()
    if torch.accelerator.is_available()
    else torch.device("cpu")
)
print(f"Using device: {device}")
assert device

# Initialize NamesDataset with the detected device
names_dataset = NamesDataset(data_folder="../datasets/names")


# train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.85, 0.15])
train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.8, 0.2])
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
BATCH_SIZE = 256


def collate_fn(
    batch: list[tuple[torch.Tensor, torch.Tensor]],
) -> list[tuple[torch.Tensor, torch.Tensor]]:
    return batch


train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)


In [None]:
import torch.nn as nn
import torch.nn.functional as F


class NamesClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NamesClassifier, self).__init__()
        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=True,
        )
        self.h2o = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is of shape (batch_size, seq_length, input_size)
        output, hidden = self.rnn(x)
        output = self.h2o(hidden[0])
        # output is of shape (batch_size, output_size)
        output = F.log_softmax(output, dim=1)
        return output


def train(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        batch_loss = torch.tensor(0.0)
        for input, label in batch:
            output = model(input)
            batch_loss += criterion(output, label.argmax(dim=1))

        total_loss += batch_loss.item()
        batch_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 2)
        optimizer.step()
        optimizer.zero_grad()

    return total_loss / len(train_dataset)


def evaluate(model: nn.Module, dataloader: DataLoader, criterion: nn.Module):
    model.eval()
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            for input, label in batch:
                output = model(input)
                loss = criterion(output, label.argmax(dim=1))
                predictions = output.argmax(dim=1)
                correct += (predictions == label.argmax(dim=1)).sum().item()
                total_loss += loss.item()

    return total_loss / len(test_dataset), correct / len(test_dataset)

In [None]:
import time

rnn = NamesClassifier(
    input_size=len(names_dataset.index_to_token),
    hidden_size=128,
    output_size=len(names_dataset.countries),
)

print(rnn)
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)

print("Starting training...")

train_losses = []
test_losses = []
accuracies = []
num_epochs = 25
for epoch in range(num_epochs):
    start_time = time.perf_counter_ns()
    train_loss = train(rnn, train_dataloader, optimizer, criterion)
    test_loss, accuracy = evaluate(rnn, test_dataloader, criterion)
    end_time = time.perf_counter_ns()
    elapsed_time = (end_time - start_time) / 1e9
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    accuracies.append(accuracy)
    print(
        f"{epoch}  ({epoch / num_epochs:.0%}) \t{elapsed_time:.2f}s\tTrain Loss: {train_loss:.2f}\tTest Loss: {test_loss:.2f}\tAccuracy: {accuracy:.2f}"
    )

print("Training complete.")

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

plt.figure()
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.plot(accuracies, label="Accuracy")
plt.xlabel("Epoch")
plt.show()
