<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/Supervised%E2%80%AFContrastive%E2%80%AFLearning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================================
#  Snippet 1 : Library imports, environment config, and        #
#              downloading the OASIS dataset from Kaggle       #
# ============================================================

# ---- Standard libraries ------------------------------------
import os           # For path manipulations and environment variables
import zipfile      # To unzip any compressed Kaggle files
import random       # For reproducible split shuffling
import math         # For basic mathematical operations
from pathlib import Path  # For elegant filesystem paths
%pip install pytorch_lightning

# ---- Numerical / data science ------------------------------
import numpy as np              # Core numerical operations on nd‑arrays
import pandas as pd             # Tabular data handling (metadata, splits)
from tqdm import tqdm           # Neat progress bars for loops

# ---- PyTorch ecosystem -------------------------------------
import torch                    # Deep‑learning tensor library
import torch.nn as nn           # Neural‑network layers and losses
import torch.nn.functional as F # Functional interface (activations, etc.)
from torch.utils.data import Dataset, DataLoader, Sampler # Data pipeline
import torchvision              # Vision utilities and pretrained models
import torchvision.transforms as T  # Composable image transforms

# ---- Visualisation -----------------------------------------
import matplotlib.pyplot as plt # Plotting curves and images
import seaborn as sns           # Statistical visualisation (confusion‐mat.)

# ---- KaggleHub (supplied by user) --------------------------
import kagglehub                # Helper to pull Kaggle datasets in Colab

# ---- Lightning (optional – easier multi‑GPU & LARS) --------
import pytorch_lightning as pl   # High‑level training loop framework

# ---- Environment & reproducibility -------------------------
SEED = 42                        # Global random seed
random.seed(SEED)                # Seed Python’s RNG
np.random.seed(SEED)             # Seed NumPy RNG
torch.manual_seed(SEED)          # Seed Torch (CPU) RNG
torch.cuda.manual_seed_all(SEED) # Seed Torch (GPU) RNG

# ---- Device selection --------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"  # Pick GPU if ready
print(f"Using device: {DEVICE}")                         # Inform user

# ---- Kaggle download ---------------------------------------
#  Downloads Kaggle dataset `ninadaithal/imagesoasis` to default location
oasis_path = kagglehub.dataset_download('ninadaithal/imagesoasis')  # Cache dir
print("Dataset downloaded to:", oasis_path)                          # Confirm

# Note: files are already individual .jpgs inside class folders      # Info
# End of Snippet 1                                                   # End

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.5.2-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.7.4-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.1.0->pytorch_lightning)
  Downloadi

100%|██████████| 1.23G/1.23G [00:13<00:00, 95.8MB/s]

Extracting files...





Dataset downloaded to: /root/.cache/kagglehub/datasets/ninadaithal/imagesoasis/versions/1


In [26]:
# ============================================================
#  Snippet 2 : Build metadata, patient‑level splits,           #
#              mean/std computation, transforms, datasets,    #
#              balanced samplers, and DataLoaders             #
# ============================================================

# ---------- 1. Build metadata -------------------------------------------
def build_metadata(root_dir):
    """
    Walk the dataset directory, collect filepaths, integer labels,
    and unique patient IDs.

    Filenames look like:  'OAS1_0001_slice105.jpg'
    True subject ID  ->    'OAS1_0001'  (first two underscore tokens)
    """
    paths, labels, pids = [], [], []                              # Lists to fill

    # Map each class folder to a numeric label (sorted for consistency)
    class_names = sorted(os.listdir(root_dir))                    # e.g. ['Mild Dementia', ...]
    class_map = {cls: idx for idx, cls in enumerate(class_names)} # Dict  name→idx
    print("Found classes:", class_names)                          # Debug info

    # Walk through every *.jpg in every class directory
    for cls in class_names:                                       # Loop each class
        cls_dir = Path(root_dir) / cls                            # /root/class_name
        for img_path in cls_dir.glob("*.jpg"):                    # Loop each image
            stem_parts = img_path.stem.split('_')                 # Split filename
            pid = '_'.join(stem_parts[:2])                        # 'OAS1_0001'
            paths.append(str(img_path))                           # Absolute path
            labels.append(class_map[cls])                         # Integer label
            pids.append(pid)                                      # Patient ID

    # Assemble a DataFrame for easy manipulation
    df = pd.DataFrame(
        {"filepath": paths, "label": labels, "patient_id": pids}
    )
    print(f"Total slices: {len(df)}")                             # Count slices
    print(f"Total unique patients: {df['patient_id'].nunique()}") # Count subjects
    print(df.head())                                              # Show sample rows
    return df                                                     # Return metadata

# Build metadata from oasis_path produced in Snippet 1
meta_df = build_metadata(os.path.join(oasis_path, 'Data')) # Modified to point to the correct subdirectory

# ---------- 2. Patient‑level train/val/test split -----------------------
def split_patients(meta, train_ratio=0.7, val_ratio=0.15, seed=42):
    """
    Split patient IDs (not slices!) into train, val, test sets.
    """
    rng = np.random.RandomState(seed)                             # Reproducible RNG
    pats = meta['patient_id'].unique()                            # Unique IDs
    rng.shuffle(pats)                                             # Shuffle order
    n_train = int(len(pats) * train_ratio)                        # #train IDs
    n_val   = int(len(pats) * val_ratio)                          # #val IDs
    train_ids = pats[:n_train]                                    # Slice
    val_ids   = pats[n_train:n_train + n_val]                     # Slice
    test_ids  = pats[n_train + n_val:]                            # Remaining
    return train_ids, val_ids, test_ids                           # Tuple

train_ids, val_ids, test_ids = split_patients(meta_df)            # Perform split

# Mark each row with its split label
meta_df['split'] = np.select(
    [meta_df.patient_id.isin(train_ids),
     meta_df.patient_id.isin(val_ids)],
    ['train', 'val'],
    default='test'
)

print("Split distribution:\n", meta_df['split'].value_counts())   # Verify balance

# ---------- 3. Compute global mean & std (grayscale) --------------------
from PIL import Image                                             # Pillow import

def compute_mean_std(df, sample_size=2000):
    """
    Estimate per‑channel mean/std on a random subset for normalisation.
    If df is empty (edge case), return 0.5/0.5 defaults.
    """
    if len(df) == 0:                                              # Guard
        print("Warning: Empty dataframe passed to compute_mean_std. "
              "Returning default values.")
        return 0.5, 0.5                                           # Defaults
    sample_paths = np.random.choice(df['filepath'],               # Random subset
                                    size=min(sample_size, len(df)),
                                    replace=False)
    to_tensor = T.ToTensor()                                      # Simple tensor tfm
    pixels = []                                                   # Collect tensors
    for p in tqdm(sample_paths, desc="Mean/Std"):
        img = to_tensor(Image.open(p).convert('L'))               # Load grayscale
        pixels.append(img)
    stack = torch.stack(pixels)                                   # (N,1,H,W)
    mean = stack.mean().item()                                    # Scalar mean
    std  = stack.std().item()                                     # Scalar std
    return mean, std

MRI_MEAN, MRI_STD = compute_mean_std(meta_df[meta_df.split == 'train'])
print("Dataset mean:", MRI_MEAN, "std:", MRI_STD)

# ---------- 4. Define transforms ----------------------------------------
train_tfms = T.Compose([
    T.Resize(224),                                                # Resize images
    T.RandomRotation(10),                                         # ±10° rotation
    T.RandomAffine(0, translate=(0.1, 0.1)),                      # 10 % translate
    T.RandomHorizontalFlip(),                                     # LR flip (brain symmetric)
    T.ToTensor(),                                                 # To [0,1]
    T.Normalize(MRI_MEAN, MRI_STD)                                # Normalise
])

eval_tfms = T.Compose([
    T.Resize(224),
    T.ToTensor(),
    T.Normalize(MRI_MEAN, MRI_STD)
])

# ---------- 5. Custom Dataset (2 views) ---------------------------------
class OasisSliceDataset(torch.utils.data.Dataset):
    """
    Returns *two* augmented views of each slice (for SupCon) and its label.
    """
    def __init__(self, df, transform):
        self.df = df.reset_index(drop=True)                       # Store subset df
        self.transform = transform                                # Transform pipeline

    def __len__(self):
        return len(self.df)                                       # Number of rows

    def __getitem__(self, idx):
        row = self.df.loc[idx]                                    # Get row
        img = Image.open(row.filepath).convert('L')               # Load as grayscale
        v1 = self.transform(img)                                  # Augmented view 1
        v2 = self.transform(img)                                  # Augmented view 2
        pair = torch.stack([v1, v2], dim=0)                       # Shape (2,1,H,W)
        return pair, row.label                                    # Return data

# ---------- 6. Balanced batch sampler (using WeightedRandomSampler) -----
from torch.utils.data.sampler import WeightedRandomSampler # Import the sampler
import collections # Import collections for Counter

def create_weighted_sampler(labels):
    """
    Create a WeightedRandomSampler to balance classes based on inverse frequency.
    """
    # Convert labels to standard Python integers for counting and weights dict keys
    labels_int = [int(label) for label in labels]

    # Calculate class counts using collections.Counter
    class_counts = collections.Counter(labels_int)
    print(f"Class counts (int keys): {class_counts}") # Debug print

    # Calculate inverse frequency weights using standard integer keys
    class_weights = {label: 1.0 / count for label, count in class_counts.items()}
    print(f"Class weights (int keys dict): {class_weights}") # Debug print

    # Assign weights to each sample based on its original label (which might be np.int64)
    # Convert label to int before dictionary lookup
    sample_weights = []
    for label in labels:
        # Convert label to standard int for lookup
        weight = class_weights[int(label)]
        sample_weights.append(weight)

    sample_weights = np.array(sample_weights)


    # The number of samples to draw in an epoch. Use the total number of
    # samples in the dataset to ensure each sample is seen roughly once per epoch.
    num_samples = len(labels)

    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=num_samples,
        replacement=True # Use replacement to ensure sufficient samples from small classes
    )
    return sampler


# ---------- 7. Create Dataset objects -----------------------------------
train_df = meta_df[meta_df.split == 'train']
val_df   = meta_df[meta_df.split == 'val']
test_df  = meta_df[meta_df.split == 'test']

# Safety check to ensure we actually have training data
assert len(train_df) > 0, "Train split is empty – check the splitting logic."

train_ds = OasisSliceDataset(train_df, train_tfms)
val_ds   = OasisSliceDataset(val_df,   eval_tfms)
test_ds  = OasisSliceDataset(test_df,  eval_tfms)

# ---------- 8. DataLoaders with weighted random sampler -----------------
BATCH_SIZE = 32                                                   # Reduced batch size as suggested

# Create weighted sampler for the training data
train_sampler = create_weighted_sampler(train_df.label.values)

train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,                                                 # Shuffle is done by the sampler
    sampler=train_sampler,
    num_workers=2,                                                # Reduced num_workers
    pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,                                                 # No sampler needed for validation/test
    num_workers=2,                                                # Reduced num_workers
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,                                                # Reduced num_workers
    pin_memory=True
)


print("DataLoaders ready ✔")                                       # Confirmation
# End of Snippet 2                                                   #

Found classes: ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
Total slices: 86437
Total unique patients: 347
                                            filepath  label patient_id
0  /root/.cache/kagglehub/datasets/ninadaithal/im...      0  OAS1_0291
1  /root/.cache/kagglehub/datasets/ninadaithal/im...      0  OAS1_0073
2  /root/.cache/kagglehub/datasets/ninadaithal/im...      0  OAS1_0278
3  /root/.cache/kagglehub/datasets/ninadaithal/im...      0  OAS1_0316
4  /root/.cache/kagglehub/datasets/ninadaithal/im...      0  OAS1_0122
Split distribution:
 split
train    61000
test     12871
val      12566
Name: count, dtype: int64


Mean/Std: 100%|██████████| 2000/2000 [00:03<00:00, 595.33it/s]


Dataset mean: 0.16600006818771362 std: 0.1789686381816864
Class counts (int keys): Counter({2: 47336, 3: 9516, 0: 3904, 1: 244})
Class weights (int keys dict): {0: 0.00025614754098360657, 1: 0.004098360655737705, 2: 2.112557039040054e-05, 3: 0.00010508617065994115}
DataLoaders ready ✔


In [27]:
# ============================================================
#  Snippet 3 : SupCon model = ResNet backbone + projection MLP #
#              and custom Supervised Contrastive loss          #
# ============================================================

class ProjectionHead(nn.Module):                                           # MLP head
    """
    2‑layer MLP → 128‑D contrastive embedding (SupCon paper).              # Doc
    """
    def __init__(self, in_dim=2048, hid_dim=2048, out_dim=128):            # Ctor
        super().__init__()                                                 # Super call
        self.fc1 = nn.Linear(in_dim, hid_dim)                              # First FC
        self.relu = nn.ReLU(inplace=True)                                  # Non‑linearity
        self.fc2 = nn.Linear(hid_dim, out_dim)                             # Second FC

    def forward(self, x):                                                  # Forward
        x = self.fc1(x)                                                    # FC1
        x = self.relu(x)                                                   # ReLU
        x = self.fc2(x)                                                    # FC2
        return F.normalize(x, dim=1)                                       # ℓ2 normalise

# ============================================================
#  Patched SupCon model (grayscale‑ready)                     #
#  => paste this over the old SupConNet class in Snippet 3    #
# ============================================================

class SupConNet(nn.Module):
    """
    ResNet‑50 backbone adapted for 1‑channel MRI slices,
    plus projection head (SupCon) and linear classifier.
    """
    def __init__(self, num_classes=4, in_channels=1):
        super().__init__()

        # -------- Load vanilla ResNet‑50 ---------------------
        # resnet = torchvision.models.resnet50(weights=None)          # No ImageNet weights
        # Optional: change resnet50 -> resnet18
        resnet = torchvision.models.resnet18(weights=None)          # Using ResNet18

        # -------- Patch first conv for grayscale -------------
        if in_channels != 3:
            # Save original conv to reuse its weights
            orig_conv = resnet.conv1                                 # (64, 3, 7, 7)
            # New conv: same hyper‑params but 1 input channel
            resnet.conv1 = nn.Conv2d(
                in_channels, 64, kernel_size=7, stride=2,
                padding=3, bias=False
            )
            # Weight initialise: mean across RGB channels
            with torch.no_grad():
                resnet.conv1.weight[:] = orig_conv.weight.mean(dim=1, keepdim=True)

        # -------- Register encoder / heads --------------------
        # Adjust in_dim for ProjectionHead and classifier if using ResNet18
        encoder_out_dim = 512  # ResNet18 output features
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])  # Remove FC
        self.proj = ProjectionHead(in_dim=encoder_out_dim)                      # 2‑layer MLP
        self.classifier = nn.Linear(encoder_out_dim, num_classes)               # Linear eval

    def forward(self, x, contrastive=True):
        """
        x : (N, 1, 256, 256)  MRI slice batch
        """
        feat = self.encoder(x).squeeze()                             # (N, 2048) for ResNet50, (N, 512) for ResNet18
        if contrastive:
            z = self.proj(feat)                                      # 128‑D ℓ2 norm
            return z
        else:
            # The user's code had `return self` here, which seems incorrect.
            # Based on the structure and purpose of the class, it should likely
            # return the logits from the classifier when contrastive is False.
            logits = self.classifier(feat)
            return logits

# ============================================================
#  Patched Supervised‑Contrastive loss (handles 2 views)      #
# ============================================================

class SupConLoss(nn.Module):
    """
    Supervised Contrastive Loss from Khosla et al. (2020).
    Works when `features` has `n_views` per sample (n_views = 2 here).
    """
    def __init__(self, temperature: float = 0.1):
        super().__init__()
        self.T = temperature                       # temperature τ

    def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Args
        ----
        features : tensor, shape (N, D) where N = n_views * batch_size
                   Must be ℓ2‑normalised already.
        labels   : tensor, shape (batch_size,)
                   Class indices **before** duplication.
        """
        device = features.device
        batch_size = labels.size(0)                # B
        n_views = features.size(0) // batch_size   # usually 2
        assert n_views * batch_size == features.size(0), "Mismatch n_views"

        # Build mask of positive pairs --------------------------------------
        lbl = labels.view(-1, 1)                   # (B,1)
        mask = torch.eq(lbl, lbl.T).float().to(device)  # (B,B)
        mask = mask.repeat(n_views, n_views)       # → (N,N)

        # Self‑contrast mask -------------------------------------------------
        logits_mask = torch.ones_like(mask) - torch.eye(mask.size(0)).to(device) # N*N identity

        # Compute positive log‑probabilities -------------------------------
        # Equivalent to: sim_matrix = torch.matmul(features, features.T) / self.T
        # Anchor‑positive pairs (excluding self):
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.T)
        # For numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach() # subtract max for stability

        # Apply masks and compute loss --------------------------------------
        # Remove diagonal
        logits = logits * logits_mask
        # Calculate log-probability
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # Positive log-probability
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        loss = -mean_log_prob_pos.mean()

        return loss


# Instantiate model & loss                                                 # Comment
model = SupConNet(num_classes=4).to(DEVICE)                                # Move to GPU
criterion_supcon = SupConLoss().to(DEVICE)                                 # SupCon loss
criterion_ce = nn.CrossEntropyLoss()                                       # CE loss
# End of Snippet 3                                                         # End

In [None]:
# ============================================================
#  Snippet 4 (memory‑friendly) : SupCon pre‑train + linear eval
# ============================================================

# ---- Adjust global BATCH_SIZE FIRST (in Snippet 2) ----------
# BATCH_SIZE = 64   # <= do this where loaders are created
# -------------------------------------------------------------

from torch.amp import autocast, GradScaler   # AMP utilities
from sklearn.metrics import f1_score, accuracy_score # Import metrics

EPOCHS_PRE = 10                                   # SupCon epochs
scaler = GradScaler()                             # Removed device_type='cuda'
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=EPOCHS_PRE
)

def train_supcon():
    model.train()
    print("\nStarting SupCon Pre-training...")
    for epoch in range(EPOCHS_PRE):
        epoch_loss = 0
        pbar = tqdm(train_loader, desc=f"SupCon Epoch {epoch+1}/{EPOCHS_PRE}")
        for pairs, lbl in pbar:
            pairs = pairs.to(DEVICE)              # (B,2,1,224,224)
            lbl   = lbl.to(DEVICE)
            bsz   = pairs.size(0)
            pairs = pairs.view(-1, 1, 224, 224)   # Adjusted size to 224x224

            optimizer.zero_grad(set_to_none=True) # Slightly faster

            with autocast(device_type='cuda'):    # Reverted to original autocast call
                feats = model(pairs, contrastive=True)
                loss  = criterion_supcon(feats, lbl)

            scaler.scale(loss).backward()        # Scaled backward
            scaler.step(optimizer)               # Optimiser step
            scaler.update()                      # Update scaler

            epoch_loss += loss.item() * bsz
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        scheduler.step()
        torch.cuda.empty_cache()                 # Free VRAM
        mean_loss = epoch_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{EPOCHS_PRE}: SupCon Mean Loss: {mean_loss:.4f}")

# ---- Linear classifier (also AMP) ---------------------------
def train_classifier(epochs=10, lr=1e-3):
    """
    Freeze encoder, train linear classifier on embeddings and evaluate.
    """
    print("\nStarting Linear Classifier Training...")
    for p in model.encoder.parameters():
        p.requires_grad = False
    # Ensure classifier parameters are trainable
    for p in model.classifier.parameters():
        p.requires_grad = True

    optimizer_l = torch.optim.Adam(model.classifier.parameters(), lr=lr)
    scaler_l = GradScaler() # Removed device_type='cuda'

    for ep in range(epochs):
        model.train()
        total, correct, ce_loss = 0, 0, 0
        all_preds, all_labels = [], [] # To collect predictions and labels for metrics

        pbar = tqdm(train_loader, desc=f"Linear Train Epoch {ep+1}/{epochs}")
        for pairs, lbl in pbar:
            img = pairs[:, 0].to(DEVICE) # Use first view
            lbl = lbl.to(DEVICE)

            optimizer_l.zero_grad(set_to_none=True)

            with autocast(device_type='cuda'): # Reverted to original autocast call
                logits = model(img, contrastive=False)
                loss = criterion_ce(logits, lbl)

            scaler_l.scale(loss).backward()
            scaler_l.step(optimizer_l)
            scaler_l.update()

            pred = logits.argmax(1)
            correct += (pred == lbl).sum().item()
            total += lbl.size(0)
            ce_loss += loss.item() * lbl.size(0)

            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(lbl.cpu().numpy())

        torch.cuda.empty_cache() # Free VRAM

        # Calculate training metrics
        train_acc = correct / total
        train_f1 = f1_score(all_labels, all_preds, average='weighted') # Use weighted for imbalanced data
        mean_ce_loss = ce_loss / total

        print(f"Epoch {ep+1}/{epochs}: Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}, Train Loss: {mean_ce_loss:.4f}")

        # ---- Validation step ----------------------------------------
        model.eval() # Set model to evaluation mode
        val_total, val_correct, val_loss = 0, 0, 0
        val_all_preds, val_all_labels = [], []

        with torch.no_grad(): # No gradient calculation during validation
             vbar = tqdm(val_loader, desc=f"Linear Val Epoch {ep+1}/{epochs}")
             for pairs, lbl in vbar:
                 img = pairs[:, 0].to(DEVICE)
                 lbl = lbl.to(DEVICE)

                 with autocast(device_type='cuda'):
                    logits = model(img, contrastive=False)
                    loss = criterion_ce(logits, lbl)

                 pred = logits.argmax(1)
                 val_correct += (pred == lbl).sum().item()
                 val_total += lbl.size(0)
                 val_loss += loss.item() * lbl.size(0)

                 val_all_preds.extend(pred.cpu().numpy())
                 val_all_labels.extend(lbl.cpu().numpy())

        # Calculate validation metrics
        val_acc = val_correct / val_total
        val_f1 = f1_score(val_all_labels, val_all_preds, average='weighted')
        mean_val_loss = val_loss / val_total

        print(f"Epoch {ep+1}/{epochs}: Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, Val Loss: {mean_val_loss:.4f}")

        torch.cuda.empty_cache() # Free VRAM


# ---- Run training --------------------------------------------------------
train_supcon()                                                 # Pre‑train SupCon
train_classifier()                                             # Train classifier
# End of Snippet 4                                                         # End


Starting SupCon Pre-training...


SupCon Epoch 1/10: 100%|██████████| 1907/1907 [07:52<00:00,  4.04it/s, loss=3.2305]


Epoch 1/10: SupCon Mean Loss: 4.7992


SupCon Epoch 2/10: 100%|██████████| 1907/1907 [07:42<00:00,  4.13it/s, loss=3.2290]


Epoch 2/10: SupCon Mean Loss: 4.7979


SupCon Epoch 3/10: 100%|██████████| 1907/1907 [07:43<00:00,  4.11it/s, loss=3.1054]


Epoch 3/10: SupCon Mean Loss: 4.7960


SupCon Epoch 4/10: 100%|██████████| 1907/1907 [07:49<00:00,  4.06it/s, loss=3.1147]


Epoch 4/10: SupCon Mean Loss: 4.7946


SupCon Epoch 5/10: 100%|██████████| 1907/1907 [07:54<00:00,  4.02it/s, loss=3.0588]


Epoch 5/10: SupCon Mean Loss: 4.7935


SupCon Epoch 6/10: 100%|██████████| 1907/1907 [07:48<00:00,  4.07it/s, loss=3.1135]


Epoch 6/10: SupCon Mean Loss: 4.7920


SupCon Epoch 7/10: 100%|██████████| 1907/1907 [07:41<00:00,  4.13it/s, loss=3.1100]


Epoch 7/10: SupCon Mean Loss: 4.7909


SupCon Epoch 8/10: 100%|██████████| 1907/1907 [07:40<00:00,  4.14it/s, loss=3.0563]


Epoch 8/10: SupCon Mean Loss: 4.7903


SupCon Epoch 9/10: 100%|██████████| 1907/1907 [07:36<00:00,  4.17it/s, loss=3.0729]


Epoch 9/10: SupCon Mean Loss: 4.7895


SupCon Epoch 10/10: 100%|██████████| 1907/1907 [07:40<00:00,  4.14it/s, loss=3.1199]


Epoch 10/10: SupCon Mean Loss: 4.7889

Starting Linear Classifier Training...


Linear Train Epoch 1/10: 100%|██████████| 1907/1907 [06:22<00:00,  4.98it/s]


Epoch 1/10: Train Acc: 0.3945, Train F1: 0.3874, Train Loss: 1.2663


Linear Val Epoch 1/10: 100%|██████████| 393/393 [01:01<00:00,  6.41it/s]


Epoch 1/10: Val Acc: 0.4465, Val F1: 0.5367, Val Loss: 1.2358


Linear Train Epoch 2/10: 100%|██████████| 1907/1907 [06:18<00:00,  5.03it/s]


Epoch 2/10: Train Acc: 0.4176, Train F1: 0.4094, Train Loss: 1.2413


Linear Val Epoch 2/10: 100%|██████████| 393/393 [01:02<00:00,  6.30it/s]


Epoch 2/10: Val Acc: 0.4760, Val F1: 0.5757, Val Loss: 1.1836


Linear Train Epoch 3/10: 100%|██████████| 1907/1907 [06:07<00:00,  5.19it/s]


Epoch 3/10: Train Acc: 0.4329, Train F1: 0.4248, Train Loss: 1.2260


Linear Val Epoch 3/10: 100%|██████████| 393/393 [01:02<00:00,  6.32it/s]


Epoch 3/10: Val Acc: 0.5143, Val F1: 0.6067, Val Loss: 1.0599


Linear Train Epoch 4/10: 100%|██████████| 1907/1907 [06:17<00:00,  5.05it/s]


Epoch 4/10: Train Acc: 0.4364, Train F1: 0.4295, Train Loss: 1.2189


Linear Val Epoch 4/10: 100%|██████████| 393/393 [01:04<00:00,  6.09it/s]


Epoch 4/10: Val Acc: 0.5376, Val F1: 0.6148, Val Loss: 1.0247


Linear Train Epoch 5/10:  31%|███▏      | 598/1907 [01:56<05:49,  3.75it/s]