# Standard Imports

In [None]:
import os
import sys
import random
import torchaudio
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from collections import Counter, defaultdict
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, roc_curve, auc, f1_score, confusion_matrix
from tqdm import tqdm
import seaborn as sns
from torchvision.models import resnet18
from timm import create_model

# Custom Class

In [None]:
class PairedSpectrogramDataset(Dataset):
    def __init__(self, base_dir, transform=None): #Added the transfomration option
        self.transform = transform
        self.stft_paths, self.cqt_paths = [], []
        self.labels = []
        self.categories, self.machine_ids = [], []

        for machine in os.listdir(base_dir):
            machine_path = os.path.join(base_dir, machine)
            if not os.path.isdir(machine_path):
                continue
            machine_id = int(machine.split('_')[-1]) #id_00 -> 0
            for category in ['normal', 'abnormal']:
                stft_dir = os.path.join(machine_path, category, 'stft')
                cqt_dir = os.path.join(machine_path, category, 'cqt')

                if not (os.path.isdir(stft_dir) and os.path.isdir(cqt_dir)):
                    continue

                for filename in os.listdir(stft_dir):
                    if filename.endswith('.npy'):
                        stft_path = os.path.join(stft_dir, filename)
                        cqt_path = os.path.join(cqt_dir, filename)

                        self.stft_paths.append(stft_path)
                        self.cqt_paths.append(cqt_path)
                        self.labels.append(0 if category == 'normal' else 1)
                        self.machine_ids.append(machine_id)
                        self.categories.append(category)
    
    def __len__(self):
        return len(self.stft_paths)
    
    def __getitem__(self, idx):
        stft = torch.tensor(np.load(self.stft_paths[idx]), dtype=torch.float32).unsqueeze(0)
        cqt = torch.tensor(np.load(self.cqt_paths[idx]), dtype=torch.float32).unsqueeze(0)
        
        if self.transform is not None:
            stft = self.transform(stft)
            cqt = self.transform(cqt)
                         
        return {
            'stft': stft, 
            'cqt': cqt, 
            'label' :self.labels[idx],
            'machine_id': self.machine_ids[idx],
            'category': self.categories[idx],
            'stft_path': self.stft_paths[idx],
            'cqt_path' : self.cqt_paths[idx]
        }


class ComposeT:
    def __init__(self, transforms):
        self.transforms= transforms
        
    def __call__(self, x):
        for t in self.transforms:
            x = t(x)
        
        return x

class ToTensor:
    def __call__(self, x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        elif not torch.is_tensor(x):
            x = torch.tensor(x).float()
        return x


class SpecAugment:
    """
    Frequency and Time Masking on Spectrograms.  Accepts (freq, time) or (1, freq, time).
    """
    def __init__(self, freq_mask_param=15, time_mask_param=35, n_freq_masks=1, n_time_masks=1):
        self.fm = freq_mask_param
        self.tm = time_mask_param
        self.FM = torchaudio.transforms.FrequencyMasking(self.fm)
        self.TM = torchaudio.transforms.TimeMasking(self.tm)
        self.nf = n_freq_masks
        self.nt = n_time_masks

    def __call__(self, spec):
        if isinstance(spec, np.ndarray):
            spec = torch.from_numpy(spec).float()
        if spec.ndim == 2: #[F,T]
            spec = spec.unsqueeze(0) # (1,F,T)
        elif spec.ndim == 4: #(B,C,F,T)
            raise ValueError("SpecAugment expects single spectrogram, got batched input")
        
        for _ in range(self.nf): 
            spec = self.FM(spec)
        for _ in range(self.nt):
            spec = self.TM(spec)
        
        return spec

class SpecTimePitchWarp:
    """
    Approximate time-stretch / pitch-shift by scaling time/freq axes of the spectrogram.
    This is an approximation for when you have only spectrograms.
    """
    def __init__(self, max_time_scale=1.2, max_freq_scale=1.1):
        self.max_time = max_time_scale
        self.max_freq = max_freq_scale

    def _resize_and_crop(self, spec, target_f, target_t):
        _, F, T= spec.shape

        spec = spec.unsqueeze(0) #(1,C,F,T)
        spec = torch.nn.functional.interpolate(spec, size=(target_f, target_t), mode='bilinear', align_corners=False)
        spec = spec.squeeze(0)

        start_f = max(0, (spec.shape[1] - F) // 2)
        start_t = max(0, (spec.shape[2] - T) // 2)
        spec = spec[:, start_f:start_f+F, start_t:start_t+T]
        
        if spec.shape[1] < F or spec.shape[2] < T:
            pad_f = F - spec.shape[1]
            pad_t = T - spec.shape[2]
            spec = torch.nn.functional.pad(spec, (0, pad_t, 0, pad_f))
        
        return spec
    
    def __call__(self, spec):
        if isinstance(spec, np.ndarray):
            spec = torch.from_numpy(spec).float()

        # Handle shape
        if spec.ndim == 2:        # (F, T)
            spec = spec.unsqueeze(0)  # (1, F, T)
        elif spec.ndim == 4:      # (B, C, F, T) -> not supported here
            raise ValueError("SpecTimePitchWarp expects single spectrogram, got batched input")

        _, F, T = spec.shape
        t_scale = random.uniform(1.0 / self.max_time, self.max_time)
        f_scale = random.uniform(1.0 / self.max_freq, self.max_freq)
        newT = max(2, int(T * t_scale))
        newF = max(2, int(F * f_scale))

        spec = self._resize_and_crop(spec, newF, newT)
        return spec  # (C, F, T)

# Feature Extractor, CAFM

In [None]:
class STFTFrequencyAdaptiveFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = resnet18(weights=None)
        self.stem = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
            resnet.bn1,
            resnet.relu,
            resnet.maxpool
        )
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2

        self.layer3 = self._make_adaptive_layer(resnet.layer3, kernel_size = (1,7))
        self.layer4 = self._make_adaptive_layer(resnet.layer4, kernel_size = (1,15))

    def _make_adaptive_layer(self, layer, kernel_size):
        for block in layer:
            block.conv1 = nn.Conv2d(
                in_channels=block.conv1.in_channels,
                out_channels=block.conv1.out_channels,
                kernel_size=kernel_size,
                stride=block.conv1.stride,
                padding=(kernel_size[0] // 2, kernel_size[1] // 2),
                bias=False 
            )

            block.conv2 = nn.Conv2d(
                in_channels=block.conv2.in_channels,
                out_channels=block.conv2.out_channels,
                kernel_size=kernel_size,
                stride=block.conv2.stride,
                padding=(kernel_size[0] // 2, kernel_size[1] // 2),
                bias=False
            )
            
        return layer
    
    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

# CQT Feature Extractor using the mobilevit_xxs model
class CQTFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = create_model('mobilevit_xxs', pretrained=False,num_classes=0, global_pool='')

        self.model.stem.conv = nn.Conv2d( # type: ignore
            1, 16, kernel_size=3, stride=2, padding=1, bias=False
        )
    def forward(self,x):
        return self.model(x)

#Projection + Polling Block to Match the Channels, Height, and Width of the Extractor Features to Match with the Shape(4,16) with 256 Channel Size
class FeatureProjector(nn.Module):
    def __init__(self, in_channels, out_channels=256, target_hw=(4,16)):
        super().__init__()
        self.proj = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
        self.pool = nn.AdaptiveAvgPool2d(target_hw)

    def forward(self, x):
        x = self.proj(x)
        x = self.pool(x)
        return x

class CAFM(nn.Module):
    def __init__(self, dim=256):
        super().__init__()
        self.wq1= nn.Linear(dim, dim)
        self.wq2 = nn.Linear(dim, dim)
        self.wq3 = nn.Linear(dim, dim)

        self.wq4 = nn.Linear(dim, dim)
        self.wq5 = nn.Linear(dim, dim)
        self.wq6 = nn.Linear(dim, dim)

        self.softmax = nn.Softmax(dim=-1)
        self.out = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
        )

    def forward(self, stft_feat, cqt_feat):
        B, C, H, W = stft_feat.shape
        assert stft_feat.shape == cqt_feat.shape, "Shape mismatch between STFT and CQT features"

        stft_seq = stft_feat.flatten(2).transpose(1, 2) # [B, N, C]
        cqt_seq = cqt_feat.flatten(2).transpose(1, 2) # [B, N, C]

        # STFT attends to CQT
        Q1 = self.wq1(stft_seq)
        K1 = self.wq2(cqt_seq)
        V1 = self.wq3(cqt_seq)
        dk = Q1.size(-1)
        attention_1 = self.softmax(torch.bmm(Q1, K1.transpose(1, 2)) / (dk ** 0.5))
        out_1 = torch.bmm(attention_1, V1)

        # CQT attends to STFT
        Q2 = self.wq4(cqt_seq)
        K2 = self.wq5(stft_seq)
        V2 = self.wq6(stft_seq)
        attention_2 = self.softmax(torch.bmm(Q2, K2.transpose(1, 2)) / (dk ** 0.5))
        out_2 = torch.bmm(attention_2, V2)

        # Mean pool and Fuse
        fused = torch.cat([out_1.mean(1), out_2.mean(1)], dim=1)
        # print(f"Fused shape before MLP[Multi Layer Perceptron]: {fused.shape}")
        output = self.out(fused)
        # print(f"Output shape after MLP [Multi Layer Perceptron]: {output.shape}")
        return output

# Fused Model

In [None]:
class FusedModel(nn.Module):
    """
    Dual-branch model with spectral positional encoding and optional temporal smoothing decoder.

    Args:
        stft_dim (int): Output channels for STFT branch projector.
        cqt_dim (int): Output channels for CQT branch projector.
        fusion_dim (int): Output channels for fusion block.
        head (nn.Module, optional): Classification or embedding head module.
        head_mode (str): Head output mode. Default 'classifier-1'.
    """
    def __init__(self,stft_dim=512,cqt_dim=320,fusion_dim=256,head=None):
        super().__init__()
        self.head = head

        # Feature extractors
        self.stft_net = STFTFrequencyAdaptiveFeatureExtractor()
        self.cqt_net = CQTFeatureExtractor()
        self.stft_proj = FeatureProjector(stft_dim)
        self.cqt_proj = FeatureProjector(cqt_dim)
        self.fuser = CAFM(fusion_dim)
    
    def forward(self, stft, cqt):
        if stft.dim() != 4 or cqt.dim() != 4:
            raise ValueError("use_decoder=False expects [B, C, H, W] inputs")

        stft_feat = self.stft_proj(self.stft_net(stft))
        cqt_feat  = self.cqt_proj(self.cqt_net(cqt))
        if stft_feat.dim() != 4 or cqt_feat.dim() != 4:
            raise RuntimeError(
                f"Expected 4D tensors before CAFM, got stft:{stft_feat.shape}, cqt:{cqt_feat.shape}. "
                "Check FeatureProjector to ensure it does not flatten."
            )

        fused = self.fuser(stft_feat, cqt_feat) # [B, fusion_dim]
        if self.head is None:
            return fused
        return self.head(fused)    

# Head

In [None]:
class AnomalyScorer(nn.Module):
    def __init__(self, in_dim=256, dropout = 0.4, mode = 'classifier-1'):
        super().__init__()

        self.mode = mode
        self.dropout = nn.Dropout(p=dropout)

        if mode == 'classifier-1':
            self.head = nn.Sequential(
                nn.Linear(in_dim, 128),
                nn.ReLU(),
                self.dropout,
                nn.Linear(128,1), # Binary Classification
            )
        elif mode == 'prototype':
            self.prototype = nn.Parameter(torch.randn(in_dim)) # Learnable normal prototype
        
    def forward(self, x):
        if self.mode == 'classifier-1':
            return self.head(x) # logits for BCEWithLogitsLoss
        
        elif self.mode == 'prototype':
            if x.dim() == 3:
                x = x.mean(dim=1)
            
            x = self.dropout(x)
            dist = torch.norm(x - self.prototype, dim=1, keepdim=True)
            return dist

In [None]:
def calculate_pAUC(labels, preds, max_fpr = 0.1):
    """
    Calculates Partial AUC (pAUC) for a given FPR range.
    Args:
        labels (array): True binary labels.
        preds (array): Predicted probabilities for the positive class.
        max_fpr (float): Maximum False Positive Rate for pAUC calculation.
    Returns:
        float: pAUC score.
    """
    if len(np.unique(labels)) < 2:
        return float('nan')
    
    fpr, tpr, _ = roc_curve(labels, preds)
    #filter for FPR <= max_fpr
    mask = fpr <= max_fpr
    fpr_filtered, tpr_filtered = fpr[mask], tpr[mask] 
      
    if fpr_filtered.size == 0:
        return 0.0

    if fpr_filtered.max() < max_fpr:
        idx = np.where(fpr <= max_fpr)[0][-1]
        if idx + 1 < len(fpr):
            x1, y1 = fpr[idx], tpr[idx]
            x2, y2 = fpr[idx + 1], tpr[idx + 1]
            tpr_interp = y1 + (y2 - y1) * (max_fpr - x1) / (x2 - x1) if (x2 - x1) > 0 else y1
            fpr_filtered = np.append(fpr_filtered, max_fpr)
            tpr_filtered = np.append(tpr_filtered, tpr_interp)
            sort_idx = np.argsort(fpr_filtered)
            fpr_filtered = fpr_filtered[sort_idx]
            tpr_filtered = tpr_filtered[sort_idx]

    return auc(fpr_filtered, tpr_filtered) / max_fpr if len(fpr_filtered) >= 2 else 0.0


def plot_confusion_matrix(y_true, y_pred, labels, save_path, title="Confusion Matrix"):
    """
        Plots a confusion matrix for model evaluation
    Args:
        y_true (list or np.array): Ground truth labels.
        y_pred (list or np.array): Predicted labels.
        labels (list): A list of labels for the matrix axes (['Normal', 'Abnormal'])
        title (str): Title for the plot
    """
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    print(f"TP: {tp} | TN: {tn} | FP: {fp} | FN: {fn} | Precision: {precision:.4f} | Recall: {recall:.4f} | Specificity: {specificity:.4f}")
    plt.figure(figsize=(8,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title(title)
    os.makedirs(save_path, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, "Confusion Matrix.png"))
    plt.show()
    plt.close()

# Configurations 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {device}" + (f" - {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else ""))

FEATURES_DIR = os.path.abspath(r'F:\CapStone\DFCA\data\features\-6_dB_features')

CHECKPOINT_DIR = os.path.abspath(r'F:\Capstone\DFCA\checkpoints\CAFM')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")


BATCH_SIZE = 32
NUM_EPOCHS = 50
LR = 5e-5
WEIGHT_DECAY = 1e-2
PATIENCE=5

print(f"Learning Rate: {LR} | Weight decay: {WEIGHT_DECAY}")

# Helper, evaluate_model, train_model

In [None]:
def _compute_primary_probs_and_loss_from_head(outputs, labels, criterion):
    """
    Computes probabilities, predictions, and loss for a binary classifier head.
    """
    logits = outputs.squeeze()
    probs = torch.sigmoid(logits)
    loss = criterion(logits, labels.float())
    preds = (probs > 0.5).long()
    return probs, preds, loss

def evaluate_model(model, data_loader, criterion, phase="Evaluation", device=device, threshold=0.5):
    model.eval()
    running_loss = 0.0
    all_labels, all_probs = [], []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc=phase):
            # 1. Load data to the specified device
            stft, cqt, labels= batch['stft'].to(device), batch['cqt'].to(device), batch['label'].to(device)
            outputs = model(stft, cqt)

            probs, _, loss = _compute_primary_probs_and_loss_from_head(outputs, labels, criterion)

            running_loss += loss.item() * stft.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.detach().cpu().numpy())

    # --- Metrics Calculation ---
    avg_loss = running_loss / len(data_loader.dataset)
    
    # During validation, find the best F1-score and its threshold
    best_threshold = threshold
    if phase == "Validation":
        best_f1 = 0
        for thresh in np.arange(0.01, 1.0, 0.01):
            preds_at_thresh = (np.array(all_probs) > thresh).astype(int)
            f1_candidate = f1_score(all_labels, preds_at_thresh)
            if f1_candidate > best_f1:
                best_f1 = f1_candidate
                best_threshold = thresh
        print(f"Optimal Threshold found: {best_threshold:.2f} (Best F1-score: {best_f1:.4f})")

    # Use the best threshold for final predictions
    all_preds = (np.array(all_probs) > best_threshold).astype(int)
    
    # Calculate all metrics
    auc_score = roc_auc_score(all_labels, all_probs) if len(np.unique(all_labels)) > 1 else float('nan')
    acc_score = accuracy_score(all_labels, all_preds)
    bacc_score = balanced_accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)

    print(f"{phase} -> Loss: {avg_loss:.4f} | AUC: {auc_score:.4f} | ACC: {acc_score:.4f} | BACC: {bacc_score:.4f} | F1: {f1:.4f}")
    print(f"Prediction Distribution: {dict(Counter(all_preds))}")

    return avg_loss, auc_score, acc_score, bacc_score, f1, all_labels, all_probs, best_threshold

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_save_path, device=device, save_plots=True, patience=5):
    best_val_auc = -np.inf
    best_val_loss = np.inf
    best_val_bacc = -np.inf
    best_threshold = 0.5
    patience_counter = 0

    train_losses, val_losses = [],[]
    train_accs, val_accs = [], []
    train_baccs, val_baccs = [], []
    train_aucs, val_aucs = [], []
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        model.train()
        running_loss = 0.0
        all_labels, all_probs, all_preds = [], [], []

        for batch in tqdm(train_loader, desc="Train"):
            stft, cqt, labels = batch['stft'].to(device), batch['cqt'].to(device), batch['label'].to(device)
            
            optimizer.zero_grad()

            logits = model(stft, cqt)

            probs, preds, loss = _compute_primary_probs_and_loss_from_head(logits, labels, criterion)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * stft.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.detach().cpu().numpy())
            all_preds.extend(preds.detach().cpu().numpy())

        # --- End of Epoch: Calculate Training Metrics ---
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_auc = roc_auc_score(all_labels, all_probs)
        epoch_acc = accuracy_score(all_labels, all_preds)
        epoch_bacc = balanced_accuracy_score(all_labels, all_preds)

        train_losses.append(epoch_loss)
        train_aucs.append(epoch_auc)
        train_accs.append(epoch_acc)
        train_baccs.append(epoch_bacc)
        
        print(f"Train -> Loss: {epoch_loss:.4f} | AUC: {epoch_auc:.4f} | ACC: {epoch_acc:.4f} | BACC: {epoch_bacc:.4f}")

        # --- Validation Step ---
        val_loss, val_auc, val_acc, val_bacc, _, _, _, current_optimal_threshold = evaluate_model(
            model, val_loader, criterion, phase="Validation", device=device, threshold=best_threshold # type: ignore
        )

        val_losses.append(val_loss)
        val_aucs.append(val_auc)
        val_accs.append(val_acc)
        val_baccs.append(val_bacc)

        if scheduler:
            scheduler.step()

        if val_bacc > best_val_bacc:
            best_val_bacc = val_bacc
            patience_counter = 0
            bacc_path = model_save_path.replace(".pth", "_best_bacc.pth")
            torch.save(model.state_dict(), bacc_path)
            print(f"Saved Best-BACC model to {bacc_path} (val_bacc improved to {best_val_bacc:.4f}")
        else:
            patience_counter += 1
            print(f"Val BACC {val_bacc:.4f} did not improve from {best_val_bacc:.4f}. Patience: {patience_counter}/{patience}")
        
        # Save by best loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            loss_path = model_save_path.replace(".pth", "_best_loss.pth")
            torch.save(model.state_dict(), loss_path)
            print(f"Saved Best-Loss model to {loss_path} (val_loss improved to {best_val_loss:.4f})")

        # Save by best AUC
        if not np.isnan(val_auc) and val_auc > best_val_auc:
            best_val_auc = val_auc
            best_threshold = current_optimal_threshold
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved Best-AUC model to {model_save_path} (val_auc improved to {best_val_auc:.4f})")
        else:
            print(f"Val AUC {val_auc:.4f} did not improve from best {best_val_auc:.4f}")
        
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    # Plotting training history
    if save_plots:
        epochs = range(1, len(train_losses)+ 1)
        plt.figure(figsize=(18, 5))
        
        plt.subplot(1, 4, 1)
        plt.plot(epochs, train_losses, label='Train Loss')
        plt.plot(epochs, val_losses, label='Val Loss')
        plt.title('Loss')
        plt.legend()
        plt.grid(True)

        plt.subplot(1, 4, 2)
        plt.plot(epochs, train_aucs, label='Train AUC')
        plt.plot(epochs, val_aucs, label='Val AUC')
        plt.title('AUC')
        plt.legend()
        plt.grid(True)
        
        plt.subplot(1, 4, 3)
        plt.plot(epochs, train_accs, label='Train Accuracy')
        plt.plot(epochs, val_accs, label='Val Accuracy')
        plt.title('Accuracy')
        plt.legend()
        plt.grid(True)

        plt.subplot(1, 4, 4)
        plt.plot(epochs, train_baccs, label='Train BACC')
        plt.plot(epochs, val_baccs, label='Val BACC')
        plt.title('Balanced Accuracy')
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.savefig(os.path.join(os.path.dirname(model_save_path), "training_summary.png"))
        plt.show()
        plt.close()
        
    return best_threshold

# Main Pipeline

In [None]:
def main():
    # --- Configuration ---
    SEED = 42
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # --- Data Transformations ---
    train_transforms = ComposeT([
        ToTensor(),
        SpecTimePitchWarp(max_time_scale=1.1, max_freq_scale=1.1),
        SpecAugment(freq_mask_param=2, time_mask_param=2, n_freq_masks=1, n_time_masks=1),
    ])
    no_transform = ComposeT([
        ToTensor()
    ])
    
    # --- Dataset Loading ---
    print("Loading datasets for CQT spectrograms...")
    # Base paired dataset
    base_dataset = PairedSpectrogramDataset(FEATURES_DIR, transform=None)
    all_labels = [int(x) for x in base_dataset.labels]

    
    idxs = list(range(len(base_dataset)))
    train_idx, temp_idx = train_test_split(idxs, test_size=0.3, stratify=all_labels, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, stratify=[all_labels[i] for i in temp_idx], random_state=42)

    train_base = PairedSpectrogramDataset(FEATURES_DIR, transform=train_transforms)
    val_base   = PairedSpectrogramDataset(FEATURES_DIR, transform=no_transform)
    test_base  = PairedSpectrogramDataset(FEATURES_DIR, transform=no_transform)
        
    train_set = Subset(train_base, train_idx)
    val_set   = Subset(val_base,   val_idx)
    test_set  = Subset(test_base,  test_idx)

    # --- Data Loaders ---
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, drop_last=True)
    val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    print(f"Split sizes => Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")
    print("Label Distribution (Train):", Counter([int(base_dataset[i]['label']) for i in train_idx]))
    print("Label Distribution (Validation):", Counter([int(base_dataset[i]['label']) for i in val_idx]))
    print("Label Distribution (Test):", Counter([int(base_dataset[i]['label']) for i in test_idx]))

    # --- Model, Criterion, and Optimizer Setup ---
    head = AnomalyScorer(in_dim=256, dropout=0.4, mode='classifier-1')
    model = FusedModel(
        stft_dim=512, cqt_dim=320, fusion_dim=256,
        head=head
    ).to(device)

    # Use weighted Binary Cross-Entropy loss for imbalanced data
    pos_count = sum(all_labels); neg_count = len(all_labels) - pos_count
    pos_weight = torch.tensor([neg_count / (pos_count + 1e-8)], dtype=torch.float32).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    print(f"\nUsing BCEWithLogitsLoss with pos_weight: {pos_weight.item():.2f}")

    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

    model_path = os.path.join(CHECKPOINT_DIR, "best_model.pth")
    os.makedirs(os.path.dirname(model_path), exist_ok=True)

    # --- Training ---
    print("\nStarting model training...")
    best_threshold = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=NUM_EPOCHS,
        model_save_path=model_path,
        device=device,
        save_plots=True,
        patience=PATIENCE
    )
    
    # --- Final Test Evaluation ---
    print("\n--- Final Test Evaluation ---")
    model.load_state_dict(torch.load(model_path, map_location=device))

    # Evaluate using the best threshold found during validation
    print(f"Evaluating test set with optimal threshold: {best_threshold:.2f}")
    
    # The simplified evaluate_model now returns the f1 score directly
    _, test_auc, test_acc, test_bacc, test_f1, all_labels_test, all_probs_test, _ = evaluate_model(
        model=model, data_loader=test_loader, criterion=criterion, phase="Test", device=device, threshold=best_threshold
    )
    
    # We must set the threshold for the test evaluation predictions.
    # The new evaluate_model does not take threshold as an argument, so we apply it manually for the final report.
    all_preds_test = (np.array(all_probs_test) > best_threshold).astype(int)
    final_acc = accuracy_score(all_labels_test, all_preds_test)
    final_bacc = balanced_accuracy_score(all_labels_test, all_preds_test)
    final_f1 = f1_score(all_labels_test, all_preds_test)
    final_pauc = calculate_pAUC(labels=all_labels_test, preds=all_probs_test, max_fpr=0.2)

    print(f"\nFinal Test Metrics (Threshold = {best_threshold:.2f}):")
    print(f"  -> Accuracy (ACC)         : {final_acc:.4f}")
    print(f"  -> Balanced Accuracy (BACC): {final_bacc:.4f}")
    print(f"  -> AUC                    : {test_auc:.4f}")
    print(f"  -> pAUC (FPR<=0.2)        : {final_pauc:.4f}")
    print(f"  -> F1-Score               : {final_f1:.4f}")

    # --- Plotting ---
    # ROC Curve
    fpr, tpr, _ = roc_curve(y_true=all_labels_test, y_score=all_probs_test)
    plt.figure(figsize=(6, 6))
    plt.plot(fpr, tpr, lw=2, label=f"AUC = {test_auc:.4f}")
    plt.plot([0, 1], [0, 1], linestyle='--', lw=1, color='gray')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Test ROC With Optimal Threshold")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(CHECKPOINT_DIR, "roc_test_optimal.png"))
    plt.show()
    plt.close()

    class_labels = ["Normal", "Abnormal"]
    plot_confusion_matrix(y_true=all_labels_test, y_pred=all_preds_test, labels=class_labels, save_path=CHECKPOINT_DIR, title="Test Set Confusion Matrix")


if __name__ == "__main__":
    main()