In [None]:
from torch.utils.data import DataLoader

from src.concept_bottleneck.dataset import (
    CUB200AttributesToClass,
    NUM_ATTRIBUTES,
    NUM_CLASSES,
)

batch_size = 4

training_data = CUB200AttributesToClass(train=True)
test_data = CUB200AttributesToClass(train=False)

training_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)


In [None]:
import torch
from torchvision.ops import MLP

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = MLP(in_channels=NUM_ATTRIBUTES, hidden_channels=[256, NUM_CLASSES]).to(device)


In [None]:
import numpy.typing as npt
import numpy as np


def train(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[npt.NDArray[np.float64], np.int_]],
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def test(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[npt.NDArray[np.float64], np.int_]],
    loss_fn: torch.nn.Module,
    device: str,
):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            test_loss += loss_fn(logits, y).item()
            correct += (logits.argmax(dim=1) == y).sum().item()

    test_loss /= len(dataloader)
    accuracy = correct / len(dataloader.dataset)  # type: ignore

    return test_loss, accuracy


In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

epochs = 100

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}-------------------")

    train(model, training_dataloader, loss_fn, optimizer, device)

    training_loss, training_acc = test(model, training_dataloader, loss_fn, device)
    print(
        f"Training Loss: {training_loss:.4f}, Training Accuracy: {100 * training_acc:>0.4f}%"
    )

    test_loss, test_acc = test(model, test_dataloader, loss_fn, device)
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {100 * test_acc:>0.4f}%")
