In [1]:
import os
import sys

notebook_dir = os.getcwd()
project_root_path = os.path.dirname(notebook_dir)
sys.path.insert(0, project_root_path)

from src.models import ModelXtoCResNet
from src.preprocessing.RIVAL10 import preprocessing_rival10
from src.utils import *
from src.training import run_epoch_x_to_c

In [None]:
concept_labels, train_loader, test_loader = preprocessing_rival10(training=True, class_concepts=True, verbose=True)

Found 26384 unique images.
Found 18 unique concepts.
Generated one-hot training matrix with shape: (21098, 10)
Found 21098 images.
Processing in 330 batches of size 64 (for progress reporting)...


Processing batches: 100%|█████████████████████| 330/330 [00:36<00:00,  9.07it/s]



Finished processing.
Successfully transformed: 21098 images.
Found 5286 images.
Processing in 83 batches of size 64 (for progress reporting)...


Processing batches: 100%|███████████████████████| 83/83 [00:13<00:00,  6.31it/s]



Finished processing.
Successfully transformed: 5286 images.
Dataset initialized with 21098 pre-sorted items.
Dataset initialized with 5286 pre-sorted items.


# Training Implementation

In [None]:
from src.utils import find_class_imbalance
from src.config import RIVAL10_CONFIG as config_dict
import torch
import torch.nn as nn
import torch.optim as optim

In [4]:
N_TRIMMED_CONCEPTS = config_dict['N_TRIMMED_CONCEPTS']
N_CLASSES = config_dict['N_CLASSES']

**Find device to run model on (CPU or GPU).**

In [5]:
device = torch.device("cuda" if torch.cuda.is_available()
                    else "mps" if torch.backends.mps.is_available()
                    else "cpu")
print(f"Using device: {device}")

Using device: mps


**Instantiate the model.**

In [6]:
model = ModelXtoCResNet(pretrained=True,
                freeze=True,
                n_concepts=N_TRIMMED_CONCEPTS)

model = model.to(device)
print("Model Instantiated (X -> C)")



Model Instantiated (X -> C)


### Loss
We use weighted loss.

`BCEWithLogitsLoss()` performs 2 steps:
1. $\sigma(x)$
    - Applies the sigmoid function to the logits to get probabilities.
2. $\text{BCE}(\sigma(x), y) = y \cdot \text{log}(\sigma(x)) + (1-y) \cdot (1-\text{log}(\sigma(x)))$
    - Compute binary cross-entropy between output probabilities ($\sigma(x)$) and ground truths ($y$)

In [7]:
use_weighted_loss = True # Set to False for simple unweighted loss

if use_weighted_loss:
    concept_weights = find_class_imbalance(concept_labels)
    attr_criterion = [nn.BCEWithLogitsLoss(weight=torch.tensor([ratio], device=device, dtype=torch.float))
                    for ratio in concept_weights]
else:
    attr_criterion = [nn.BCEWithLogitsLoss() for _ in range(N_TRIMMED_CONCEPTS)]

### Optimiser
Use same settings as used in CBM repo.

In [8]:
lr = 0.01
weight_decay = 0.00004 # same as lambda in L2-regularisation

optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                    lr=lr,
                    momentum=0.9,
                    weight_decay=weight_decay)

# scheduler_step = n -> decrease the LR every n epochs
scheduler_step = 1000
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=0.1)

print("Optimizer and Scheduler Ready")

Optimizer and Scheduler Ready


### Training and Validation Loops

In [9]:
epochs = 50
log_interval = 50

best_val_acc = 0.0

In [10]:
for epoch in range(epochs):
    print(f"--- Epoch {epoch+1}/{epochs} ---")

    # Train
    train_loss, train_acc = run_epoch_x_to_c(model, train_loader, attr_criterion, optimizer, is_training=True, use_aux=True, n_concepts=N_TRIMMED_CONCEPTS, device=device, verbose=True)
    print(f'Epoch {epoch+1} Train Summary | Loss: {train_loss:.4f} | Acc: {train_acc:.3f}')

    # Validate
    if test_loader:
        with torch.no_grad():
            val_loss, val_acc = run_epoch_x_to_c(model, test_loader, attr_criterion, optimizer, n_concepts=N_TRIMMED_CONCEPTS, device=device, verbose=True)

        print(f'Epoch {epoch+1} Val Summary   | Loss: {val_loss:.4f} | Acc: {val_acc:.3f}')

        # Save best model based on validation accuracy
        if val_acc > best_val_acc:
            print(f"Validation accuracy improved ({best_val_acc:.3f} -> {val_acc:.3f}). Saving model...")
            best_val_acc = val_acc
            torch.save(model, 'x_to_c_best_model.pth')
            print("Model saved to x_to_c_best_model.pth")

    # Scheduler step
    scheduler.step()
    print(f"Current LR: {optimizer.param_groups[0]['lr']}")

--- Epoch 1/50 ---


                                                                                

Epoch 1 Train Summary | Loss: 0.3839 | Acc: 96.637


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(90369) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90377) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90378) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90379) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 1 Val Summary   | Loss: 0.1848 | Acc: 98.576
Validation accuracy improved (0.000 -> 98.576). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 2/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(90436) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90437) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90438) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90439) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 2 Train Summary | Loss: 0.1876 | Acc: 98.435


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(90565) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90566) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90567) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90569) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 2 Val Summary   | Loss: 0.1527 | Acc: 98.770
Validation accuracy improved (98.576 -> 98.770). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 3/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(90584) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90585) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90586) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90587) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 3 Train Summary | Loss: 0.1654 | Acc: 98.533


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(90723) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90724) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90726) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90727) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 3 Val Summary   | Loss: 0.1453 | Acc: 98.808
Validation accuracy improved (98.770 -> 98.808). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 4/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(90754) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90755) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90766) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90769) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 4 Train Summary | Loss: 0.1536 | Acc: 98.618


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(90924) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90925) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90926) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90927) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 4 Val Summary   | Loss: 0.1373 | Acc: 98.851
Validation accuracy improved (98.808 -> 98.851). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 5/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(90962) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90964) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90965) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(90966) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 5 Train Summary | Loss: 0.1474 | Acc: 98.685


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(91023) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91024) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91025) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91027) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 5 Val Summary   | Loss: 0.1328 | Acc: 98.848
Current LR: 0.01
--- Epoch 6/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(91038) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91039) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91040) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91041) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 6 Train Summary | Loss: 0.1428 | Acc: 98.684


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(91090) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91091) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91092) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91094) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 6 Val Summary   | Loss: 0.1309 | Acc: 98.896
Validation accuracy improved (98.851 -> 98.896). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 7/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(91097) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91098) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91100) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91101) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 7 Train Summary | Loss: 0.1360 | Acc: 98.731


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(91151) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91152) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91153) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91154) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 7 Val Summary   | Loss: 0.1289 | Acc: 98.899
Validation accuracy improved (98.896 -> 98.899). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 8/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(91165) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91166) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91167) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91168) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 8 Train Summary | Loss: 0.1326 | Acc: 98.759


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(91227) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91228) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91229) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91230) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 8 Val Summary   | Loss: 0.1279 | Acc: 98.924
Validation accuracy improved (98.899 -> 98.924). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 9/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(91235) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91236) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91237) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91238) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 9 Train Summary | Loss: 0.1324 | Acc: 98.736


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(91293) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91294) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91295) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91296) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 9 Val Summary   | Loss: 0.1257 | Acc: 98.910
Current LR: 0.01
--- Epoch 10/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(91314) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91316) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91318) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91319) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 10 Train Summary | Loss: 0.1298 | Acc: 98.788


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(91525) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91532) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91533) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91535) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 10 Val Summary   | Loss: 0.1244 | Acc: 98.936
Validation accuracy improved (98.924 -> 98.936). Saving model...
Model saved to x_to_c_best_model.pth
Current LR: 0.01
--- Epoch 11/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(91584) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91585) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91586) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91587) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 11 Train Summary | Loss: 0.1283 | Acc: 98.758


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(91716) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91717) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91718) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91719) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 11 Val Summary   | Loss: 0.1260 | Acc: 98.926
Current LR: 0.01
--- Epoch 12/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(91761) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91771) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91772) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(91773) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 12 Train Summary | Loss: 0.1244 | Acc: 98.820


Validation:   0%|                                        | 0/83 [00:00<?, ?it/s]python3.11(92100) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(92105) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(92108) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(92109) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

Epoch 12 Val Summary   | Loss: 0.1253 | Acc: 98.928
Current LR: 0.01
--- Epoch 13/50 ---


Training:   0%|                                         | 0/330 [00:00<?, ?it/s]python3.11(92158) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(92159) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(92160) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python3.11(92161) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
                                                                                

KeyboardInterrupt: 

In [None]:
if test_loader:
    with torch.no_grad():

        test_loss, test_acc, outputs = run_epoch_x_to_c(
            model, test_loader, attr_criterion, optimizer=None, n_concepts=N_TRIMMED_CONCEPTS, device=device,
            return_outputs='sigmoid', verbose=True
        )

# print(f"Shuffled labels shape: {shuffled_img_labels.shape}")
print(f'Best Model Summary   | Loss: {test_loss:.4f} | Acc: {test_acc:.3f}')