In [1]:
from torch.utils.data import DataLoader
from torchvision import transforms

from src.concept_bottleneck.dataset import (
    CUB200ImageToAttributes,
    NUM_ATTRIBUTES,
)

batch_size = 16
num_workers = 2


training_preprocess = transforms.Compose(
    [
        transforms.RandomResizedCrop(299),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ]
)
training_data = CUB200ImageToAttributes(train=True, transform=training_preprocess)
training_dataloader = DataLoader(
    training_data, batch_size=batch_size, num_workers=num_workers, shuffle=True
)

test_data = CUB200ImageToAttributes(train=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)


In [2]:
import torch
from src.concept_bottleneck.networks import get_inception

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model: torch.nn.Module = get_inception().to(device)


Using cuda device


Using cache found in /home/shuangwu/.cache/torch/hub/pytorch_vision_v0.10.0


In [3]:
import numpy.typing as npt
import numpy as np


def train(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, npt.NDArray[np.float32]]],
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    size = len(dataloader.dataset)  # type: ignore
    for batch, (x, y) in enumerate(dataloader):
        x = x.to(device)
        y = y.to(device)

        logits, aux_logits = model(x)
        loss = loss_fn(logits, y) + 0.4 * loss_fn(aux_logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            print(f"loss: {loss.item():>7f} [{batch * len(x):>5d}/{size:>5d}]")


def test(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, npt.NDArray[np.float32]]],
    loss_fn: torch.nn.Module,
    device: str,
):
    model.eval()
    test_loss = 0
    correct = 0
    total_correct = 0
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            test_loss += loss_fn(logits, y).item()

            correct_attributes = (
                ((torch.sigmoid(logits) >= 0.5) == (y >= 0.5)).sum().item()
            )
            correct += correct_attributes / NUM_ATTRIBUTES

            total_correct += ( # Count the number of images with all attributes correct
                torch.all((torch.sigmoid(logits) >= 0.5) == (y >= 0.5), dim=1)
                .sum()
                .item()
            )

    test_loss /= len(dataloader)
    accuracy = correct / len(dataloader.dataset)  # type: ignore
    total_accuracy = total_correct / len(dataloader.dataset)  # type: ignore
    print(f"Total accuracy: {total_accuracy:>0.10f}%")

    return test_loss, accuracy


In [4]:
from src.concept_bottleneck.train import TrainFn, TestFn, run_epochs
from src.concept_bottleneck.inference import INDEPENDENT_IMAGE_TO_ATTRIBUTES_MODEL_NAME

loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

train_fn: TrainFn = lambda model: train(
    model, training_dataloader, loss_fn, optimizer, device
)
test_fn: TestFn = lambda model, dataloader: test(model, dataloader, loss_fn, device)

epochs = 40

run_epochs(
    epochs,
    model,
    train_fn,
    test_fn,
    training_dataloader,
    test_dataloader,
    save_name=INDEPENDENT_IMAGE_TO_ATTRIBUTES_MODEL_NAME,
)


Epoch 1/40-------------------
loss: 0.994618 [    0/ 5994]
loss: 0.980477 [ 1600/ 5994]
loss: 0.963351 [ 3200/ 5994]
loss: 0.948408 [ 4800/ 5994]
Total accuracy: 0.0000000000%
Training Loss: 0.6494, Training Accuracy: 68.4494%
Total accuracy: 0.0000000000%
Test Loss: 0.6492, Test Accuracy: 68.6302%
Saving model to image-to-attributes.pth with accuracy 68.6302%
Epoch 2/40-------------------
loss: 0.934690 [    0/ 5994]
loss: 0.921234 [ 1600/ 5994]
loss: 0.904710 [ 3200/ 5994]
loss: 0.895340 [ 4800/ 5994]
Total accuracy: 0.0000000000%
Training Loss: 0.5930, Training Accuracy: 81.4441%
Total accuracy: 0.0000000000%
Test Loss: 0.5925, Test Accuracy: 81.4879%
Saving model to image-to-attributes.pth with accuracy 81.4879%
Epoch 3/40-------------------
loss: 0.886735 [    0/ 5994]
loss: 0.866741 [ 1600/ 5994]
loss: 0.853452 [ 3200/ 5994]
loss: 0.840987 [ 4800/ 5994]
Total accuracy: 0.0000000000%
Training Loss: 0.5503, Training Accuracy: 87.2709%
Total accuracy: 0.0000000000%
Test Loss: 0.5536

OrderedDict([('Conv2d_1a_3x3.conv.weight',
              tensor([[[[ 5.6951e-02,  1.0615e-01, -1.7129e-01],
                        [ 1.0368e-01, -6.3231e-02,  2.1417e-02],
                        [ 1.8407e-01, -1.1232e-01, -1.3360e-01]],
              
                       [[-1.5117e-01,  1.9434e-02, -7.9931e-02],
                        [ 7.0975e-02, -1.4965e-01, -1.2370e-01],
                        [ 9.3350e-03, -1.7277e-01, -4.0761e-03]],
              
                       [[-3.4978e-03,  9.3602e-02,  5.4404e-02],
                        [-1.3265e-01, -4.2537e-02,  9.1765e-02],
                        [ 6.9057e-02,  7.2528e-02,  5.1086e-02]]],
              
              
                      [[[ 1.5154e-01,  1.5209e-01,  1.2308e-01],
                        [ 1.8178e-01, -6.2739e-02,  5.6475e-02],
                        [ 6.2908e-02,  1.5792e-01, -1.5148e-01]],
              
                       [[ 1.2345e-01, -5.1615e-02, -5.5928e-02],
                        [ 3.8004