In [None]:
from torch.utils.data import DataLoader

from src.concept_bottleneck.dataset import CUB200ImageToClass

batch_size = 16
num_workers = 2

training_data = CUB200ImageToClass(train=True)
test_data = CUB200ImageToClass(train=False)

training_dataloader = DataLoader(training_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)


In [None]:
import torch

from src.concept_bottleneck.networks import get_mlp

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = get_mlp().to(device)


In [None]:
import numpy as np


def train(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, np.int_]],
    trained_image_to_attributes_model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    trained_image_to_attributes_model.eval()
    size = len(dataloader.dataset)  # type: ignore
    for batch, (x, y) in enumerate(dataloader):
        x = trained_image_to_attributes_model(x.to(device))
        y = y.to(device)

        logits = model(torch.sigmoid(x))
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            print(f"loss: {loss.item():>7f} [{batch * len(x):>5d}/{size:>5d}]")


def test(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, np.int_]],
    trained_image_to_attributes_model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    device: str,
):
    model.eval()
    trained_image_to_attributes_model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = trained_image_to_attributes_model(x.to(device))
            y = y.to(device)

            logits = model(torch.sigmoid(x))
            test_loss += loss_fn(logits, y).item()
            correct += (logits.argmax(dim=1) == y).sum().item()

    test_loss /= len(dataloader)
    accuracy = correct / len(dataloader.dataset)  # type: ignore

    return test_loss, accuracy


In [None]:
from src.concept_bottleneck.train import TrainFn, TestFn, run_epochs, MODEL_PATH
from src.concept_bottleneck.inference import (
    load_image_to_attributes_model,
    INDEPENDENT_IMAGE_TO_ATTRIBUTES_MODEL_NAME,
    SEQUENTIAL_ATTRIBUTES_TO_CLASS_MODEL_NAME,
)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)

trained_image_to_attributes_model = load_image_to_attributes_model(
    INDEPENDENT_IMAGE_TO_ATTRIBUTES_MODEL_NAME, device
)

train_fn: TrainFn = lambda model: train(
    model,
    training_dataloader,
    trained_image_to_attributes_model,
    loss_fn,
    optimizer,
    device,
)
test_fn: TestFn = lambda model, dataloader: test(
    model, dataloader, trained_image_to_attributes_model, loss_fn, device
)

epochs = 300


def on_better_accuracy(model: torch.nn.Module, accuracy: float):
    print(
        f"Saving model to {SEQUENTIAL_ATTRIBUTES_TO_CLASS_MODEL_NAME} with accuracy {100 * accuracy:>0.4f}%"
    )
    torch.save(
        model.state_dict(), MODEL_PATH / SEQUENTIAL_ATTRIBUTES_TO_CLASS_MODEL_NAME
    )


run_epochs(
    epochs,
    model,
    train_fn,
    test_fn,
    training_dataloader,
    test_dataloader,
    on_better_accuracy,
)
