In [48]:
%reload_ext autoreload
%autoreload 2

import torch
from data.names_dataset import NamesDataset
from torch.utils.data import DataLoader

# Check for available accelerators
device = (
    torch.accelerator.current_accelerator()
    if torch.accelerator.is_available()
    else torch.device("cpu")
)
print(f"Using device: {device}")
assert device

# Initialize NamesDataset with the detected device
names_dataset = NamesDataset(data_folder="../datasets/names")


# train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.85, 0.15])
train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.85, 0.15])
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Using device: mps
Train dataset size: 17063
Test dataset size: 3011


In [49]:
BATCH_SIZE = 64


def collate_fn(
    batch: list[tuple[torch.Tensor, torch.Tensor]],
) -> list[tuple[torch.Tensor, torch.Tensor]]:
    return batch


train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)


In [None]:
import torch.nn as nn
import torch.nn.functional as F


class NamesClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NamesClassifier, self).__init__()
        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=True,
        )
        self.h2o = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is of shape (batch_size, seq_length, input_size)
        output, hidden = self.rnn(x)
        output = self.h2o(hidden[0])
        # output is of shape (batch_size, output_size)
        output = F.log_softmax(output, dim=1)
        return output


def train(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        batch_loss = torch.tensor(0.0)
        for input, label in batch:
            output = model(input)
            batch_loss += criterion(output, label.argmax(dim=1))

        total_loss += batch_loss
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    return total_loss / len(dataloader)


def evaluate(model: nn.Module, dataloader: DataLoader, criterion: nn.Module):
    model.eval()
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            for input, label in batch:
                output = model(input)
                loss = criterion(output, label.argmax(dim=1))
                predictions = output.argmax(dim=1)
                correct += (predictions == label.argmax(dim=1)).sum().item()
                total_loss += loss.item()

    accuracy = correct / len(dataloader)
    return total_loss / len(dataloader), accuracy

In [51]:
import time

rnn = NamesClassifier(
    input_size=len(names_dataset.index_to_token),
    hidden_size=256,
    output_size=len(names_dataset.countries),
)

print(rnn)
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    start_time = time.perf_counter_ns()
    train_loss = train(rnn, train_dataloader, optimizer, criterion)
    test_loss, accuracy = evaluate(rnn, test_dataloader, criterion)
    end_time = time.perf_counter_ns()
    elapsed_time = (end_time - start_time) / 1e6
    print(
        f"Epoch {epoch + 1}/{num_epochs}, Elapsed Time: {elapsed_time:.2f}ms, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}"
    )

NamesClassifier(
  (rnn): RNN(87, 256, batch_first=True)
  (h2o): Linear(in_features=256, out_features=18, bias=True)
)
Epoch 1/10, Elapsed Time: 11312.15ms, Train Loss: 92.5351, Test Loss: 70.2659, Accuracy: 43.2083
Epoch 2/10, Elapsed Time: 11218.50ms, Train Loss: 66.2782, Test Loss: 58.2939, Accuracy: 45.6458
Epoch 3/10, Elapsed Time: 11283.15ms, Train Loss: 57.5835, Test Loss: 53.1384, Accuracy: 47.2292
Epoch 4/10, Elapsed Time: 11175.58ms, Train Loss: 52.8761, Test Loss: 51.7221, Accuracy: 47.4583
Epoch 5/10, Elapsed Time: 11478.35ms, Train Loss: 49.9978, Test Loss: 47.4878, Accuracy: 48.7083
Epoch 6/10, Elapsed Time: 11251.63ms, Train Loss: 47.1136, Test Loss: 46.3374, Accuracy: 48.8125
Epoch 7/10, Elapsed Time: 11312.57ms, Train Loss: 45.3075, Test Loss: 45.1199, Accuracy: 49.3750
Epoch 8/10, Elapsed Time: 11402.86ms, Train Loss: 42.8366, Test Loss: 45.4286, Accuracy: 49.0625
Epoch 9/10, Elapsed Time: 11372.15ms, Train Loss: 41.5598, Test Loss: 44.5040, Accuracy: 48.6250
Epoch 1