In [None]:
import torch

from src.concept_bottleneck.dataset import CUB200AttributesToClass

training_data: CUB200AttributesToClass[
    torch.Tensor, torch.Tensor
] = CUB200AttributesToClass(
    train=True,
    transform=torch.from_numpy,  # type: ignore
    target_transform=lambda x: x - 1,  # from 1-indexed to 0-indexed
)
test_data: CUB200AttributesToClass[
    torch.Tensor, torch.Tensor
] = CUB200AttributesToClass(
    train=False,
    transform=torch.from_numpy,  # type: ignore
    target_transform=lambda x: x - 1,  # from 1-indexed to 0-indexed
)


In [None]:
from torch.utils.data import DataLoader

batch_size = 4
num_workers = 1

training_dataloader = DataLoader(
    training_data, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)


In [None]:
from torchvision.ops import MLP

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model = MLP(
    in_channels=training_data.num_attributes,
    hidden_channels=[training_data.num_classes],
)
model = model.to(device)


In [None]:
def train(
    dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
):
    size = len(dataloader.dataset)  # type: ignore
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(torch.float).to(device), y.to(device)  # type: ignore

        logits = model(X)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 1000 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


In [None]:
def test(
    dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
):
    size = len(dataloader.dataset)  # type: ignore
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(torch.float).to(device), y.to(device)  # type: ignore
            logits = model(X)
            test_loss += loss_fn(logits, y).item()
            correct += (torch.argmax(logits, dim=1) == y).sum().item()
    test_loss /= num_batches
    accuracy = correct / size
    return test_loss, accuracy


In [None]:
def save_model(model: torch.nn.Module, filename: str):
    torch.save(model.state_dict(), f"{filename}.pth")
    print(f"Saved PyTorch Model State to {filename}.pth")


def load_model(model: torch.nn.Module, filename: str):
    model.load_state_dict(torch.load(f"{filename}.pth"))
    print(f"Loaded PyTorch Model State from {filename}.pth")
    model.eval()


In [None]:
def run_epoch(epochs: int):
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    train_losses: list[float] = []
    train_accuracies: list[float] = []
    test_losses: list[float] = []
    test_accuracies: list[float] = []
    for t in range(epochs):
        print(f"\nEpoch {t+1}\n-------------------------------")
        train(training_dataloader, model, loss_fn, optimizer)
        train_loss, train_accuracy = test(training_dataloader, model, loss_fn)
        print(
            f"Train Accuracy: {(100 * train_accuracy):>0.10f}%, Avg loss: {train_loss:>8f}"
        )
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        test_loss, test_accuracy = test(test_dataloader, model, loss_fn)
        print(
            f"Test Accuracy: {(100 * test_accuracy):>0.10f}%, Avg loss: {test_loss:>8f}"
        )
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        if test_accuracy > 0.85 and t % 50 == 0:
            save_model(model, f"mlp_model_{t}")

        if test_accuracy > 0.98:
            print("Reached 98% accuracy so cancelling training")
            break

    print("Done!")
    return train_losses, train_accuracies, test_losses, test_accuracies


In [None]:
train_losses, train_accuracies, test_losses, test_accuracies = run_epoch(epochs=5000)
save_model(model, "mlp_model_final")

import json

with open("train_losses.json", "w") as f:
    json.dump(train_losses, f)

with open("train_accuracies.json", "w") as f:
    json.dump(train_accuracies, f)

with open("test_losses.json", "w") as f:
    json.dump(test_losses, f)

with open("test_accuracies.json", "w") as f:
    json.dump(test_accuracies, f)
