In [None]:
%reload_ext autoreload
%autoreload 2

import torch
from data.names_dataset import NamesDataset
from torch.utils.data import DataLoader

# Check for available accelerators
device = (
    torch.accelerator.current_accelerator()
    if torch.accelerator.is_available()
    else torch.device("cpu")
)
print(f"Using device: {device}")
assert device

# Initialize NamesDataset with the detected device
names_dataset = NamesDataset(data_folder="../datasets/names")


# train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.85, 0.15])
train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.85, 0.15])
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence


class NamesClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NamesClassifier, self).__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=False,
        )
        self.h2o = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is of shape (seq_length, batch_size, input_size)
        _lstm_output, (hidden, _cell) = self.lstm(x)
        # Use the last hidden state for classification
        last_hidden_state = hidden[-1]  # Shape: (batch_size, hidden_size)
        h2o_output = self.h2o(last_hidden_state)  # Shape: (batch_size, output_size)
        return h2o_output


def collate_fn(batch):
    inputs = []
    labels = []

    for input, label in batch:
        inputs.append(input.squeeze(1))
        labels.append(label.squeeze(0))
    inputs_padded = pad_sequence(list(inputs), batch_first=False)
    labels = torch.tensor(labels, dtype=torch.long)

    return inputs_padded, labels


# Updated train function to compute loss efficiently for the entire batch
def train(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
):
    model.train()
    total_loss = 0.0
    for inputs, labels in dataloader:
        # Forward pass for the entire batch
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


# Updated evaluate function to compute loss and accuracy efficiently
def evaluate(model: nn.Module, dataloader: DataLoader, criterion: nn.Module):
    model.eval()
    total_loss = 0.0
    correct = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            # Forward pass for the entire batch
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Compute predictions and accuracy
            predictions = outputs.argmax(dim=1)
            correct += (predictions == labels).sum().item()
            total_samples += labels.size(0)
            total_loss += loss.item()

    return total_loss / len(dataloader), correct / total_samples


def predict(model: nn.Module, input: str) -> str:
    tensor = names_dataset.name_to_tensor(input)
    model.eval()
    with torch.no_grad():
        output = model(tensor)
        return names_dataset.countries[output.argmax(dim=1).item()]

In [None]:
import time
import matplotlib.pyplot as plt

BATCH_SIZE = 256
NUM_EPOCHS = 20
LEARNING_RATE = 0.005
HIDDEN_SIZE = 128


train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)


rnn = NamesClassifier(
    input_size=len(names_dataset.index_to_token),
    hidden_size=HIDDEN_SIZE,
    output_size=len(names_dataset.countries),
)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=LEARNING_RATE)

print("Starting training...")
print(rnn)

train_losses = []
test_losses = []
accuracies = []
for epoch in range(NUM_EPOCHS):
    start_time = time.perf_counter_ns()
    train_loss = train(rnn, train_dataloader, optimizer, criterion)
    test_loss, accuracy = evaluate(rnn, test_dataloader, criterion)
    end_time = time.perf_counter_ns()
    elapsed_time = (end_time - start_time) / 1e9
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    accuracies.append(accuracy)
    print(
        f"{epoch}  ({epoch / NUM_EPOCHS:.0%}) \t{elapsed_time:.2f}s\tTrain Loss: {train_loss:.2f}\tTest Loss: {test_loss:.2f}\tAccuracy: {accuracy:.2f}"
    )

print("Training complete.")

plt.figure()
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.plot(accuracies, label="Accuracy")
plt.xlabel("Epoch")
plt.show()

print(predict(rnn, "John"))
print(predict(rnn, "Maria"))
print(predict(rnn, "Yuki"))
print(predict(rnn, "Sven"))
print(predict(rnn, "Hai"))
print(predict(rnn, "Vivian"))