In [1]:
from torch.utils.data import DataLoader

from src.concept_bottleneck.dataset import (
    CUB200AttributesToClass,
    NUM_ATTRIBUTES,
    NUM_CLASSES,
)

batch_size = 4

training_data = CUB200AttributesToClass(train=True)
test_data = CUB200AttributesToClass(train=False)

training_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)


In [2]:
import torch
from torchvision.ops import MLP

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = MLP(
    in_channels=NUM_ATTRIBUTES,
    hidden_channels=[NUM_CLASSES],
).to(device)


Using device: cuda


In [3]:
import numpy.typing as npt
import numpy as np


def train(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[npt.NDArray[np.float32], np.int_]],
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def test(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[npt.NDArray[np.float32], np.int_]],
    loss_fn: torch.nn.Module,
    device: str,
):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            test_loss += loss_fn(logits, y).item()
            correct += (logits.argmax(dim=1) == y).sum().item()

    test_loss /= len(dataloader)
    accuracy = correct / len(dataloader.dataset)  # type: ignore

    return test_loss, accuracy


In [4]:
from src.concept_bottleneck.train import TrainFn, TestFn, run_epochs

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_fn: TrainFn = lambda model: train(
    model, training_dataloader, loss_fn, optimizer, device
)
test_fn: TestFn = lambda model, dataloader: test(model, dataloader, loss_fn, device)

epochs = 100

run_epochs(
    epochs,
    model,
    train_fn,
    test_fn,
    training_dataloader,
    test_dataloader,
    save_name="attributes-to-class.pth",
)


Epoch 1/100-------------------
Training Loss: 2.7229, Training Accuracy: 47.0637%
Test Loss: 3.2976, Test Accuracy: 30.8595%
Saving model to attributes-to-class.pth with accuracy 30.8595%
Epoch 2/100-------------------
Training Loss: 2.0053, Training Accuracy: 62.8795%
Test Loss: 2.8224, Test Accuracy: 39.1094%
Saving model to attributes-to-class.pth with accuracy 39.1094%
Epoch 3/100-------------------
Training Loss: 1.6295, Training Accuracy: 69.2359%
Test Loss: 2.6060, Test Accuracy: 41.7328%
Saving model to attributes-to-class.pth with accuracy 41.7328%
Epoch 4/100-------------------
Training Loss: 1.3737, Training Accuracy: 73.8572%
Test Loss: 2.4832, Test Accuracy: 43.3207%
Saving model to attributes-to-class.pth with accuracy 43.3207%
Epoch 5/100-------------------
Training Loss: 1.1960, Training Accuracy: 76.8769%
Test Loss: 2.4110, Test Accuracy: 44.4598%
Saving model to attributes-to-class.pth with accuracy 44.4598%
Epoch 6/100-------------------
Training Loss: 1.0583, Traini

OrderedDict([('0.weight',
              tensor([[-0.8163, -1.4979, -0.0926,  ..., -0.9268,  0.3362, -0.3397],
                      [-0.6414, -3.0858,  0.2448,  ..., -0.1183, -0.3665, -1.7412],
                      [-2.0676, -2.9883, -0.1602,  ..., -0.9300, -2.5451,  0.3916],
                      ...,
                      [ 0.9792,  1.6769, -4.3831,  ...,  3.5402, -0.9775, -2.5518],
                      [-4.0948, -2.0613, -0.3954,  ...,  0.2688, -1.4746, -0.4283],
                      [-1.1528, -1.6002,  0.2506,  ..., -1.4795, -1.5225, -0.1176]],
                     device='cuda:0')),
             ('0.bias',
              tensor([-7.0941e-02, -3.2560e-01, -8.5631e-02, -5.7371e-01, -5.1653e-01,
                      -3.6832e-01, -5.5124e-01, -2.8476e-01,  2.5166e-01, -7.4276e-01,
                       1.2000e-01, -7.2834e-01, -4.2010e-01, -1.0039e-01, -4.1842e-01,
                       7.8252e-03, -5.0166e-01, -1.1946e-01, -1.1594e-01, -1.6041e-01,
                      -6.3405e