In [1]:
from torch.utils.data import DataLoader
from torchvision import transforms

from src.concept_bottleneck.dataset import (
    CUB200ImageToAttributes,
    NUM_ATTRIBUTES,
)

batch_size = 16
num_workers = 2


training_preprocess = transforms.Compose(
    [
        transforms.RandomResizedCrop(299),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ]
)
training_data = CUB200ImageToAttributes(train=True, transform=training_preprocess)
training_dataloader = DataLoader(
    training_data, batch_size=batch_size, num_workers=num_workers, shuffle=True
)

test_data = CUB200ImageToAttributes(train=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)


In [2]:
import torch
from src.concept_bottleneck.networks import get_inception

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model: torch.nn.Module = get_inception().to(device)


Using cuda device


Using cache found in /home/shuangwu/.cache/torch/hub/pytorch_vision_v0.10.0


In [3]:
import numpy.typing as npt
import numpy as np


def train(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, npt.NDArray[np.float32]]],
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    size = len(dataloader.dataset)  # type: ignore
    for batch, (x, y) in enumerate(dataloader):
        x = x.to(device)
        y = y.to(device)

        logits, aux_logits = model(x)
        loss = loss_fn(logits, y) + 0.4 * loss_fn(aux_logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            print(f"loss: {loss.item():>7f} [{batch * len(x):>5d}/{size:>5d}]")


def test(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, npt.NDArray[np.float32]]],
    loss_fn: torch.nn.Module,
    device: str,
):
    model.eval()
    test_loss = 0
    correct = 0
    total_correct = 0
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            test_loss += loss_fn(logits, y).item()

            correct_attributes = (
                ((torch.sigmoid(logits) >= 0.5) == (y >= 0.5)).sum().item()
            )
            correct += correct_attributes / NUM_ATTRIBUTES

            total_correct += ( # Count the number of images with all attributes correct
                torch.all((torch.sigmoid(logits) >= 0.5) == (y >= 0.5), dim=1)
                .sum()
                .item()
            )

    test_loss /= len(dataloader)
    accuracy = correct / len(dataloader.dataset)  # type: ignore
    total_accuracy = total_correct / len(dataloader.dataset)  # type: ignore
    print(f"Total accuracy: {total_accuracy:>0.10f}%")

    return test_loss, accuracy


In [4]:
from src.concept_bottleneck.train import TrainFn, TestFn, run_epochs, MODEL_PATH
from src.concept_bottleneck.inference import INDEPENDENT_IMAGE_TO_ATTRIBUTES_MODEL_NAME

loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

train_fn: TrainFn = lambda model: train(
    model, training_dataloader, loss_fn, optimizer, device
)
test_fn: TestFn = lambda model, dataloader: test(model, dataloader, loss_fn, device)

epochs = 20


def on_better_accuracy(model: torch.nn.Module, accuracy: float):
    print(
        f"Saving model to {INDEPENDENT_IMAGE_TO_ATTRIBUTES_MODEL_NAME} with accuracy {100 * accuracy:>0.4f}%"
    )
    torch.save(
        model.state_dict(), MODEL_PATH / INDEPENDENT_IMAGE_TO_ATTRIBUTES_MODEL_NAME
    )


run_epochs(
    epochs,
    model,
    train_fn,
    test_fn,
    training_dataloader,
    test_dataloader,
    on_better_accuracy
)


Epoch 1/20-------------------
loss: 0.993363 [    0/ 5994]
loss: 0.365139 [ 1600/ 5994]
loss: 0.328849 [ 3200/ 5994]
loss: 0.325951 [ 4800/ 5994]
Total accuracy: 0.0000000000%
Training Loss: 0.2250, Training Accuracy: 91.1533%
Total accuracy: 0.0000000000%
Test Loss: 0.2244, Test Accuracy: 91.1861%
Saving model to independent-image-to-attributes.pth with accuracy 91.1861%
Epoch 2/20-------------------
loss: 0.336230 [    0/ 5994]
loss: 0.316452 [ 1600/ 5994]
loss: 0.367877 [ 3200/ 5994]
loss: 0.335515 [ 4800/ 5994]
Total accuracy: 0.0000000000%
Training Loss: 0.2151, Training Accuracy: 91.4797%
Total accuracy: 0.0000000000%
Test Loss: 0.2161, Test Accuracy: 91.4603%
Saving model to independent-image-to-attributes.pth with accuracy 91.4603%
Epoch 3/20-------------------
loss: 0.357367 [    0/ 5994]
loss: 0.284756 [ 1600/ 5994]
loss: 0.276843 [ 3200/ 5994]
loss: 0.308982 [ 4800/ 5994]
Total accuracy: 0.0000000000%
Training Loss: 0.2076, Training Accuracy: 91.7353%
Total accuracy: 0.00000

OrderedDict([('Conv2d_1a_3x3.conv.weight',
              tensor([[[[-2.0966e-01, -3.7044e-01, -1.1657e-01],
                        [-1.4970e-01, -2.8173e-01, -8.8491e-02],
                        [-2.5879e-02, -5.0224e-02, -3.3978e-02]],
              
                       [[ 2.2028e-01,  2.3109e-01,  6.9276e-02],
                        [ 1.3729e-01,  1.3355e-01,  3.3492e-02],
                        [-1.9928e-02,  8.8372e-03,  4.7262e-02]],
              
                       [[ 1.3894e-01,  2.6443e-01, -3.2625e-02],
                        [ 9.5345e-02,  1.8760e-01, -1.0267e-02],
                        [-1.4566e-01, -3.0897e-02, -9.6407e-02]]],
              
              
                      [[[ 2.7257e-03,  3.5081e-02, -6.7276e-04],
                        [ 3.4907e-02,  9.1352e-02,  1.9268e-02],
                        [-7.4279e-03,  1.1542e-03, -3.1608e-02]],
              
                       [[ 1.6386e-02,  4.9063e-03,  4.5662e-03],
                        [ 2.6894