In [1]:
from torch.utils.data import DataLoader

from src.concept_bottleneck.dataset import (
    CUB200AttributesToClass,
    NUM_ATTRIBUTES,
    NUM_CLASSES,
)

batch_size = 4

training_data = CUB200AttributesToClass(train=True)
test_data = CUB200AttributesToClass(train=False)

training_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)


In [2]:
import torch

from src.concept_bottleneck.networks import get_mlp

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = get_mlp().to(device)


Using device: cuda


In [3]:
import numpy.typing as npt
import numpy as np


def train(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[npt.NDArray[np.float32], np.int_]],
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def test(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[npt.NDArray[np.float32], np.int_]],
    loss_fn: torch.nn.Module,
    device: str,
):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            test_loss += loss_fn(logits, y).item()
            correct += (logits.argmax(dim=1) == y).sum().item()

    test_loss /= len(dataloader)
    accuracy = correct / len(dataloader.dataset)  # type: ignore

    return test_loss, accuracy


In [4]:
from src.concept_bottleneck.train import TrainFn, TestFn, run_epochs

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

train_fn: TrainFn = lambda model: train(
    model, training_dataloader, loss_fn, optimizer, device
)
test_fn: TestFn = lambda model, dataloader: test(model, dataloader, loss_fn, device)

epochs = 100

run_epochs(
    epochs,
    model,
    train_fn,
    test_fn,
    training_dataloader,
    test_dataloader,
    save_name="attributes-to-class.pth",
)


Epoch 1/100-------------------
Training Loss: 4.8809, Training Accuracy: 18.6019%
Test Loss: 4.9268, Test Accuracy: 15.0673%
Saving model to attributes-to-class.pth with accuracy 15.0673%
Epoch 2/100-------------------
Training Loss: 4.5035, Training Accuracy: 32.1655%
Test Loss: 4.5937, Test Accuracy: 23.9731%
Saving model to attributes-to-class.pth with accuracy 23.9731%
Epoch 3/100-------------------
Training Loss: 4.1775, Training Accuracy: 38.8722%
Test Loss: 4.3097, Test Accuracy: 29.0645%
Saving model to attributes-to-class.pth with accuracy 29.0645%
Epoch 4/100-------------------
Training Loss: 3.8949, Training Accuracy: 43.1431%
Test Loss: 4.0673, Test Accuracy: 32.8098%
Saving model to attributes-to-class.pth with accuracy 32.8098%
Epoch 5/100-------------------
Training Loss: 3.6495, Training Accuracy: 46.8468%
Test Loss: 3.8601, Test Accuracy: 35.2434%
Saving model to attributes-to-class.pth with accuracy 35.2434%
Epoch 6/100-------------------
Training Loss: 3.4352, Traini

OrderedDict([('0.weight',
              tensor([[-0.0922, -0.3742,  0.1346,  ..., -0.2687,  0.2704, -0.2033],
                      [ 0.1067, -0.3059,  0.4043,  ..., -0.0886, -0.0688, -0.7138],
                      [-0.0604, -0.3150,  0.1318,  ..., -0.2120, -0.5559,  0.1888],
                      ...,
                      [ 0.0651,  0.9468, -0.1142,  ...,  0.9651, -0.1223, -0.5563],
                      [-0.2269, -0.2144, -0.0398,  ...,  0.3109,  0.0655,  0.0720],
                      [-0.0551,  0.0073, -0.0054,  ..., -0.1103, -0.3590, -0.1296]],
                     device='cuda:0')),
             ('0.bias',
              tensor([ 0.6068,  0.6178,  0.3081,  0.1719,  0.2887,  0.3213,  0.1441,  0.8954,
                       0.4529, -0.0213, -0.1170, -0.0119, -0.2155, -0.0587, -0.1556,  0.2981,
                       0.2109,  0.3282, -0.2287, -0.2762, -0.0366,  0.4565,  0.3167,  0.5228,
                       0.1946, -0.3027,  0.1153, -0.1625, -0.6111, -0.1824, -0.1918,  0.0805,
  