In [1]:
import os
import sys

notebook_dir = os.getcwd()
project_root_path = os.path.dirname(notebook_dir)
sys.path.insert(0, project_root_path)

from src import ModelXtoC
from scripts.run_preprocessing import preprocessing_main
from src.utils import *
from src.training import run_epoch_x_to_c

In [2]:
concept_labels, train_loader, test_loader = preprocessing_main(class_concepts=False, verbose=True)

Found 11788 images.
Processing in 369 batches of size 32 (for progress reporting)...


Processing batches: 100%|█████████████████████| 369/369 [00:54<00:00,  6.75it/s]



Finished processing.
Successfully transformed: 11788 images.
Found 11788 unique images.
Found 312 unique concepts.
Generated concept matrix with shape: (11788, 312)
Found 200 classes.
Found labels for 11788 images.
Generated one-hot matrix with shape: (11788, 200)
Split complete: 5994 train images, 5794 test images.
Dataset initialized with 5994 pre-sorted items.
Dataset initialized with 5794 pre-sorted items.


# Training Implementation

In [3]:
from src.utils import find_class_imbalance
from config import N_TRIMMED_CONCEPTS, N_CLASSES
import torch
import torch.nn as nn
import torch.optim as optim

**Find device to run model on (CPU or GPU).**

In [4]:
device = torch.device("cuda" if torch.cuda.is_available()
                    else "mps" if torch.backends.mps.is_available()
                    else "cpu")
print(f"Using device: {device}")

Using device: mps


**Instantiate the model.**

In [5]:
model = ModelXtoC(pretrained=True,
                freeze=True,
                n_classes=N_CLASSES, # Still needed for model structure, but won't be trained/used
                use_aux=True,
                n_concepts=N_TRIMMED_CONCEPTS)

model = model.to(device)
print("Model Instantiated (X -> C)")

Model Instantiated (X -> C)


### Loss
We use weighted loss.

`BCEWithLogitsLoss()` performs 2 steps:
1. $\sigma(x)$
    - Applies the sigmoid function to the logits to get probabilities.
2. $\text{BCE}(\sigma(x), y) = y \cdot \text{log}(\sigma(x)) + (1-y) \cdot (1-\text{log}(\sigma(x)))$
    - Compute binary cross-entropy between output probabilities ($\sigma(x)$) and ground truths ($y$)

In [6]:
use_weighted_loss = True # Set to False for simple unweighted loss

if use_weighted_loss:
    concept_weights = find_class_imbalance(concept_labels)
    attr_criterion = [nn.BCEWithLogitsLoss(weight=torch.tensor([ratio], device=device, dtype=torch.float))
                    for ratio in concept_weights]
else:
    attr_criterion = [nn.BCEWithLogitsLoss() for _ in range(N_TRIMMED_CONCEPTS)]

### Optimiser
Use same settings as used in CBM repo.

In [7]:
lr = 0.01
weight_decay = 0.00004 # same as lambda in L2-regularisation

optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                    lr=lr,
                    momentum=0.9,
                    weight_decay=weight_decay)

# scheduler_step = n -> decrease the LR every n epochs
scheduler_step = 1000
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=0.1)

print("Optimizer and Scheduler Ready")

Optimizer and Scheduler Ready


### Training and Validation Loops

In [8]:
epochs = 50
log_interval = 50

best_val_acc = 0.0

In [9]:
for epoch in range(epochs):
    print(f"--- Epoch {epoch+1}/{epochs} ---")

    # Train
    train_loss, train_acc = run_epoch_x_to_c(model, train_loader, attr_criterion, optimizer, is_training=True, use_aux=True, n_concepts=N_TRIMMED_CONCEPTS, device=device, verbose=True)
    print(f'Epoch {epoch+1} Train Summary | Loss: {train_loss:.4f} | Acc: {train_acc:.3f}')

    # Validate
    if test_loader:
        with torch.no_grad():
            val_loss, val_acc = run_epoch_x_to_c(model, test_loader, attr_criterion, optimizer, n_concepts=N_TRIMMED_CONCEPTS, device=device, verbose=True)

        print(f'Epoch {epoch+1} Val Summary   | Loss: {val_loss:.4f} | Acc: {val_acc:.3f}')

        # Save best model based on validation accuracy
        if val_acc > best_val_acc:
            print(f"Validation accuracy improved ({best_val_acc:.3f} -> {val_acc:.3f}). Saving model...")
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'x_to_c_best_model.pth')
            print("Model saved to x_to_c_best_model.pth")

    # Scheduler step
    scheduler.step()
    print(f"Current LR: {optimizer.param_groups[0]['lr']}")

--- Epoch 1/50 ---


Training:  53%|███▏  | 50/94 [01:18<00:43,  1.00it/s, acc=69.6772, loss=42.8811]

 Batch:  50/94 | Avg. Loss: 42.8811 | Avg. Acc.: 69.677 | Time: 78.30s


                                                                                

Epoch 1 Train Summary | Loss: 40.8721 | Acc: 71.515


Validation:  55%|██▏ | 50/91 [01:17<00:33,  1.22it/s, acc=79.1339, loss=23.2052]

 Batch:  50/91 | Avg. Loss: 23.2052 | Avg. Acc.: 79.134 | Time: 77.19s


                                                                                

Epoch 1 Val Summary   | Loss: 24.0218 | Acc: 78.211
Validation accuracy improved (0.000 -> 78.211). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 2/50 ---


Training:  53%|███▏  | 50/94 [03:53<03:52,  5.27s/it, acc=74.5831, loss=35.9407]

 Batch:  50/94 | Avg. Loss: 35.9407 | Avg. Acc.: 74.583 | Time: 233.71s


                                                                                

Epoch 2 Train Summary | Loss: 35.9339 | Acc: 74.588


Validation:  55%|██▏ | 50/91 [01:58<00:49,  1.20s/it, acc=79.6939, loss=22.5564]

 Batch:  50/91 | Avg. Loss: 22.5564 | Avg. Acc.: 79.694 | Time: 118.96s


                                                                                

Epoch 2 Val Summary   | Loss: 23.4376 | Acc: 78.646
Validation accuracy improved (78.211 -> 78.646). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 3/50 ---


Training:  53%|███▏  | 50/94 [03:37<05:40,  7.74s/it, acc=75.1970, loss=34.4722]

 Batch:  50/94 | Avg. Loss: 34.4722 | Avg. Acc.: 75.197 | Time: 217.57s


                                                                                

Epoch 3 Train Summary | Loss: 34.3218 | Acc: 75.323


Validation:  55%|██▏ | 50/91 [02:00<01:20,  1.96s/it, acc=80.1071, loss=21.8383]

 Batch:  50/91 | Avg. Loss: 21.8383 | Avg. Acc.: 80.107 | Time: 120.52s


                                                                                

Epoch 3 Val Summary   | Loss: 22.6509 | Acc: 79.086
Validation accuracy improved (78.646 -> 79.086). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 4/50 ---


Training:  53%|███▏  | 50/94 [03:21<02:30,  3.41s/it, acc=75.6390, loss=33.5020]

 Batch:  50/94 | Avg. Loss: 33.5020 | Avg. Acc.: 75.639 | Time: 201.47s


                                                                                

Epoch 4 Train Summary | Loss: 33.4981 | Acc: 75.783


Validation:  55%|██▏ | 50/91 [02:34<01:15,  1.84s/it, acc=80.3527, loss=22.4284]

 Batch:  50/91 | Avg. Loss: 22.4284 | Avg. Acc.: 80.353 | Time: 154.55s


                                                                                

Epoch 4 Val Summary   | Loss: 23.3858 | Acc: 79.198
Validation accuracy improved (79.086 -> 79.198). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 5/50 ---


                                                                                

KeyboardInterrupt: 