In [1]:
import os
import sys

notebook_dir = os.getcwd()
project_root_path = os.path.dirname(notebook_dir)
sys.path.insert(0, project_root_path)

from src import ModelXtoC
from scripts.preprocesing import preprocessing_main

In [2]:
concept_labels, train_loader, val_loader, test_loader = preprocessing_main(verbose=True)

Found 11788 images.
Processing in 369 batches of size 32 (for progress reporting)...


Processing batches: 100%|█████████████████████| 369/369 [00:54<00:00,  6.83it/s]



Finished processing.
Successfully transformed: 11788 images.
Found 11788 unique images.
Found 312 unique concepts.
Generated concept matrix with shape: (11788, 312)
Found 200 classes.
Found labels for 11788 images.
Generated one-hot matrix with shape: (11788, 200)
Split complete: 5994 train images, 5794 test images.
Dataset initialized with 5994 pre-sorted items.
Dataset initialized with 5794 pre-sorted items.


# TRAINING IMPLEMENTATION (IN PROGRESS)

In [3]:
from src.utils import find_class_imbalance
from src import inception_v3
from config import N_CONCEPTS, N_CLASSES
import torch
import torch.nn as nn
import torch.optim as optim
import time

In [4]:
device = torch.device("cuda" if torch.cuda.is_available()
                    else "mps" if torch.backends.mps.is_available()
                    else "cpu")
print(f"Using device: {device}")

Using device: mps


In [5]:
# Instantiate the model
model = ModelXtoC(pretrained=True,
                freeze=True,
                n_classes=N_CLASSES, # Still needed for model structure, but won't be trained/used
                use_aux=True,
                n_concepts=N_CONCEPTS)

model = model.to(device)
print("Model Instantiated (X -> C)")

Model Instantiated (X -> C)


### Loss
Original code from CBM Github uses a separate loss for each attribute

`BCEWithLogitsLoss()` performs 2 steps:
1. $\sigma(x)$
    - Applies the sigmoid function to the logits to get probabilities.
2. $\text{BCE}(\sigma(x), y) = y \cdot \text{log}(\sigma(x)) + (1-y) \cdot (1-\text{log}(\sigma(x)))$
    - Compute binary cross-entropy between output probabilities ($\sigma(x)$) and ground truths ($y$)

In [6]:
use_weighted_loss = True # Set to False for simple unweighted loss

if use_weighted_loss:
    concept_weights = find_class_imbalance(concept_labels)
    attr_criterion = [nn.BCEWithLogitsLoss(weight=torch.tensor([ratio], device=device, dtype=torch.float))
                    for ratio in concept_weights]
else:
    attr_criterion = [nn.BCEWithLogitsLoss() for _ in range(N_CONCEPTS)]

### Optimiser

In [7]:
lr = 0.01
weight_decay = 0.00004
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                    lr=lr,
                    momentum=0.9,
                    weight_decay=weight_decay)

scheduler_step = 1000 # Set large for almost constant LR, or smaller (e.g., 30, 50) for decay
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=0.1)

print("Optimizer and Scheduler Ready")

Optimizer and Scheduler Ready


### HELPERS

In [8]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def binary_accuracy(output, target):
    """
    Computes the accuracy for multiple binary predictions
    output and target are Torch tensors
    """
    target_int = target.int()

    pred = output >= 0.5

    # .eq returns a boolean tensor, summing True values gives the count of correct predictions
    correct_sum = (pred.int()).eq(target_int).sum()
    num_elements = target.numel()

    # Calculate accuracy
    accuracy = correct_sum.float() * 100.0 / num_elements

    return accuracy

### Training and Validation Loops

In [9]:
def run_epoch_x_to_c(model, loader, criterion_list, optimizer, is_training, use_aux, n_concepts):
    """
    Modified run_epoch focused ONLY on X -> C training.
    criterion_list: List of loss functions for each concept.
    """
    if is_training:
        model.train() # use dropout layers and calculates gradients
    else:
        model.eval() # sets layers like BatchNorm to use running statistics

    # Create instances of AverageMeter to track the average loss and accuracy throughout the epoch.
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()
    start_time = time.time()

    for batch_idx, data in enumerate(loader):
        inputs, concept_labels, _, _ = data # Ignore class labels (_) and image ids
        inputs, concept_labels = inputs.to(device), concept_labels.to(device)

        if is_training:
            optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        if len(outputs) == 2:
            main_outputs, aux_outputs = outputs[0], outputs[1]
        else:
            main_outputs, aux_outputs = outputs, None

        # Loss calculation
        total_loss = 0
        main_outputs_for_acc = [] # To store tensors for accuracy calculation

        for i in range(n_concepts):
            # Ensure concept_labels[:, i] is float for BCEWithLogitsLoss
            target_concepts = concept_labels[:, i].float()

            # --- Access the main output tensor for the i-th attribute ---
            output_i_tensor = main_outputs[i]
            output_i = output_i_tensor.squeeze() # Shape [N]
            main_outputs_for_acc.append(output_i_tensor) # Keep [N, 1] shape for cat

            # --- Calculate loss using the main output ---
            loss_i = criterion_list[i](output_i, target_concepts)

            # --- Optionally add auxiliary loss if training and use_aux_loss is True ---
            if is_training and use_aux and aux_outputs is not None:
                 # Access the aux output tensor for the i-th attribute
                 aux_output_i_tensor = aux_outputs[i]
                 aux_output_i = aux_output_i_tensor.squeeze() # Shape [N]
                 loss_aux = criterion_list[i](aux_output_i, target_concepts)
                 loss_i += 0.4 * loss_aux # Add weighted auxiliary loss

            total_loss += loss_i

        # Average loss over attributes in the batch
        avg_batch_loss = total_loss / n_concepts

        # Backward pass and optimization
        if is_training:
            avg_batch_loss.backward()
            optimizer.step()

        # Accuracy Calculation (using the main outputs collected)
        sigmoid_outputs = torch.sigmoid(torch.cat(main_outputs_for_acc, dim=1))
        acc = binary_accuracy(sigmoid_outputs, concept_labels.int())

        # Update meters
        loss_meter.update(avg_batch_loss.item(), inputs.size(0))
        acc_meter.update(acc.item(), inputs.size(0))

        # Logging
        if is_training and (batch_idx + 1) % log_interval == 0:
             elapsed_time = time.time() - start_time
             print(f' Batch: {batch_idx+1}/{len(loader)} | Loss: {loss_meter.val:.4f} ({loss_meter.avg:.4f}) |'
                   f' Acc: {acc_meter.val:.3f} ({acc_meter.avg:.3f}) | Time: {elapsed_time:.2f}s')
             start_time = time.time() # Reset timer

    return loss_meter.avg, acc_meter.avg

In [11]:
# --- Training Loop ---
epochs = 50 # Adjust as needed
log_interval = 50 # How often to print progress

# print("\nStarting Training Loop...")
best_val_acc = 0.0

for epoch in range(epochs):
    print(f"--- Epoch {epoch+1}/{epochs} ---")

    # Train
    train_loss, train_acc = run_epoch_x_to_c(model, train_loader, attr_criterion, optimizer, is_training=True, use_aux=True, n_concepts=N_CONCEPTS)
    print(f'Epoch {epoch+1} Train Summary | Loss: {train_loss:.4f} | Acc: {train_acc:.3f}')

    # Validate
    if test_loader: # Assuming test_loader is your validation loader
        with torch.no_grad():
            val_loss, val_acc = run_epoch_x_to_c(model, val_loader, attr_criterion, optimizer, is_training=False, use_aux=False, n_concepts=N_CONCEPTS)
        print(f'Epoch {epoch+1} Val Summary   | Loss: {val_loss:.4f} | Acc: {val_acc:.3f}')

#         # Save best model based on validation accuracy
#         if val_acc > best_val_acc:
#             print(f"Validation accuracy improved ({best_val_acc:.3f} -> {val_acc:.3f}). Saving model...")
#             best_val_acc = val_acc
#             torch.save(model.state_dict(), 'x_to_c_best_model.pth')
#             print("Model saved to x_to_c_best_model.pth")

    # Scheduler step
    scheduler.step()
    print(f"Current LR: {optimizer.param_groups[0]['lr']}")

print("\nTraining Finished.")
# Load best model for potential further use/testing
model.load_state_dict(torch.load('x_to_c_best_model.pth'))
print("Best model loaded.")

--- Epoch 1/50 ---
 Batch: 40/149 | Loss: 5.7586 (7.7479) | Acc: 86.378 (86.032) | Time: 35.14s
 Batch: 80/149 | Loss: 6.5056 (7.5115) | Acc: 88.091 (86.790) | Time: 31.49s
 Batch: 120/149 | Loss: 3.6607 (7.4355) | Acc: 88.802 (87.074) | Time: 31.88s
Epoch 1 Train Summary | Loss: 7.5188 | Acc: 87.204
Epoch 1 Val Summary   | Loss: 4.5155 | Acc: 89.581
Current LR: 0.01
--- Epoch 2/50 ---


KeyboardInterrupt: 

KeyboardInterrupt: 