In [1]:
from torch.utils.data import DataLoader

from src.concept_bottleneck.dataset import CUB200ImageToClass

batch_size = 16
num_workers = 2

training_data = CUB200ImageToClass(train=True)
test_data = CUB200ImageToClass(train=False)

training_dataloader = DataLoader(
    training_data, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)


In [2]:
import torch

from src.concept_bottleneck.networks import get_mlp, get_inception

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


class JointImageToClass(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.inception = get_inception()
        self.mlp = get_mlp()

    def forward(self, x):  # type: ignore
        if self.training:
            x, _ = self.inception(x)
        else:
            x = self.inception(x)
        x = self.mlp(torch.sigmoid(x))
        return x


model = JointImageToClass().to(device)


Using device: cuda


Using cache found in /home/shuangwu/.cache/torch/hub/pytorch_vision_v0.10.0


In [3]:
import numpy as np


def train(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, np.int_]],
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    size = len(dataloader.dataset)  # type: ignore
    for batch, (x, y) in enumerate(dataloader):
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            print(f"loss: {loss.item():>7f} [{batch * len(x):>5d}/{size:>5d}]")


def test(
    model: torch.nn.Module,
    dataloader: DataLoader[tuple[torch.Tensor, np.int_]],
    loss_fn: torch.nn.Module,
    device: str,
):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            test_loss += loss_fn(logits, y).item()
            correct += (logits.argmax(dim=1) == y).sum().item()

    test_loss /= len(dataloader)
    accuracy = correct / len(dataloader.dataset)  # type: ignore

    return test_loss, accuracy


In [4]:
from src.concept_bottleneck.train import TrainFn, TestFn, run_epochs, MODEL_PATH
from src.concept_bottleneck.inference import (
    JOINT_IMAGE_TO_ATTRIBUTES_MODEL_NAME,
    JOINT_ATTRIBUTES_TO_CLASS_MODEL_NAME,
)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

train_fn: TrainFn = lambda model: train(
    model,
    training_dataloader,
    loss_fn,
    optimizer,
    device,
)
test_fn: TestFn = lambda model, dataloader: test(model, dataloader, loss_fn, device)

epochs = 150


def on_better_accuracy(model: JointImageToClass, accuracy: float):
    print(
        f"Saving model to {JOINT_IMAGE_TO_ATTRIBUTES_MODEL_NAME} with accuracy {100 * accuracy:>0.4f}%"
    )
    torch.save(
        model.inception.state_dict(), MODEL_PATH / JOINT_IMAGE_TO_ATTRIBUTES_MODEL_NAME
    )

    print(
        f"Saving model to {JOINT_ATTRIBUTES_TO_CLASS_MODEL_NAME} with accuracy {100 * accuracy:>0.4f}%"
    )
    torch.save(
        model.mlp.state_dict(), MODEL_PATH / JOINT_ATTRIBUTES_TO_CLASS_MODEL_NAME
    )


run_epochs(
    epochs,
    model,
    train_fn,
    test_fn,
    training_dataloader,
    test_dataloader,
    on_better_accuracy,
)


Epoch 1/150-------------------
loss: 5.279185 [    0/ 5994]
loss: 5.315689 [ 1600/ 5994]
loss: 5.308527 [ 3200/ 5994]
loss: 5.268953 [ 4800/ 5994]
Training Loss: 5.2992, Training Accuracy: 0.5172%
Test Loss: 5.3002, Test Accuracy: 0.5696%
Saving model to joint_image_to_attributes.pth with accuracy 0.5696%
Saving model to joint_attributes_to_class.pth with accuracy 0.5696%
Epoch 2/150-------------------
loss: 5.329253 [    0/ 5994]
loss: 5.334650 [ 1600/ 5994]
loss: 5.369268 [ 3200/ 5994]
loss: 5.354413 [ 4800/ 5994]
Training Loss: 5.2113, Training Accuracy: 0.9676%
Test Loss: 5.2151, Test Accuracy: 0.9147%
Saving model to joint_image_to_attributes.pth with accuracy 0.9147%
Saving model to joint_attributes_to_class.pth with accuracy 0.9147%
Epoch 3/150-------------------
loss: 5.456625 [    0/ 5994]
loss: 5.313138 [ 1600/ 5994]
loss: 5.457387 [ 3200/ 5994]
loss: 4.947141 [ 4800/ 5994]
Training Loss: 5.0591, Training Accuracy: 1.2346%
Test Loss: 5.0726, Test Accuracy: 1.1046%
Saving mode

OrderedDict([('inception.Conv2d_1a_3x3.conv.weight',
              tensor([[[[-1.1568e+00, -1.2925e+00, -8.7065e-01],
                        [-1.6539e+00, -1.9719e+00, -1.5733e+00],
                        [-1.5897e+00, -1.6926e+00, -1.4974e+00]],
              
                       [[ 2.3193e+00,  2.3398e+00,  2.2714e+00],
                        [ 2.0073e+00,  1.7957e+00,  1.7701e+00],
                        [ 1.9643e+00,  1.9241e+00,  1.9529e+00]],
              
                       [[ 2.1265e-01,  2.2891e-01, -1.0033e-01],
                        [ 5.1204e-02, -1.5918e-01, -3.4908e-01],
                        [-6.6735e-02, -1.4212e-01, -2.2394e-01]]],
              
              
                      [[[-9.2755e-02, -6.8218e-03, -3.6571e-02],
                        [-3.4192e-02,  7.7841e-02,  2.5105e-02],
                        [-6.9930e-02, -3.6620e-03, -1.4488e-02]],
              
                       [[-2.0621e-02,  1.5350e-02,  1.7270e-02],
                      