In [None]:
import torchvision.transforms as transforms
import torch

from src.concept_bottleneck.dataset import CUB200_2011

# According to: https://pytorch.org/hub/pytorch_vision_inception_v3/
preprocess = transforms.Compose(
    [
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ]
)

training_data: CUB200_2011[torch.Tensor, torch.Tensor] = CUB200_2011(
    train=True, transform=preprocess, target_transform=torch.from_numpy  # type: ignore
)
test_data: CUB200_2011[torch.Tensor, torch.Tensor] = CUB200_2011(
    train=False, download=False, transform=preprocess, target_transform=torch.from_numpy  # type: ignore
)


In [None]:
from torch.utils.data import DataLoader

batch_size = 16
num_workers = 2

train_dataloader = DataLoader(
    training_data, batch_size=batch_size, shuffle=True, num_workers=2
)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=2)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model: torch.nn.Module = torch.hub.load(
    "pytorch/vision:v0.10.0",
    "inception_v3",
    init_weights=False,
    num_classes=training_data.num_attributes,
)
model = model.to(device)


In [None]:
def train(
    dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
):
    size = len(dataloader.dataset)  # type: ignore
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(torch.float).to(device)  # type: ignore

        logits, aux_logits = model(X)
        loss = loss_fn(logits, y) + 0.4 * loss_fn(aux_logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


In [None]:
def test(
    dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
):
    size = len(dataloader.dataset)  # type: ignore
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        sigmoid = torch.nn.Sigmoid()
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)  # type: ignore
            logits = model(X)
            test_loss += loss_fn(logits, y.to(torch.float)).item()
            correct_attributes = torch.sum(
                ((sigmoid(logits) > 0.5).to(torch.int64) == y).to(torch.float)
            )
            num_attributes = y.shape[1]
            correct += correct_attributes.item() / num_attributes
    test_loss /= num_batches
    accuracy = correct / size
    return test_loss, accuracy


In [None]:
def save_model(model: torch.nn.Module):
    torch.save(model.state_dict(), "model_weights.pth")
    print("Saved PyTorch Model State to model.pth")


def load_model(model: torch.nn.Module):
    model.load_state_dict(torch.load("model_weights.pth"))
    print("Loaded PyTorch Model State from model.pth")
    model.eval()


In [None]:
def run_epoch(epochs: int):
    loss_fn = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    train_losses: list[float] = []
    train_accuracies: list[float] = []
    test_losses: list[float] = []
    test_accuracies: list[float] = []
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer)
        train_loss, train_accuracy = test(train_dataloader, model, loss_fn)
        print(
            f"Train Accuracy: {(100 * train_accuracy):>0.10f}%, Avg loss: {train_loss:>8f}"
        )
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        test_loss, test_accuracy = test(test_dataloader, model, loss_fn)
        print(
            f"Train Accuracy: {(100 * test_accuracy):>0.10f}%, Avg loss: {test_loss:>8f}"
        )
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        if test_accuracy > 0.85 and t % 10 == 0:
            save_model(model)

        if test_accuracy > 0.98:
            print("Reached 98% accuracy so cancelling training")
            break

    print("Done!")
    return train_losses, train_accuracies, test_losses, test_accuracies


In [None]:
train_losses, train_accuracies, test_losses, test_accuracies = run_epoch(epochs=1000)
save_model(model)

import json

with open("train_losses.json", "w") as f:
    json.dump(train_losses, f)

with open("train_accuracies.json", "w") as f:
    json.dump(train_accuracies, f)

with open("test_losses.json", "w") as f:
    json.dump(test_losses, f)

with open("test_accuracies.json", "w") as f:
    json.dump(test_accuracies, f)
