In [1]:
from torch.utils.data import DataLoader

from src.concept_bottleneck.dataset import (
    CUB200AttributesToClass,
    NUM_ATTRIBUTES,
    NUM_CLASSES,
)

batch_size = 4

training_data = CUB200AttributesToClass(train=True)
test_data = CUB200AttributesToClass(train=False)

training_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)


In [2]:
import torch
from torchvision.ops import MLP

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = MLP(
    in_channels=NUM_ATTRIBUTES,
    hidden_channels=[256, NUM_CLASSES],
    dropout=0.5,
    inplace=False,
).to(device)


Using device: cuda


In [3]:
import numpy.typing as npt
import numpy as np


def train(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[npt.NDArray[np.float32], np.int_]],
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def test(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[npt.NDArray[np.float32], np.int_]],
    loss_fn: torch.nn.Module,
    device: str,
):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            test_loss += loss_fn(logits, y).item()
            correct += (logits.argmax(dim=1) == y).sum().item()

    test_loss /= len(dataloader)
    accuracy = correct / len(dataloader.dataset)  # type: ignore

    return test_loss, accuracy


In [4]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

epochs = 500

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}-------------------")

    train(model, training_dataloader, loss_fn, optimizer, device)

    training_loss, training_acc = test(model, training_dataloader, loss_fn, device)
    print(
        f"Training Loss: {training_loss:.4f}, Training Accuracy: {100 * training_acc:>0.4f}%"
    )

    test_loss, test_acc = test(model, test_dataloader, loss_fn, device)
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {100 * test_acc:>0.4f}%")


Epoch 1/500-------------------
Training Loss: 5.2945, Training Accuracy: 0.5506%
Test Loss: 5.2945, Test Accuracy: 0.5005%
Epoch 2/500-------------------
Training Loss: 5.2866, Training Accuracy: 0.7674%
Test Loss: 5.2879, Test Accuracy: 0.7767%
Epoch 3/500-------------------
Training Loss: 5.2793, Training Accuracy: 1.1845%
Test Loss: 5.2816, Test Accuracy: 1.2081%
Epoch 4/500-------------------
Training Loss: 5.2719, Training Accuracy: 1.6517%
Test Loss: 5.2754, Test Accuracy: 1.5878%
Epoch 5/500-------------------
Training Loss: 5.2643, Training Accuracy: 2.0020%
Test Loss: 5.2689, Test Accuracy: 1.8122%
Epoch 6/500-------------------
Training Loss: 5.2565, Training Accuracy: 2.7861%
Test Loss: 5.2623, Test Accuracy: 2.2610%
Epoch 7/500-------------------
Training Loss: 5.2480, Training Accuracy: 3.2366%
Test Loss: 5.2552, Test Accuracy: 2.6407%
Epoch 8/500-------------------
Training Loss: 5.2387, Training Accuracy: 3.8872%
Test Loss: 5.2470, Test Accuracy: 3.1757%
Epoch 9/500-----