# EPASS with SimMatch Base for Freesound Audio Tagging

This notebook implements the **EPASS (Ensemble Projectors Aided for Semi-supervised Learning)** algorithm, using **SimMatch** as the base semi-supervised framework, for audio classification on the Freesound dataset (2018).

**Core Concepts:**

1.  **SimMatch Base:** Leverages both pseudo-labeling (like FixMatch) and instance similarity matching (contrastive learning) using two strongly augmented views of unlabeled data.
2.  **EPASS Enhancement:** Instead of a single MLP projector head (mapping encoder features to embeddings for contrastive loss), EPASS uses *multiple* projector heads. The embeddings from these heads are ensembled (averaged) to produce a more robust and less biased representation.
3.  **Goal:** Train models with 20% and 80% labeled data, aiming for high accuracy, demonstrating overfitting/underfitting via plots, and saving the best overall model.

In [None]:
import os
import random
import numpy as np
import pandas as pd
import librosa
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchaudio
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm.notebook import tqdm
import itertools
import math
import copy

## 1. Configuration

In [None]:
class Config:
    def __init__(self):
        # Audio & Spectrogram Params
        self.sr = 32000           # Audio sample rate
        self.duration = 5         # Audio duration (seconds)
        self.n_mels = 128         # Number of Mel bands
        self.n_fft = 1024         # FFT size
        self.hop_length = 512     # Hop length

        # Training Params
        self.batch_size = 32      # Combined batch size (adjust per GPU memory)
        self.epochs = 50          # Number of epochs (adjust as needed for convergence/overfitting demo)
        self.lr = 3e-4            # Learning rate (Adam default often works well)
        self.num_classes = 41     # Number of classes (as per train.csv)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.seed = 42            # Random seed for reproducibility
        self.num_workers = 2      # Dataloader workers

        # Semi-Supervised Params (SimMatch + EPASS)
        self.labeled_percents = [0.2, 0.8] # Percentages of labeled data to train with [20%, 80%]
        self.val_percent = 0.1    # Percentage of *original* training data for validation
        self.mu = 7               # Ratio of unlabeled to labeled samples per batch (unlabeled_bs = mu * labeled_bs)
        self.wu = 1.0             # Unsupervised classification loss weight
        self.wc = 1.0             # Contrastive loss weight (SimMatch component)
        self.threshold = 0.95     # Confidence threshold (tau) for pseudo-labeling
        self.temperature = 0.1    # Temperature T for contrastive loss (SimMatch component)
        self.embedding_dim = 128  # Dimension of the projected embeddings
        self.num_projectors = 3   # Number of projectors for EPASS

        # SpecAugment Params (for strong augmentation)
        self.freq_mask_param = 27
        self.time_mask_param = 70 # Adjusted based on spectrogram width
        
        # Model Saving
        self.model_save_path = "best_epass_simmatch_model.pth"
        
        # Data paths (update if necessary)
        self.train_csv_path = "/kaggle/input/freesound-audio-tagging-2018/train.csv"
        self.test_csv_path = "/kaggle/input/freesound-audio-tagging-2018/test_post_competition.csv"
        self.audio_train_dir = "/kaggle/input/freesound-audio-tagging-2018/audio_train"
        self.audio_test_dir  = "/kaggle/input/freesound-audio-tagging-2018/audio_test"

config = Config()

# Seed everything for reproducibility
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f"Device: {config.device}")
print(f"Number of projectors (EPASS): {config.num_projectors}")

## 2. Audio Preprocessing & Augmentation

In [None]:
# ----------------------------
# 2.1 Audio Preprocessing Function
# ----------------------------
def preprocess_audio(path, sr=config.sr, duration=config.duration, n_mels=config.n_mels, n_fft=config.n_fft, hop_length=config.hop_length):
    try:
        y, _ = librosa.load(path, sr=sr)
        max_len = sr * duration
        # Pad or truncate to fixed length
        if len(y) < max_len:
            y = np.pad(y, (0, max_len - len(y)))
        else:
            y = y[:max_len]
        # Compute mel spectrogram
        mel = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
        mel_db = librosa.power_to_db(mel, ref=np.max)
        # Normalize to [0, 1]
        mel_min = mel_db.min()
        mel_max = mel_db.max()
        if mel_max == mel_min: # Avoid division by zero for silent clips
             return np.zeros_like(mel_db, dtype=np.float32)
        mel_norm = (mel_db - mel_min) / (mel_max - mel_min)
        return mel_norm.astype(np.float32)  # shape: (n_mels, time)
    except Exception as e:
        print(f"Error processing {path}: {e}")
        # Return a zero array or handle error appropriately
        time_steps = int(max_len / hop_length) + 1 # Approximate time steps
        return np.zeros((n_mels, time_steps), dtype=np.float32)

### 2.2 Augmentation Functions

- **Weak Augmentation:** Identity (no change).
- **Strong Augmentation:** SpecAugment (frequency and time masking).

In [None]:
# ----------------------------
# 2.2 Augmentation Functions
# ----------------------------

# Weak augmentation (identity)
def weak_augment(mel_spec):
    # Ensure input is a tensor and add channel dim
    if not isinstance(mel_spec, torch.Tensor):
        mel_spec = torch.tensor(mel_spec)
    return mel_spec.unsqueeze(0)

# Strong augmentation (SpecAugment)
spec_augment = torchaudio.transforms.SpecAugment(
    freq_masking_param=config.freq_mask_param,
    time_masking_param=config.time_mask_param,
    # Set masks to 1, as we need two strong views for SimMatch contrastive loss
    freq_mask_count=1, 
    time_mask_count=1, 
    iid_masks=True
)

def strong_augment(mel_spec):
     # Ensure input is a tensor and add channel dim
    if not isinstance(mel_spec, torch.Tensor):
        mel_spec = torch.tensor(mel_spec)
    mel_tensor = mel_spec.unsqueeze(0) 
    augmented_mel = spec_augment(mel_tensor)
    return augmented_mel

## 3. Dataset Classes

- Labeled dataset returns one weakly augmented view and the label.
- Unlabeled dataset returns one weakly augmented view and *two* differently strongly augmented views (for SimMatch contrastive loss).
- Test/Validation dataset returns one weakly augmented view (or none) and the label.

In [None]:
# ----------------------------
# 3. Dataset Classes
# ----------------------------

# Dataset for Labeled Data
class FreesoundLabeledDataset(Dataset):
    def __init__(self, df, audio_dir, label_map, transform=preprocess_audio, augment=weak_augment):
        self.df = df
        self.audio_dir = audio_dir
        self.label_map = label_map
        self.transform = transform
        self.augment = augment # Only weak augmentation needed for supervised loss
        self.fnames = df.index.tolist()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        file_path = os.path.join(self.audio_dir, fname)
        mel = self.transform(file_path)
        mel_tensor_aug = self.augment(mel) # (1, n_mels, time)
        label = self.label_map[self.df.loc[fname, 'label']]
        return mel_tensor_aug, torch.tensor(label)

# Dataset for Unlabeled Data
class FreesoundUnlabeledDataset(Dataset):
    def __init__(self, df, audio_dir, label_map, transform=preprocess_audio, weak_aug=weak_augment, strong_aug=strong_augment):
        self.df = df
        self.audio_dir = audio_dir
        self.label_map = label_map # Keep label map for potential analysis, but don't return label
        self.transform = transform
        self.weak_aug = weak_aug
        self.strong_aug = strong_aug
        self.fnames = df.index.tolist()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        file_path = os.path.join(self.audio_dir, fname)
        mel = self.transform(file_path)
        mel_tensor_weak = self.weak_aug(mel)     # For pseudo-label generation
        mel_tensor_strong1 = self.strong_aug(mel) # For classification & contrastive loss
        mel_tensor_strong2 = self.strong_aug(mel) # For contrastive loss
        # We don't return the true label for unlabeled data during training
        return mel_tensor_weak, mel_tensor_strong1, mel_tensor_strong2

# Dataset for Testing/Validation (uses weak augmentation/no augmentation)
class FreesoundEvalDataset(Dataset):
    def __init__(self, df, audio_dir, label_map, transform=preprocess_audio, augment=weak_augment):
        self.df = df
        self.audio_dir = audio_dir
        self.label_map = label_map
        self.transform = transform
        self.augment = augment # Use weak/no augment for eval consistency
        self.fnames = df.index.tolist()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        file_path = os.path.join(self.audio_dir, fname)
        mel = self.transform(file_path)
        mel_tensor = self.augment(mel) # (1, n_mels, time)
        label = self.label_map[self.df.loc[fname, 'label']]
        return mel_tensor, torch.tensor(label)

### 4. Prepare Metadata, Label Map, and Data Splits

- Load `train.csv` and `test_post_competition.csv`.
- Create the label map.
- Split the original `train.csv` data into training and validation sets.
- Within the training loop, further split the training set into labeled and unlabeled based on the current `labeled_percent`.

In [None]:
# ----------------------------
# 4. Prepare Metadata and Label Map
# ----------------------------
# Load training CSV
train_df_full = pd.read_csv(config.train_csv_path)
# Ensure fname is index for easy lookup
if 'fname' in train_df_full.columns:
    train_df_full.set_index("fname", inplace=True)

# Load test CSV (for final evaluation - assuming it has labels)
try:
    test_df = pd.read_csv(config.test_csv_path)
    if 'fname' in test_df.columns:
        test_df.set_index("fname", inplace=True)
    # Ensure test set has labels for evaluation
    if 'label' not in test_df.columns or test_df['label'].isnull().any():
         print("Warning: Test CSV does not contain labels or has missing labels. Using manually_verified column if available.")
         # Try using 'manually_verified' if 'label' is missing/incomplete 
         if 'manually_verified' in test_df.columns and test_df['manually_verified'].notnull().all():
             # Heuristic: Assume verified files are correctly labeled by filename pattern or other logic if needed.
             # This part might need competition-specific logic if labels aren't directly provided.
             # For now, let's assume the test set *is* labeled for evaluation simplicity.
             print("Test set seems labeled based on filename/verification. Proceeding with evaluation.")
         else:
             print("Cannot evaluate on test set without ground truth labels.")
             test_df = None # Disable test evaluation
    else:
        test_df = test_df.dropna(subset=['label'])
except FileNotFoundError:
    print(f"Warning: Test CSV not found at {config.test_csv_path}. Skipping test evaluation.")
    test_df = None

# Create label mapping (alphabetical order)
labels = sorted(train_df_full['label'].unique())
label_map = {label: idx for idx, label in enumerate(labels)}
idx_to_label = {idx: label for label, idx in label_map.items()}
config.num_classes = len(labels) # Update num_classes based on actual data
print(f"Number of classes: {config.num_classes}")
print(f"Labels: {labels}")

# --- Split Train/Validation --- 
# Split the *full* training data first to get a held-out validation set
# Use StratifiedShuffleSplit to ensure representative split even if we run only once
sss_val = StratifiedShuffleSplit(n_splits=1, test_size=config.val_percent, random_state=config.seed)
train_idx, val_idx = next(sss_val.split(train_df_full.index, train_df_full['label']))

train_df = train_df_full.iloc[train_idx]
val_df = train_df_full.iloc[val_idx]

print(f"\nFull training samples: {len(train_df_full)}")
print(f"Split into: Training samples: {len(train_df)}, Validation samples: {len(val_df)}")
if test_df is not None:
    print(f"Test samples: {len(test_df)}")

# We will split train_df further into labeled/unlabeled inside the training loop

### 5. Define the Model Architecture (Encoder + Classifier + EPASS Projectors)

- Use a pre-trained ResNet18 as the backbone encoder.
- Modify the first convolutional layer for 1-channel (spectrogram) input.
- Add a single linear classifier head.
- Add **multiple** MLP projector heads (EPASS).

In [None]:
# ------------------------------------------------------------------
# 5. Define the Model Architecture (Encoder + Classifier + Projectors)
# ------------------------------------------------------------------
class EpassSimMatchNet(nn.Module):
    def __init__(self, num_classes, embedding_dim, num_projectors, pretrained=True):
        super().__init__()
        # Encoder (ResNet18 base)
        base_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if pretrained else None)
        base_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.encoder = nn.Sequential(*list(base_model.children())[:-1])
        encoder_output_dim = base_model.fc.in_features # 512 for ResNet18
        
        # Classifier Head
        self.fc = nn.Linear(encoder_output_dim, num_classes)
        
        # EPASS Projector Heads (Multiple MLPs)
        self.projectors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(encoder_output_dim, encoder_output_dim), # Optional: intermediate layer
                nn.ReLU(),
                nn.Linear(encoder_output_dim, embedding_dim)
            ) for _ in range(num_projectors)
        ])
        self.num_projectors = num_projectors

    def forward(self, x):
        features = self.encoder(x)
        flat_features = torch.flatten(features, 1)
        
        # Classification logits
        logits = self.fc(flat_features)
        
        # Get embeddings from all projectors
        embeddings = [proj(flat_features) for proj in self.projectors]
        
        # Ensemble (average) embeddings for contrastive loss
        # Stack along a new dimension (e.g., dim 0), then mean
        ensembled_embedding = torch.mean(torch.stack(embeddings, dim=0), dim=0)
        
        # Return logits for classification and the *ensembled* embedding for contrastive loss
        return logits, ensembled_embedding

# Instantiate the model
model = EpassSimMatchNet(
    num_classes=config.num_classes,
    embedding_dim=config.embedding_dim,
    num_projectors=config.num_projectors
).to(config.device)

print(f"Model created with {config.num_projectors} projectors and moved to device.")
# Optional: Print model summary or number of parameters
# print(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {num_params:,}")

### 6. Create DataLoaders

- Create DataLoaders for labeled, unlabeled (dynamically sized), validation, and test sets.

In [None]:
# ----------------------------
# 6. Create DataLoaders (Helper Function)
# ----------------------------
def create_dataloaders(labeled_df, unlabeled_df, val_df, test_df, config):
    train_labeled_dataset = FreesoundLabeledDataset(labeled_df, config.audio_train_dir, label_map)
    train_unlabeled_dataset = FreesoundUnlabeledDataset(unlabeled_df, config.audio_train_dir, label_map)
    val_dataset = FreesoundEvalDataset(val_df, config.audio_train_dir, label_map)
    test_dataset = FreesoundEvalDataset(test_df, config.audio_test_dir, label_map) if test_df is not None else None

    # Calculate batch sizes based on mu ratio
    # Ensure labeled batch size is at least 1
    labeled_bs = max(1, config.batch_size // (config.mu + 1))
    unlabeled_bs = config.batch_size - labeled_bs
    print(f"  Using Labeled BS: {labeled_bs}, Unlabeled BS: {unlabeled_bs}")
    
    labeled_loader = DataLoader(train_labeled_dataset, 
                              batch_size=labeled_bs, 
                              shuffle=True, 
                              num_workers=config.num_workers, 
                              drop_last=True) # Drop last if not divisible
    
    unlabeled_loader = DataLoader(train_unlabeled_dataset, 
                                batch_size=unlabeled_bs, 
                                shuffle=True, 
                                num_workers=config.num_workers, 
                                drop_last=True) # Drop last if not divisible
                                
    val_loader = DataLoader(val_dataset, 
                            batch_size=config.batch_size, # Use full batch size for eval 
                            shuffle=False, 
                            num_workers=config.num_workers)
                            
    test_loader = DataLoader(test_dataset, 
                             batch_size=config.batch_size, 
                             shuffle=False, 
                             num_workers=config.num_workers) if test_dataset is not None else None
    
    print(f"  Loaders created. Num labeled batches/epoch: {len(labeled_loader)}, Num unlabeled batches/epoch: {len(unlabeled_loader)}")
    return labeled_loader, unlabeled_loader, val_loader, test_loader

### 7. Define Training and Evaluation Functions

- **`train_one_epoch`**: 
  - Takes model, optimizer, labeled/unlabeled loaders, loss criteria.
  - Iterates through both loaders simultaneously.
  - Calculates supervised loss (`loss_s`) on labeled data.
  - Calculates unsupervised classification loss (`loss_u`) using pseudo-labels.
  - Calculates unsupervised contrastive loss (`loss_c`) using ensembled embeddings from two strong views (SimMatch + EPASS).
  - Combines losses and performs backpropagation.
- **`evaluate`**: 
  - Standard evaluation loop using the classification head.

In [None]:
# ---------------------------------------
# 7. Training and Evaluation Functions
# ---------------------------------------

def train_one_epoch(model, optimizer, labeled_loader, unlabeled_loader, criterion_s, criterion_u, criterion_c, epoch):
    model.train()
    running_loss_s = 0.0
    running_loss_u = 0.0
    running_loss_c = 0.0
    correct_labeled = 0
    total_labeled = 0
    mask_ratios = []

    # Ensure unlabeled loader defines the epoch length
    num_batches = len(unlabeled_loader) 
    # Use cycle for the potentially smaller labeled loader
    labeled_iter = itertools.cycle(labeled_loader)
    
    train_iterator = tqdm(unlabeled_loader, total=num_batches, desc=f"Epoch {epoch+1}")

    for batch_idx, (inputs_u_w, inputs_u_s1, inputs_u_s2) in enumerate(train_iterator):
        # Get labeled data for this step
        try:
            inputs_l, labels_l = next(labeled_iter)
        except StopIteration:
            # Should not happen if unlabeled_loader is longer and drop_last=True for both
            print("Warning: Labeled loader exhausted unexpectedly.")
            continue 
            
        # Move data to device
        inputs_l, labels_l = inputs_l.to(config.device), labels_l.to(config.device)
        inputs_u_w = inputs_u_w.to(config.device)
        inputs_u_s1 = inputs_u_s1.to(config.device)
        inputs_u_s2 = inputs_u_s2.to(config.device)
        
        labeled_bs = inputs_l.size(0)
        unlabeled_bs = inputs_u_w.size(0)

        # --- Supervised Loss --- 
        logits_l, _ = model(inputs_l) # We only need logits for supervised loss
        loss_s = criterion_s(logits_l, labels_l)
        
        # --- Unsupervised Losses --- 
        # 1. Pseudo-Labeling Loss (Classification Consistency)
        with torch.no_grad():
            logits_u_w, _ = model(inputs_u_w)
            probs_u_w = torch.softmax(logits_u_w, dim=1)
            max_probs, pseudo_labels_u = torch.max(probs_u_w, dim=1)
            mask = (max_probs >= config.threshold).float()
            mask_ratios.append(mask.mean().item())

        logits_u_s1, embeddings_s1 = model(inputs_u_s1)
        loss_u_vec = criterion_u(logits_u_s1, pseudo_labels_u)
        loss_u = (loss_u_vec * mask).mean() # Apply mask

        # 2. Contrastive Loss (Instance Similarity using EPASS embeddings)
        _, embeddings_s2 = model(inputs_u_s2) # Only need embeddings for the second strong view
        
        # Normalize the ensembled embeddings (important for contrastive loss)
        embeddings_s1_norm = F.normalize(embeddings_s1, dim=1)
        embeddings_s2_norm = F.normalize(embeddings_s2, dim=1)
        
        # SimMatch Contrastive Loss Calculation (simplified version: compare s1 vs s2)
        # Calculate similarity matrix (dot product)
        sim_matrix = torch.mm(embeddings_s1_norm, embeddings_s2_norm.t()) / config.temperature
        
        # Targets: identity matrix (match corresponding augmented views)
        targets = torch.arange(unlabeled_bs).to(config.device)
        
        # Calculate cross-entropy loss (symmetric: compare s1->s2 and s2->s1)
        loss_c_vec1 = criterion_c(sim_matrix, targets)
        loss_c_vec2 = criterion_c(sim_matrix.t(), targets) # Symmetric loss
        loss_c = (loss_c_vec1 + loss_c_vec2) / 2.0
        loss_c = (loss_c * mask).mean() # Apply the same mask as pseudo-labeling loss
        
        # --- Combine Losses --- 
        total_loss = loss_s + config.wu * loss_u + config.wc * loss_c

        # --- Backpropagation and Optimization --- 
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # --- Statistics --- 
        running_loss_s += loss_s.item() * labeled_bs
        running_loss_u += loss_u.item() * unlabeled_bs # Use unlabeled_bs for unsupervised loss avg
        running_loss_c += loss_c.item() * unlabeled_bs # Use unlabeled_bs for contrastive loss avg

        preds_l = logits_l.argmax(dim=1)
        correct_labeled += (preds_l == labels_l).sum().item()
        total_labeled += labeled_bs

        # Update progress bar
        train_iterator.set_postfix(Loss=f"{total_loss.item():.4f}", Ls=f"{loss_s.item():.4f}", Lu=f"{loss_u.item():.4f}", Lc=f"{loss_c.item():.4f}", Mask=f"{np.mean(mask_ratios[-10:]):.2f}")
            
    # Calculate average losses and accuracy for the epoch
    # Use total labeled samples for Ls and total unlabeled samples processed for Lu, Lc
    total_unlabeled_processed = num_batches * unlabeled_loader.batch_size
    avg_loss_s = running_loss_s / total_labeled if total_labeled > 0 else 0
    avg_loss_u = running_loss_u / total_unlabeled_processed if total_unlabeled_processed > 0 else 0
    avg_loss_c = running_loss_c / total_unlabeled_processed if total_unlabeled_processed > 0 else 0
    acc_labeled = correct_labeled / total_labeled if total_labeled > 0 else 0
    avg_mask_ratio = np.mean(mask_ratios) if mask_ratios else 0

    return avg_loss_s, avg_loss_u, avg_loss_c, acc_labeled, avg_mask_ratio


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Evaluating", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            # Only use logits for evaluation
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    epoch_loss = running_loss / total if total > 0 else 0
    epoch_acc = correct / total if total > 0 else 0
    return epoch_loss, epoch_acc, all_preds, all_labels

### 8. Training Loop

- Initialize model, optimizer, loss functions.
- Loop through specified labeled data percentages (20%, 80%).
  - For each percentage:
    - Split the training data into labeled and unlabeled sets.
    - Create dataloaders.
    - Re-initialize model weights and optimizer for a fair comparison.
    - Loop through epochs:
      - Call `train_one_epoch`.
      - Call `evaluate` on the validation set.
      - Track history (losses, accuracies, mask ratio).
      - Check if current validation accuracy is the best *overall*.
      - If so, save the model's state dictionary and record the best accuracy and corresponding labeled percentage.

In [None]:
# ----------------------------
# 8. Training Loop 
# ----------------------------
criterion_s = nn.CrossEntropyLoss() # Supervised loss
criterion_u = nn.CrossEntropyLoss(reduction='none') # Unsupervised classification loss 
criterion_c = nn.CrossEntropyLoss(reduction='none') # Contrastive loss (applied per-sample, then masked & averaged)

best_val_acc_overall = 0.0
best_model_state = None
best_labeled_percent = -1
history = {}

for labeled_percent in config.labeled_percents:
    print(f"\n----- Training with {labeled_percent*100:.0f}% Labeled Data -----")
    history[labeled_percent] = {'train_loss_s': [], 'train_loss_u': [], 'train_loss_c': [], 
                                'train_acc_l': [], 'val_loss': [], 'val_acc': [], 'mask_ratio': []}

    # --- Create Labeled/Unlabeled Split for this run ---
    sss_label = StratifiedShuffleSplit(n_splits=1, train_size=labeled_percent, random_state=config.seed + int(labeled_percent*100))
    # Use train_df (which excludes validation data)
    labeled_idx, unlabeled_idx = next(sss_label.split(train_df.index, train_df['label']))
    labeled_df_run = train_df.iloc[labeled_idx]
    unlabeled_df_run = train_df.iloc[unlabeled_idx]
    print(f"  Labeled samples for this run: {len(labeled_df_run)}")
    print(f"  Unlabeled samples for this run: {len(unlabeled_df_run)}")

    # --- Create DataLoaders for this run ---
    labeled_loader, unlabeled_loader, val_loader, test_loader = create_dataloaders(
        labeled_df_run, unlabeled_df_run, val_df, test_df, config
    )

    # --- Re-initialize Model and Optimizer for each run ---
    print("  Re-initializing model and optimizer...")
    model = EpassSimMatchNet(
        num_classes=config.num_classes,
        embedding_dim=config.embedding_dim,
        num_projectors=config.num_projectors
    ).to(config.device)
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    # Optional: Learning rate scheduler
    # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)

    best_val_acc_run = 0.0 # Best validation accuracy for *this* run

    # --- Epoch Loop for this run ---
    for epoch in range(config.epochs):
        tr_loss_s, tr_loss_u, tr_loss_c, tr_acc_l, mask_ratio = train_one_epoch(
            model, optimizer, labeled_loader, unlabeled_loader, 
            criterion_s, criterion_u, criterion_c, epoch
        )
        val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion_s, config.device)
        
        # Optional: Step the scheduler
        # scheduler.step()

        # Log history for this run
        history[labeled_percent]['train_loss_s'].append(tr_loss_s)
        history[labeled_percent]['train_loss_u'].append(tr_loss_u)
        history[labeled_percent]['train_loss_c'].append(tr_loss_c)
        history[labeled_percent]['train_acc_l'].append(tr_acc_l)
        history[labeled_percent]['val_loss'].append(val_loss)
        history[labeled_percent]['val_acc'].append(val_acc)
        history[labeled_percent]['mask_ratio'].append(mask_ratio)

        print(f"  Epoch {epoch+1}/{config.epochs} -> "
              f"Loss S: {tr_loss_s:.4f}, Loss U: {tr_loss_u:.4f}, Loss C: {tr_loss_c:.4f}, Acc (L): {tr_acc_l:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | Mask Ratio: {mask_ratio:.3f}")

        # Check if this is the best model *overall*
        if val_acc > best_val_acc_overall:
            best_val_acc_overall = val_acc
            best_labeled_percent = labeled_percent
            best_model_state = copy.deepcopy(model.state_dict()) # Deep copy state dict
            print(f"  *** New best validation accuracy overall: {best_val_acc_overall:.4f} (from {labeled_percent*100:.0f}% run). Saving model state... ***")
            # Save immediately or just store the state dict and save at the end
            # torch.save(best_model_state, config.model_save_path)
            
# --- End of Training Loop --- 
print(f"\nFinished training across all label percentages.")
print(f"Best overall validation accuracy: {best_val_acc_overall:.4f} achieved with {best_labeled_percent*100:.0f}% labeled data.")

# Save the overall best model state if found
if best_model_state is not None:
    print(f"Saving the best overall model state to {config.model_save_path}")
    torch.save(best_model_state, config.model_save_path)
else:
    print("No best model state was saved (perhaps validation accuracy never improved?).")

# Load the best model for final evaluation
print(f"\nLoading best overall model for final evaluation...")
if best_model_state is not None:
    model = EpassSimMatchNet( # Recreate the model structure
        num_classes=config.num_classes,
        embedding_dim=config.embedding_dim,
        num_projectors=config.num_projectors
    ).to(config.device)
    model.load_state_dict(best_model_state)
    print("Best model loaded successfully.")
else:
    print("Could not load a best model state. Evaluation will use the model from the last epoch of the last run.")
    # 'model' variable still holds the last trained model 

# Ensure test_loader was created if test_df exists
if test_df is not None:
     # Need to create test_loader if it wasn't created in the last loop iteration
     # (This assumes val_df exists from the initial split)
     _, _, _, test_loader = create_dataloaders(labeled_df_run, unlabeled_df_run, val_df, test_df, config)
else:
     test_loader = None

### 9. Evaluate on Test Set and Compute Metrics

- Use the loaded best-performing model state.
- Run `evaluate` on the `test_loader` (if available).
- Print final metrics.

In [None]:
# ------------------------------------------
# 9. Evaluate on Test Set and Compute Metrics
# ------------------------------------------

if test_loader is not None and best_model_state is not None:
    print("\nEvaluating the best model on the Test Set...")
    test_loss, test_acc, y_pred_test, y_true_test = evaluate(model, test_loader, criterion_s, config.device)
    print(f"\nFinal Test Results using Best Overall Model (Val Acc: {best_val_acc_overall:.4f}, Labeled: {best_labeled_percent*100:.0f}%):")
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    print("\nClassification Report on Test Set:")
    # Use idx_to_label to get class names
    target_names = [idx_to_label[i] for i in range(config.num_classes)]
    print(classification_report(y_true_test, y_pred_test, target_names=target_names, digits=4))

    print("\nConfusion Matrix on Test Set:")
    cm = confusion_matrix(y_true_test, y_pred_test)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=False, fmt='d', xticklabels=target_names, yticklabels=target_names, cmap='Blues') # Annot=False for large matrices
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix on Test Data (Best EPASS+SimMatch Model)")
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
elif test_loader is None:
    print("\nSkipping final test evaluation as test data/labels were not available.")
else: # best_model_state is None
    print("\nSkipping final test evaluation as no best model was saved.")

### 10. Plot Training Curves

- Plot losses (Supervised, Unsupervised Classification, Contrastive) and accuracies (Train Labeled, Validation) for **each** labeled percentage run to show overfitting/underfitting trends under different supervision levels.

In [None]:
# ----------------------------
# 10. Plot Training Curves
# ----------------------------
num_runs = len(config.labeled_percents)
fig, axes = plt.subplots(num_runs, 3, figsize=(18, 6 * num_runs), squeeze=False)

for i, percent in enumerate(config.labeled_percents):
    run_history = history[percent]
    epochs_range = range(1, len(run_history['train_loss_s']) + 1)
    
    # Plot Losses
    ax = axes[i, 0]
    ax.plot(epochs_range, run_history['train_loss_s'], label='Train Loss S')
    ax.plot(epochs_range, run_history['train_loss_u'], label='Train Loss U (Class.)')
    ax.plot(epochs_range, run_history['train_loss_c'], label='Train Loss C (Contrast.)')
    ax.plot(epochs_range, run_history['val_loss'], label='Val Loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title(f'Loss Curves ({percent*100:.0f}% Labeled)')
    ax.legend()
    ax.grid(True)

    # Plot Accuracies
    ax = axes[i, 1]
    ax.plot(epochs_range, run_history['train_acc_l'], label='Train Acc (on Labeled)')
    ax.plot(epochs_range, run_history['val_acc'], label='Val Acc')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.set_title(f'Accuracy Curves ({percent*100:.0f}% Labeled)')
    ax.legend()
    ax.grid(True)
    ax.axhline(y=best_val_acc_overall if best_labeled_percent == percent else 0, color='r', linestyle='--', label=f'Best Overall Val Acc ({best_val_acc_overall:.3f})' if best_labeled_percent == percent else None)
    if best_labeled_percent == percent: ax.legend() # Show legend only if this run was best

    # Plot Mask Ratio
    ax = axes[i, 2]
    ax.plot(epochs_range, run_history['mask_ratio'], label='Pseudo-Label Mask Ratio')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Ratio')
    ax.set_title(f'Mask Ratio Curve ({percent*100:.0f}% Labeled)')
    ax.legend()
    ax.grid(True)

plt.suptitle('EPASS + SimMatch Training Progress', fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

print("\n--- Analysis of Curves ---")
print("Overfitting: Indicated if validation accuracy plateaus/decreases while training accuracy continues to rise, or if validation loss increases while training loss decreases.")
print("Underfitting: Indicated if both training and validation accuracies are low and plateau early, or if losses remain high.")
print(f"Target Accuracy (~80%): Observe if the best validation accuracy ({best_val_acc_overall:.4f}) reached the target.")