In [1]:
from torch.utils.data import DataLoader

from src.concept_bottleneck.dataset import CUB200ImageToClass

batch_size = 16
num_workers = 2

training_data = CUB200ImageToClass(train=True)
test_data = CUB200ImageToClass(train=False)

training_dataloader = DataLoader(training_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)


In [2]:
import torch

from src.concept_bottleneck.networks import get_mlp

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = get_mlp().to(device)


Using device: cuda


In [3]:
import numpy as np

from src.concept_bottleneck.inference import SEQUENTIAL_ATTRIBUTES_TO_CLASS_MODEL_NAME


def train(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, np.int_]],
    trained_image_to_attributes_model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    trained_image_to_attributes_model.eval()
    size = len(dataloader.dataset)  # type: ignore
    for batch, (x, y) in enumerate(dataloader):
        x = trained_image_to_attributes_model(x.to(device))
        y = y.to(device)

        logits = model(x)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            print(f"loss: {loss.item():>7f} [{batch * len(x):>5d}/{size:>5d}]")


def test(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, np.int_]],
    trained_image_to_attributes_model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    device: str,
):
    model.eval()
    trained_image_to_attributes_model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = trained_image_to_attributes_model(x.to(device))
            y = y.to(device)

            logits = model(x)
            test_loss += loss_fn(logits, y).item()
            correct += (logits.argmax(dim=1) == y).sum().item()

    test_loss /= len(dataloader)
    accuracy = correct / len(dataloader.dataset)  # type: ignore

    return test_loss, accuracy


In [4]:
from src.concept_bottleneck.train import TrainFn, TestFn, run_epochs
from src.concept_bottleneck.inference import load_image_to_attributes_model

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

trained_image_to_attributes_model = load_image_to_attributes_model(device)

train_fn: TrainFn = lambda model: train(
    model,
    training_dataloader,
    trained_image_to_attributes_model,
    loss_fn,
    optimizer,
    device,
)
test_fn: TestFn = lambda model, dataloader: test(
    model, dataloader, trained_image_to_attributes_model, loss_fn, device
)

epochs = 200

run_epochs(
    epochs,
    model,
    train_fn,
    test_fn,
    training_dataloader,
    test_dataloader,
    save_name=SEQUENTIAL_ATTRIBUTES_TO_CLASS_MODEL_NAME,
)


Using cache found in /home/shuangwu/.cache/torch/hub/pytorch_vision_v0.10.0


Epoch 1/200-------------------
loss: 8.090341 [    0/ 5994]
loss: 443.980286 [ 1600/ 5994]
loss: 554.696838 [ 3200/ 5994]
loss: 243.043884 [ 4800/ 5994]
Training Loss: 325.3225, Training Accuracy: 15.8825%
Test Loss: 327.2878, Test Accuracy: 15.1536%
Saving model to sequential-attributes-to-class.pth with accuracy 15.1536%
Epoch 2/200-------------------
loss: 326.555908 [    0/ 5994]
loss: 249.687500 [ 1600/ 5994]
loss: 259.589081 [ 3200/ 5994]
loss: 267.599213 [ 4800/ 5994]
Training Loss: 284.8847, Training Accuracy: 20.2536%
Test Loss: 288.2812, Test Accuracy: 19.2440%
Saving model to sequential-attributes-to-class.pth with accuracy 19.2440%
Epoch 3/200-------------------
loss: 103.394791 [    0/ 5994]
loss: 177.530777 [ 1600/ 5994]
loss: 368.142120 [ 3200/ 5994]
loss: 317.048096 [ 4800/ 5994]
Training Loss: 248.4690, Training Accuracy: 22.2389%
Test Loss: 251.3528, Test Accuracy: 20.2969%
Saving model to sequential-attributes-to-class.pth with accuracy 20.2969%
Epoch 4/200----------

OrderedDict([('0.weight',
              tensor([[-5.1838e+00, -3.4911e+01,  1.1484e+01,  ..., -1.7288e+01,
                       -1.9966e+00, -2.2177e+01],
                      [ 1.0490e+01, -4.3466e+01,  9.3212e+00,  ...,  1.3436e+01,
                       -1.3708e+01, -2.7054e+01],
                      [-4.0387e+00, -2.3054e+01, -2.0484e+00,  ..., -1.5062e+01,
                       -1.2643e+01, -1.8658e+01],
                      ...,
                      [ 2.6812e+01,  4.7301e+01,  2.5565e+00,  ...,  3.7157e+01,
                        8.3943e+00, -6.0619e+00],
                      [ 1.1969e+01, -1.1978e+01,  1.2614e+01,  ...,  3.8636e+01,
                        9.8863e+00, -9.4816e+00],
                      [ 6.7023e+00,  2.5132e+01, -7.8330e-01,  ..., -4.1641e-02,
                       -2.1109e+01,  7.7411e+00]], device='cuda:0')),
             ('0.bias',
              tensor([-1.5246e+00, -8.7593e+00, -1.2093e-01, -2.6328e+00,  2.7481e+00,
                       4.6032e