<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/val_80_f1_44_(30_)_SC_MIL_Supervised_Contrastive_Multiple_Instance_Learning_for_Imbalanced_Classification_in_Pathology.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# --- Section 1: Imports and Setup ---

# Core libraries for numerical operations and data handling
import numpy as np # Used for numerical operations, especially with arrays and matrices.
import pandas as pd # Used for creating and manipulating dataframes to manage metadata.
import os # Provides a way of using operating system dependent functionality, like reading file paths.
import re # Regular expression library, crucial for parsing patient IDs from filenames.
from collections import defaultdict # A dictionary subclass that calls a factory function to supply missing values.

# Deep learning framework and utilities
import torch # The main deep learning library (PyTorch).
import torch.nn as nn # PyTorch's module for building neural networks.
import torch.optim as optim # Contains optimization algorithms like Adam.
from torch.utils.data import Dataset, DataLoader # Tools for creating custom datasets and efficient data loading.
import torchvision.transforms as transforms # Contains common image transformations for data augmentation.
from torchvision.models import shufflenet_v2_x1_0 # The pretrained ShuffleNet model used as the feature extractor.

# Image processing and visualization
from PIL import Image # Python Imaging Library, used for opening and manipulating image files.
import matplotlib.pyplot as plt # The primary library for creating plots and visualizations.
import seaborn as sns # Built on top of matplotlib for more attractive statistical graphics.

# Utilities and metrics
from tqdm import tqdm # Creates smart progress bars for loops, essential for tracking training progress.
from sklearn.model_selection import train_test_split # A utility to split datasets into train, validation, and test sets.
from sklearn.metrics import confusion_matrix, f1_score, roc_auc_score, roc_curve # Metrics for comprehensive model evaluation.

# Ensure reproducibility
import random # Python's built-in library for random number generation.
np.random.seed(42) # Set seed for NumPy for consistent random operations.
torch.manual_seed(42) # Set seed for PyTorch on the CPU.
random.seed(42) # Set seed for Python's random module.
if torch.cuda.is_available(): # Check if a GPU is available.
    torch.cuda.manual_seed_all(42) # Set seed for PyTorch on all available GPUs.

# --- End of Section 1 ---

In [2]:
# --- Section 2: Data Preprocessing and Bag Reconstruction ---

def create_bags_from_folders(data_dir, patient_id_regex=r'^(OAS\d_\d{4})_'):
    """
    Parses image filenames in class-structured folders to reconstruct patient bags.

    This function iterates through subdirectories (each representing a class) of the
    main data directory. It uses a regular expression to extract a patient ID
    from each filename and groups the file paths of all slices belonging to the
    same patient into a "bag".

    Args:
        data_dir (str): The path to the root directory containing class folders.
        patient_id_regex (str): A regular expression to extract patient IDs.
                                The first captured group is used as the ID.

    Returns:
        list: A list of all patient bags. Each bag is a dictionary containing
              the patient ID, a list of image paths, and the class label.
        dict: A mapping from class names to integer labels.
    """
    # Print the starting of the bag creation process.
    print("Starting bag reconstruction from image folders...")

    # Define the classes based on the folder names provided in the dataset description.
    class_names = ['Non Demented', 'Very mild Dementia', 'Mild Dementia', 'Moderate Dementia']

    # Create a mapping from class names to integer indices for model training.
    class_to_label = {name: i for i, name in enumerate(class_names)}

    # Use a defaultdict to conveniently group slices by patient.
    # The key will be (patient_id, class_name), value will be a list of slice paths.
    patient_slice_groups = defaultdict(list)

    # Iterate over each class directory.
    for class_name in class_names:
        # Construct the full path to the class directory.
        class_dir = os.path.join(data_dir, class_name)
        # Check if the directory exists to avoid errors.
        if not os.path.isdir(class_dir):
            # Print a warning if a class folder is not found.
            print(f"Warning: Directory not found for class '{class_name}'. Skipping.")
            # Continue to the next class name in the list.
            continue

        # List all files in the current class directory.
        for filename in os.listdir(class_dir):
            # We are only interested in JPG files.
            if filename.endswith('.jpg'):
                # Try to match the regex to extract the patient ID.
                match = re.match(patient_id_regex, filename)
                # If a match is found, proceed.
                if match:
                    # The patient ID is the first group captured by the regex.
                    patient_id = match.groups()[0]
                    # Get the full path to the image file.
                    image_path = os.path.join(class_dir, filename)
                    # Append the slice path to the corresponding patient's group.
                    patient_slice_groups[(patient_id, class_name)].append(image_path)

    # Now, convert the grouped slices into a list of bag dictionaries.
    all_bags = []
    # Iterate through the patient groups we created.
    for (patient_id, class_name), image_paths in patient_slice_groups.items():
        # Each bag is a dictionary containing its ID, all its image paths, and its label.
        all_bags.append({
            'patient_id': patient_id,          # Unique identifier for the patient.
            'image_paths': sorted(image_paths), # A sorted list of paths to all slices for this patient.
            'label': class_to_label[class_name] # The integer label corresponding to the patient's class.
        })

    # Print a summary of the bag creation process.
    print(f"Successfully created {len(all_bags)} bags from {len(patient_slice_groups)} patients.")
    # Return the list of all bags and the class-to-label mapping.
    return all_bags, class_to_label


class OASISBagDataset(Dataset):
    """
    PyTorch Dataset for loading bags of OASIS MRI slices.

    Each item returned by this dataset is a complete bag, which consists of
    all MRI slices for a single patient, stacked into a single tensor.
    """
    def __init__(self, bags, transform=None):
        """
        Args:
            bags (list): A list of bag dictionaries, from create_bags_from_folders.
            transform (callable, optional): A torchvision.transforms pipeline to be
                                            applied to each slice.
        """
        # Store the list of bag definitions.
        self.bags = bags
        # Store the transformation pipeline.
        self.transform = transform

    def __len__(self):
        """Returns the total number of bags (patients) in the dataset."""
        # The length of the dataset is the number of patients.
        return len(self.bags)

    def __getitem__(self, idx):
        """
        Retrieves one bag (all slices for one patient) from the dataset.

        Args:
            idx (int): The index of the bag to retrieve.

        Returns:
            tuple: A tuple containing:
                   - torch.Tensor: A tensor of shape (num_slices, C, H, W)
                                   containing all transformed slices for the patient.
                   - int: The label for the bag.
        """
        # Get the bag definition dictionary at the specified index.
        bag = self.bags[idx]
        # Get the label for this bag.
        label = bag['label']
        # Retrieve the list of image paths for the slices in this bag.
        image_paths = bag['image_paths']

        # Load and transform each slice in the bag.
        bag_slices = []
        # Iterate over each image path in the bag.
        for path in image_paths:
            # Open the image file using PIL.
            # Convert to RGB since pretrained models expect 3 channels. MRI is grayscale.
            image = Image.open(path).convert("RGB")
            # Apply the transformations if they are defined.
            if self.transform:
                image = self.transform(image)
            # Append the transformed slice tensor to our list.
            bag_slices.append(image)

        # Stack all slice tensors into a single tensor for the entire bag.
        # This tensor represents all instances in the bag.
        stacked_slices = torch.stack(bag_slices, dim=0)

        # Return the stacked slices and the corresponding label.
        return stacked_slices, label

# --- Main Data Loading Execution ---

# Define the path to the dataset zip file.
ZIP_FILE_PATH = '/content/drive/MyDrive/MR& TP/oaisis.zip'
# Define the directory where the dataset will be unzipped.
UNZIP_DIR = '/content/oaisis_unzipped'

# Unzip the dataset if it hasn't been already.
import zipfile
if not os.path.exists(UNZIP_DIR):
    print(f"Unzipping {ZIP_FILE_PATH} to {UNZIP_DIR}...")
    with zipfile.ZipFile(ZIP_FILE_PATH, 'r') as zip_ref:
        zip_ref.extractall(UNZIP_DIR)
    print("Unzipping complete.")
else:
    print(f"Dataset already unzipped to {UNZIP_DIR}. Skipping unzipping.")

# Define the path to the dataset directory after unzipping.
DATA_ROOT_DIR = os.path.join(UNZIP_DIR, 'Data')

# Inspect the contents of the data directory to check folder names
print(f"Inspecting contents of {DATA_ROOT_DIR}:")
try:
    for item in os.listdir(DATA_ROOT_DIR):
        item_path = os.path.join(DATA_ROOT_DIR, item)
        if os.path.isdir(item_path):
            print(f"Found directory: {item}")
            # Inspect the first directory found
            print(f"Inspecting contents of {item_path}:")
            try:
                for i, file_item in enumerate(os.listdir(item_path)):
                    print(f"  File {i+1}: {file_item}")
                    if i > 5: # Print only the first few files to avoid flooding the output
                        print("  ...")
                        break
            except Exception as e:
                print(f"Error inspecting directory {item_path}: {e}")
        else:
            print(f"Found file: {item}")
except FileNotFoundError:
    print(f"Error: Data directory not found at {DATA_ROOT_DIR}. Please check the unzipping process.")

# Define a custom transform to add Gaussian noise
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean

    def __call__(self, img):
        # Ensure the image is a PyTorch Tensor
        if not isinstance(img, torch.Tensor):
             img = transforms.ToTensor()(img)

        # Add noise
        return img + torch.randn(img.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


# Define image transformations. The paper uses a 224x224 input size.
# For training, we include data augmentation as described in the paper, plus Gaussian noise.
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),       # Resize images to 224x224 for the ShuffleNet model.
    transforms.RandomHorizontalFlip(),   # Randomly flip images horizontally for augmentation.
    transforms.RandomRotation(10),       # Randomly rotate images by up to 10 degrees.
    transforms.ToTensor(),               # Convert PIL Image to a PyTorch tensor.
    AddGaussianNoise(mean=0., std=0.30), # Add Gaussian noise (adjust std for noise level)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize with ImageNet stats.
])

# For validation and testing, we only resize, convert to tensor, and normalize.
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),       # Resize images to 224x224.
    transforms.ToTensor(),               # Convert PIL Image to a PyTorch tensor.
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize with ImageNet stats.
])

# Reconstruct the bags from the folder structure.
all_patient_bags, class_mapping = create_bags_from_folders(DATA_ROOT_DIR)

# Separate labels for stratification.
bag_labels = [bag['label'] for bag in all_patient_bags]

# Split the bags into training (60%), validation (15%), and test (25%) sets.
# First, split into training (60%) and a temporary set (40%).
train_bags, temp_bags, train_labels, temp_labels = train_test_split(
    all_patient_bags, bag_labels, test_size=0.40, random_state=42, stratify=bag_labels
)

# Split the temporary set into validation (15%) and test (25%).
# This corresponds to a 15/40 = 37.5% split of the temp set for validation.
val_bags, test_bags, _, _ = train_test_split(
    temp_bags, temp_labels, test_size=0.625, random_state=42 # Removed stratify
)

print(f"Dataset split complete: {len(train_bags)} training bags, {len(val_bags)} validation bags, {len(test_bags)} test bags.")

# Create the PyTorch Dataset objects.
train_dataset = OASISBagDataset(train_bags, transform=train_transform)
val_dataset = OASISBagDataset(val_bags, transform=val_test_transform)
test_dataset = OASISBagDataset(test_bags, transform=val_test_transform)

# Create the DataLoaders.
# We use a batch_size of 1 because each bag has a variable number of slices.
# The actual "batch" for contrastive loss will be formed manually in the training loop.
# This is a standard approach for handling variable-sized inputs in MIL.
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Print a confirmation that DataLoaders are ready.
print("DataLoaders are ready for training, validation, and testing.")

# --- End of Section 2 ---

Dataset already unzipped to /content/oaisis_unzipped. Skipping unzipping.
Inspecting contents of /content/oaisis_unzipped/Data:
Found directory: Moderate Dementia
Inspecting contents of /content/oaisis_unzipped/Data/Moderate Dementia:
  File 1: OAS1_0351_MR1_mpr-4_105.jpg
  File 2: OAS1_0351_MR1_mpr-1_125.jpg
  File 3: OAS1_0351_MR1_mpr-2_150.jpg
  File 4: OAS1_0351_MR1_mpr-4_133.jpg
  File 5: OAS1_0308_MR1_mpr-2_149.jpg
  File 6: OAS1_0351_MR1_mpr-3_112.jpg
  File 7: OAS1_0308_MR1_mpr-1_124.jpg
  ...
Found directory: Mild Dementia
Inspecting contents of /content/oaisis_unzipped/Data/Mild Dementia:
  File 1: OAS1_0031_MR1_mpr-4_135.jpg
  File 2: OAS1_0035_MR1_mpr-3_157.jpg
  File 3: OAS1_0382_MR1_mpr-2_105.jpg
  File 4: OAS1_0031_MR1_mpr-2_114.jpg
  File 5: OAS1_0035_MR1_mpr-1_130.jpg
  File 6: OAS1_0073_MR1_mpr-1_114.jpg
  File 7: OAS1_0056_MR1_mpr-4_140.jpg
  ...
Found directory: Non Demented
Inspecting contents of /content/oaisis_unzipped/Data/Non Demented:
  File 1: OAS1_0086_MR1_m

In [3]:
# --- Section 3: Model Definition ---

# Set the device for training (use GPU if available, otherwise CPU).
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") # Announce which device is being used.

class GatedAttention(nn.Module):
    """
    A Gated Attention module for aggregating instance features into a bag embedding.
    This is a common and effective variant of the attention mechanism used in MIL.
    """
    def __init__(self, input_dim, hidden_dim=128):
        """
        Args:
            input_dim (int): The dimensionality of the instance features (from the feature extractor).
            hidden_dim (int): The dimensionality of the hidden layer in the attention network.
        """
        super(GatedAttention, self).__init__() # Initialize the parent nn.Module class.
        self.input_dim = input_dim # Store the input dimension.
        self.hidden_dim = hidden_dim # Store the hidden dimension.

        # Attention network layers.
        # This part of the network learns to parameterize the attention weights.
        self.attention_V = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim), # First linear layer.
            nn.Tanh() # Tanh activation function.
        )

        # Gating network layers.
        # This part provides a gating mechanism to modulate the attention.
        self.attention_U = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim), # Gating linear layer.
            nn.Sigmoid() # Sigmoid activation for the gate.
        )

        # Final linear layer to compute the attention scores.
        self.attention_weights = nn.Linear(self.hidden_dim, 1)

    def forward(self, x):
        """
        Forward pass for the Gated Attention module.
        Args:
            x (torch.Tensor): A tensor of instance features with shape (num_instances, input_dim).
        Returns:
            torch.Tensor: A tensor of attention scores with shape (num_instances, 1).
        """
        # Calculate the two components of the gated attention.
        A_V = self.attention_V(x)  # (num_instances, hidden_dim)
        A_U = self.attention_U(x)  # (num_instances, hidden_dim)

        # Element-wise multiplication of the two components.
        gated_A = A_V * A_U  # (num_instances, hidden_dim)

        # Compute the final attention scores (unnormalized).
        attention_scores = self.attention_weights(gated_A)  # (num_instances, 1)

        # Return the computed attention scores.
        return attention_scores

class SC_MIL(nn.Module):
    """
    Implementation of the Supervised Contrastive Multiple Instance Learning (SC-MIL) model.
    """
    def __init__(self, input_dim, n_classes, projection_dim=128):
        """
        Args:
            input_dim (int): The dimension of the features from the patch feature extractor.
            n_classes (int): The number of output classes for classification.
            projection_dim (int): The output dimension of the projection head.
        """
        super(SC_MIL, self).__init__() # Initialize the parent nn.Module class.
        self.input_dim = input_dim # Store input feature dimension.
        self.n_classes = n_classes # Store number of classes.
        self.projection_dim = projection_dim # Store projection dimension.

        # 1. Feature Extractor (f)
        # We use a pretrained ShuffleNet and remove its final classifier.
        shufflenet = shufflenet_v2_x1_0(pretrained=True) # Load pretrained ShuffleNet.
        # The feature extractor is all layers except the final fully connected layer.
        self.feature_extractor = nn.Sequential(*list(shufflenet.children())[:-1])
        # Freeze the feature extractor's parameters as it is already trained on ImageNet.
        for param in self.feature_extractor.parameters():
            param.requires_grad = False # This can be set to True for end-to-end finetuning.

        # 2. Attention Module (m)
        # This module will compute the bag embedding from instance features.
        self.attention = GatedAttention(self.input_dim) # Instantiate the attention module.

        # 3. Projector Head (g) for Supervised Contrastive Loss
        # A non-linear MLP as specified in the paper.
        self.projector = nn.Sequential(
            nn.Linear(self.input_dim, self.input_dim), # First linear layer.
            nn.ReLU(), # ReLU activation.
            nn.Linear(self.input_dim, self.projection_dim) # Second linear layer to project to projection_dim.
        )

        # 4. Classifier Head (h) for Cross-Entropy Loss
        # A simple linear layer to map the bag embedding to class logits.
        self.classifier = nn.Linear(self.input_dim, self.n_classes)

    def forward(self, x):
        """
        Forward pass of the SC-MIL model.
        Args:
            x (torch.Tensor): A tensor representing a bag of instances (slices)
                              with shape (num_instances, C, H, W).
        Returns:
            tuple: A tuple containing:
                   - logits (torch.Tensor): The output class logits for CE loss.
                   - projection (torch.Tensor): The projected features for SCL loss.
        """
        # Ensure input tensor is on the correct device.
        x = x.to(device)

        # Step 1: Extract patch-level features using the feature extractor.
        # x is shaped (num_instances, C, H, W)
        instance_features = self.feature_extractor(x) # Output shape: (num_instances, input_dim, 1, 1)
        instance_features = instance_features.view(-1, self.input_dim) # Reshape to (num_instances, input_dim)

        # Step 2: Compute attention scores for each instance.
        attention_scores = self.attention(instance_features) # Output shape: (num_instances, 1)

        # Apply softmax to get normalized attention weights.
        attention_weights = torch.softmax(attention_scores, dim=0) # Shape: (num_instances, 1)

        # Step 3: Compute the bag embedding as the attention-weighted average of instance features.
        bag_embedding = torch.sum(attention_weights * instance_features, dim=0) # Shape: (input_dim)

        # Step 4: Pass the bag embedding through the two heads.
        # Classifier head -> Logits for classification
        logits = self.classifier(bag_embedding) # Shape: (n_classes)

        # Projector head -> Projection for contrastive loss
        projection = self.projector(bag_embedding) # Shape: (projection_dim)

        # The model returns both outputs, to be used by the combined loss function.
        return logits, projection

# --- Model Initialization ---

# Define model parameters based on our setup.
# The feature dimension from ShuffleNet v2 x1.0 is 1024.
FEATURE_DIM = 1024
# We have 4 classes in the OASIS dataset.
NUM_CLASSES = 4
# The projection dimension for contrastive loss, 128 is a common choice.
PROJECTION_DIM = 128

# Instantiate the SC-MIL model.
model = SC_MIL(input_dim=FEATURE_DIM, n_classes=NUM_CLASSES, projection_dim=PROJECTION_DIM)

# Move the model to the selected device (GPU or CPU).
model.to(device)

# Print a summary of the model to verify its architecture.
# Note: The full architecture is complex, this just confirms instantiation.
print("SC-MIL model has been successfully initialized.")
print(f"Feature Dimension: {model.input_dim}, Number of Classes: {model.n_classes}, Projection Dimension: {model.projection_dim}")

# --- End of Section 3 ---

Using device: cuda


Downloading: "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth" to /root/.cache/torch/hub/checkpoints/shufflenetv2_x1-5666bf0f80.pth
100%|██████████| 8.79M/8.79M [00:00<00:00, 36.8MB/s]


SC-MIL model has been successfully initialized.
Feature Dimension: 1024, Number of Classes: 4, Projection Dimension: 128


In [4]:
# --- Section 4: Training Loop with Progress Tracking ---

class SupervisedContrastiveLoss(nn.Module):
    """
    Implementation of the Supervised Contrastive Loss function from the paper.
    Reference: https://arxiv.org/abs/2004.11362
    """
    def __init__(self, temperature=0.07):
        """
        Args:
            temperature (float): A scalar hyperparameter to scale the logits.
        """
        super(SupervisedContrastiveLoss, self).__init__() # Initialize the parent class.
        self.temperature = temperature # Store the temperature value.

    def forward(self, features, labels):
        """
        Forward pass for the SCL loss.
        Args:
            features (torch.Tensor): The projected features from the model's projection head.
                                     Shape: (batch_size, projection_dim).
            labels (torch.Tensor): The ground truth labels for the features.
                                   Shape: (batch_size).
        Returns:
            torch.Tensor: A scalar tensor representing the computed SCL loss.
        """
        # Ensure the features and labels are on the correct device.
        features = features.to(device)
        labels = labels.to(device)

        # Get the batch size from the features tensor.
        batch_size = features.shape[0]
        # Reshape labels to (batch_size, 1) for broadcasting.
        labels = labels.contiguous().view(-1, 1)

        # Create a mask to identify positive pairs (where labels are the same).
        # torch.eq(a, b) computes a == b element-wise.
        # The mask will be True for pairs (i, j) where label[i] == label[j].
        mask = torch.eq(labels, labels.T).float().to(device)

        # Compute dot product between all pairs of features (cosine similarity).
        # This is the core of the contrastive calculation.
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T), # (batch_size, batch_size)
            self.temperature # Scale by the temperature hyperparameter.
        )

        # For numerical stability, subtract the maximum logit from each row.
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # Create a mask to remove self-comparisons (the diagonal).
        logits_mask = torch.ones_like(mask) - torch.eye(batch_size).to(device)
        mask = mask * logits_mask # Apply the self-comparison mask.

        # Compute the log-probabilities.
        exp_logits = torch.exp(logits) * logits_mask # Exponentiate, ignoring self-comparisons.
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # Compute the mean log-likelihood over all positive pairs.
        # We mask out the negative pairs and sum the log probabilities of positive pairs.
        # Avoid division by zero in case a class has only one sample in the batch (mask.sum(1) == 1)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)


        # The final loss is the negative of this mean.
        loss = -mean_log_prob_pos.mean()

        # Return the final scalar loss.
        return loss


from sklearn.metrics import accuracy_score

def train_and_validate(model, train_loader, val_loader, optimizer, scl_criterion, ce_criterion, total_iterations, effective_batch_size=16, temp=1.0):
    """
    A single epoch of training and validation.

    Args:
        model (nn.Module): The SC-MIL model to be trained.
        train_loader (DataLoader): DataLoader for the training data.
        val_loader (DataLoader): DataLoader for the validation data.
        optimizer (torch.optim.Optimizer): The optimizer for updating model weights.
        scl_criterion (nn.Module): The supervised contrastive loss function.
        ce_criterion (nn.Module): The cross-entropy loss function.
        total_iterations (int): Total number of training iterations for curriculum weight.
        effective_batch_size (int): The number of bags to accumulate before a backward pass.
        temp (float): The temperature for the SCL loss.

    Returns:
        tuple: A tuple containing average training loss, validation loss, validation F1 score, and validation accuracy.
    """
    # --- Training Phase ---
    model.train() # Set the model to training mode.
    train_loss = 0.0 # Initialize cumulative training loss.

    # Manually accumulate batches for contrastive loss.
    batch_projections = [] # To store projected features of bags.
    batch_logits = []      # To store logits of bags.
    batch_labels = []      # To store labels of bags.

    # Use tqdm for a progress bar over the training data.
    pbar = tqdm(train_loader, desc="Training")
    for i, (bag_slices, labels) in enumerate(pbar):
        # The dataloader returns a batch of 1 bag, so we squeeze the dimensions.
        bag_slices = bag_slices.squeeze(0) # Shape: (num_slices, C, H, W)
        labels = labels.squeeze(0)         # Shape: (1)

        # Perform a forward pass through the model.
        logits, projection = model(bag_slices)

        # Append the results to our manual batch lists.
        batch_logits.append(logits)
        batch_projections.append(projection)
        batch_labels.append(labels)

        # When the manual batch is full, perform a training step.
        if len(batch_labels) >= effective_batch_size or i == len(train_loader) - 1:
            # Convert lists to tensors.
            logits_tensor = torch.stack(batch_logits).to(device)
            projections_tensor = torch.stack(batch_projections).to(device)
            labels_tensor = torch.stack(batch_labels).to(device)

            # Calculate the two loss components.
            loss_scl = scl_criterion(projections_tensor, labels_tensor) # SCL on the batch.
            loss_ce = ce_criterion(logits_tensor, labels_tensor)       # CE on the batch.

            # Calculate the curriculum weight `beta_t` as per the paper.
            current_iter = epoch * len(train_loader) + i + 1 # Corrected current iteration calculation
            beta_t = 1.0 - (current_iter / total_iterations) # Linear decay from 1 to 0.
            beta_t = max(0.0, beta_t) # Ensure beta_t doesn't go below 0


            # Combine the losses using the curriculum weight.
            total_loss = beta_t * loss_scl + (1 - beta_t) * loss_ce

            # Backpropagation.
            optimizer.zero_grad() # Clear previous gradients.
            total_loss.backward() # Compute gradients of the total loss.
            optimizer.step()      # Update model parameters.

            # Accumulate training loss.
            train_loss += total_loss.item()

            # Clear the manual batch lists for the next accumulation.
            batch_projections, batch_logits, batch_labels = [], [], []

            # Update progress bar description.
            pbar.set_postfix({"Loss": total_loss.item()})

    # --- Validation Phase ---
    model.eval() # Set the model to evaluation mode.
    val_loss = 0.0 # Initialize cumulative validation loss.
    all_preds, all_labels = [], [] # To store predictions and labels for F1 score and accuracy.

    # Disable gradient calculations for validation.
    with torch.no_grad():
        for bag_slices, labels in tqdm(val_loader, desc="Validating"):
            bag_slices = bag_slices.squeeze(0) # Get the bag tensor.
            labels = labels.to(device) # Send labels to device.

            # Forward pass. We only need logits for validation.
            logits, _ = model(bag_slices)

            # Calculate cross-entropy loss for validation.
            loss = ce_criterion(logits.unsqueeze(0), labels)
            val_loss += loss.item()

            # Get predictions by finding the class with the highest logit.
            preds = torch.argmax(logits, dim=0)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    # Calculate average losses and metrics.
    avg_train_loss = train_loss / (len(train_loader) / effective_batch_size)
    avg_val_loss = val_loss / len(val_loader)
    val_f1 = f1_score(all_labels, all_preds, average='macro')
    val_accuracy = accuracy_score(all_labels, all_preds)

    return avg_train_loss, avg_val_loss, val_f1, val_accuracy

# --- Training Execution ---

# Hyperparameters from the paper and for our setup.
EPOCHS = 20 # Number of training epochs. Can be increased for better performance.
LEARNING_RATE = 1e-4 # Learning rate as specified in the paper.
EFFECTIVE_BATCH_SIZE = 16 # Number of bags to process before an update.
TEMPERATURE = 1.0 # Temperature for SCL, as used in the paper's experiments.

# Instantiate the optimizer.
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Instantiate the loss functions.
scl_loss_fn = SupervisedContrastiveLoss(temperature=TEMPERATURE)
ce_loss_fn = nn.CrossEntropyLoss()

# Calculate total number of iterations for the curriculum weight.
# This assumes one pass over the data per epoch.
total_training_iterations = len(train_loader) * EPOCHS

# Lists to store metrics for plotting later.
train_loss_history, val_loss_history, val_f1_history, val_accuracy_history = [], [], [], []

print("\n--- Starting Model Training ---")
# Main training loop over epochs.
for epoch in range(EPOCHS):
    # Perform one epoch of training and validation.
    avg_train_loss, avg_val_loss, val_f1, val_accuracy = train_and_validate(
        model, train_loader, val_loader, optimizer, scl_loss_fn, ce_loss_fn, total_training_iterations, EFFECTIVE_BATCH_SIZE, TEMPERATURE
    )

    # Store the metrics.
    train_loss_history.append(avg_train_loss)
    val_loss_history.append(avg_val_loss)
    val_f1_history.append(val_f1)
    val_accuracy_history.append(val_accuracy)

    # Print epoch results.
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val F1: {val_f1:.4f} | Val Accuracy: {val_accuracy:.4f}")

print("--- Training Finished ---")

# --- End of Section 4 ---


--- Starting Model Training ---


Training: 100%|██████████| 58/58 [01:26<00:00,  1.49s/it, Loss=1.95]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s]


Epoch 1/20 | Train Loss: 2.7026 | Val Loss: 1.3780 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:17<00:00,  1.33s/it, Loss=2.11]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.06it/s]


Epoch 2/20 | Train Loss: 2.6957 | Val Loss: 1.3740 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:15<00:00,  1.31s/it, Loss=2.07]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s]


Epoch 3/20 | Train Loss: 2.6297 | Val Loss: 1.3691 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:17<00:00,  1.33s/it, Loss=2.03]
Validating: 100%|██████████| 15/15 [00:13<00:00,  1.07it/s]


Epoch 4/20 | Train Loss: 2.5265 | Val Loss: 1.3651 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:16<00:00,  1.32s/it, Loss=1.65]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s]


Epoch 5/20 | Train Loss: 2.4421 | Val Loss: 1.3602 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:15<00:00,  1.31s/it, Loss=1.94]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s]


Epoch 6/20 | Train Loss: 2.4287 | Val Loss: 1.3549 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:15<00:00,  1.30s/it, Loss=1.75]
Validating: 100%|██████████| 15/15 [00:13<00:00,  1.11it/s]


Epoch 7/20 | Train Loss: 2.3523 | Val Loss: 1.3494 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:14<00:00,  1.29s/it, Loss=1.84]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.07it/s]


Epoch 8/20 | Train Loss: 2.2894 | Val Loss: 1.3439 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:20<00:00,  1.39s/it, Loss=1.8]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.01it/s]


Epoch 9/20 | Train Loss: 2.2193 | Val Loss: 1.3372 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:15<00:00,  1.31s/it, Loss=1.63]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s]


Epoch 10/20 | Train Loss: 2.1160 | Val Loss: 1.3313 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:16<00:00,  1.31s/it, Loss=1.7]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s]


Epoch 11/20 | Train Loss: 2.0721 | Val Loss: 1.3249 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:16<00:00,  1.32s/it, Loss=1.57]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.07it/s]


Epoch 12/20 | Train Loss: 1.9954 | Val Loss: 1.3166 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:16<00:00,  1.32s/it, Loss=1.59]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s]


Epoch 13/20 | Train Loss: 1.9222 | Val Loss: 1.3112 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:16<00:00,  1.32s/it, Loss=1.54]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s]


Epoch 14/20 | Train Loss: 1.8470 | Val Loss: 1.3039 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:16<00:00,  1.31s/it, Loss=1.49]
Validating: 100%|██████████| 15/15 [00:13<00:00,  1.08it/s]


Epoch 15/20 | Train Loss: 1.7686 | Val Loss: 1.2923 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:15<00:00,  1.30s/it, Loss=1.43]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.07it/s]


Epoch 16/20 | Train Loss: 1.6890 | Val Loss: 1.2853 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:15<00:00,  1.30s/it, Loss=1.37]
Validating: 100%|██████████| 15/15 [00:13<00:00,  1.08it/s]


Epoch 17/20 | Train Loss: 1.6109 | Val Loss: 1.2730 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:15<00:00,  1.29s/it, Loss=1.27]
Validating: 100%|██████████| 15/15 [00:13<00:00,  1.08it/s]


Epoch 18/20 | Train Loss: 1.5202 | Val Loss: 1.2595 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:16<00:00,  1.31s/it, Loss=1.28]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s]


Epoch 19/20 | Train Loss: 1.4497 | Val Loss: 1.2454 | Val F1: 0.4444 | Val Accuracy: 0.8000


Training: 100%|██████████| 58/58 [01:15<00:00,  1.30s/it, Loss=1.17]
Validating: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s]

Epoch 20/20 | Train Loss: 1.3604 | Val Loss: 1.2323 | Val F1: 0.4444 | Val Accuracy: 0.8000
--- Training Finished ---



