In [36]:
import torchvision.transforms as transforms
import torch

from src.concept_bottleneck.dataset import CUB200_2011

# According to: https://pytorch.org/hub/pytorch_vision_inception_v3/
preprocess = transforms.Compose(
    [
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ]
)

training_data: CUB200_2011[torch.Tensor, torch.Tensor] = CUB200_2011(
    train=True, transform=preprocess, target_transform=torch.from_numpy  # type: ignore
)
test_data: CUB200_2011[torch.Tensor, torch.Tensor] = CUB200_2011(
    train=False, download=False, transform=preprocess, target_transform=torch.from_numpy  # type: ignore
)


Using downloaded and verified file: /home/shuangwu/interactive-concept-bottleneck/src/concept_bottleneck/data/CUB_200_2011.tgz
Extracting /home/shuangwu/interactive-concept-bottleneck/src/concept_bottleneck/data/CUB_200_2011.tgz to /home/shuangwu/interactive-concept-bottleneck/src/concept_bottleneck/data


In [37]:
from torch.utils.data import DataLoader

batch_size = 8
num_workers = 2

train_dataloader = DataLoader(
    training_data, batch_size=batch_size, shuffle=True, num_workers=2
)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=2)


In [38]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model: torch.nn.Module = torch.hub.load(
    "pytorch/vision:v0.10.0",
    "inception_v3",
    init_weights=False,
    num_classes=training_data.num_attributes,
)
model = model.to(device)


Using cuda device


Using cache found in /home/shuangwu/.cache/torch/hub/pytorch_vision_v0.10.0


In [59]:
def train(
    dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
):
    size = len(dataloader.dataset)  # type: ignore
    model.train()
    losses: list[float] = []
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(torch.float).to(device)  # type: ignore

        logits, aux_logits = model(X)
        loss = loss_fn(logits, y) + 0.4 * loss_fn(aux_logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            losses.append(loss)
    return losses


In [60]:
def test(
    dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
):
    size = len(dataloader.dataset)  # type: ignore
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        sigmoid = torch.nn.Sigmoid()
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)  # type: ignore
            logits, _ = model(X)
            test_loss += loss_fn(logits, y).item()
            correct += torch.mean(
                ((sigmoid(logits) > 0).to(torch.int64) == y).to(torch.float)
            ).item()
    test_loss /= num_batches
    accuracy = correct / size
    print(
        f"Test Error: \n Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )
    return test_loss, accuracy


In [62]:
def run_epoch():
    epochs = 5
    loss_fn = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    train_losses: list[float] = []
    test_losses: list[float] = []
    test_accuracies: list[float] = []
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_losses.extend(train(train_dataloader, model, loss_fn, optimizer))
        test_loss, accuracy = test(test_dataloader, model, loss_fn)
        test_losses.append(test_loss)
        test_accuracies.append(accuracy)
    print("Done!")
    return train_losses, test_losses, test_accuracies

In [63]:
def save_model(model: torch.nn.Module):
    torch.save(model.state_dict(), "model_weights.pth")
    print("Saved PyTorch Model State to model.pth")

def load_model(model: torch.nn.Module):
    model.load_state_dict(torch.load("model_weights.pth"))
    print("Loaded PyTorch Model State from model.pth")
    model.eval()