# Importing the Necessary Libraries

In [None]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset
from collections import Counter, defaultdict
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, roc_curve, f1_score 

# Importing the Custom Modules

In [None]:
custom_modules_path = os.path.abspath(r'F:\Capstone\DFCA')

# Add the path to sys.path
if custom_modules_path not in sys.path:
    sys.path.append(custom_modules_path)

from utils.datasets import PairedSpectrogramDataset, WindowedPairedSpectrogramDataset
from utils.augmentations import ComposeT, ToTensor, SpecAugment, SpecTimePitchWarp
from utils.gradcam_utils import build_gradcam_for_model, run_and_save_gradcams
from utils.metrics_utils import calculate_pAUC, plot_confusion_matrix
from scripts.pretrain_pipeline import FusedModel 
from models.heads import SimpleAnomalyMLP, ComplexAnomalyMLP, EmbeddingMLP, AnomalyScorer
from models.losses import ContrastiveLoss, BinaryFocalLoss, FocalLoss

# Configuration's

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}" + (f" - {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else ""))

FEATURES_DIR = os.path.abspath(r'F:\CapStone\DFCA\data\features\-6_dB_features')
BATCH_SIZE = 32
NUM_EPOCHs = 5
LR = 5e-5
WEIGHT_DECAY = 1e-2
CHECKPOINT_DIR = os.path.abspath(r'F:\Capstone\DFCA\checkpoints')
CONTRASTIVE_MARGIN = 0.5
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

HEAD_MODE = 'mlp' # 'mlp', 'classifier', 'classifier-1', 'prototype','embedding'
EMB_DIM = 64

USE_TEMPORAL_DECODER = True
WINDOW_SIZE = 5
SEQ_LOSS_WEIGHT = 0.3

save_path = os.path.join(CHECKPOINT_DIR, '[Anomaly-With-Transformations-dropout=0.4]_MLP(5e-5)')
os.makedirs(save_path, exist_ok=True)
if not USE_TEMPORAL_DECODER:
    print(f"Learning Rate: {LR} | Weight decay: {WEIGHT_DECAY} | SEQ Loss Weight: {SEQ_LOSS_WEIGHT} | Head Mode: {HEAD_MODE}")
else:
    print(f"Learning Rate: {LR} | Weight decay: {WEIGHT_DECAY} | SEQ Loss Weight: {SEQ_LOSS_WEIGHT} | Head Mode: {HEAD_MODE} |Window size: {WINDOW_SIZE}")

# Helper Function's Evaluate Model

In [None]:
# ==================================
# Evaluation (Supports both modes)
# ==================================
def _compute_primary_probs_and_loss_from_head(head_mode, outputs, labels, criterion):
    """
        Returns: probs [B], preds [B], loss (scalar)
        Assumes outputs is:
            - logits tensor for classifier/mlp/classifier-1
            - distance/anomaly score for prototype (AnomalyScorer prototype)
            - embeddings for embedding head (handled separately!)
    """
    # ===== DEBUG PRINT ===================
    # print("Logits shape: ", outputs.shape)
    # print("Labels shape: ",labels.shape)
    # ===== DEBUG PRINT ===================
    if head_mode == "classifier":
        if outputs.ndim == 2 and outputs.shape[1] == 2:
            probs = torch.softmax(outputs, dim=1)[:, 1]
            loss = criterion(outputs, labels.long())
            preds = torch.argmax(outputs.detach(), dim=1)
        else:
            logits = outputs.squeeze()
            probs = torch.sigmoid(logits)
            loss = criterion(logits, labels.float())
            preds = (probs > 0.5).long()
    
    elif head_mode == "classifier-1":
        logits = outputs.squeeze()
        probs = torch.sigmoid(logits)
        loss = criterion(logits, labels.float())
        preds = (probs > 0.5).long()

    elif head_mode == "mlp":
        logits = outputs.squeeze(1) if outputs.ndim == 2 else outputs
        probs = torch.sigmoid(logits)
        loss = criterion(logits, labels.float())
        preds = (probs > 0.5).long()

    elif head_mode == "prototype":
        # AnomalyScorer(mode = 'prototype') returns distances (higher => more anomalous)
        scores = outputs.view(-1)
        probs = torch.sigmoid(scores) # map to [0,1]; threshold later
        loss = criterion(scores, labels.float()) # BCE/BinaryFocalLoss
        preds = (probs > 0.5).long()

    else:
        raise ValueError(f"Unsupported head_mode in primary loss: {head_mode}")
    
    return probs, preds, loss

def _temporal_aux_loss(seq_scores, labels, criterions_for_seq):
    """
        seq_scores: (B, T) raw logits from TemporalSmoothingDEcoder (Linear output)
        labels: (B, ) => expand to (B, T)
        criterion_for_seq: BCEWithLogitsLoss (or similar) for temporal smoothing
    """
    if seq_scores.ndim == 2:
        B, T = seq_scores.shape
        labels_T = labels.float().unsqueeze(1).expand(B, T)
    
    elif seq_scores.ndim == 1:
        B = seq_scores.shape[0]
        T = 1
        seq_scores = seq_scores.unsqueeze(1)
        labels_T = labels.float().unsqueeze(1)
    
    else:
        raise ValueError(f"Unexpected seq_scores shape {seq_scores.shape}")
    
    aux_loss = criterions_for_seq(seq_scores, labels_T)
    # derive a sequence-level probability for metrics by averaging sigmoid(seq_scores)
    seq_probs = torch.sigmoid(seq_scores).mean(dim=1)
    
    return aux_loss, seq_probs

# Evaluate and Train Function's

In [None]:
def evaluate_model(model, data_loader, criterion, phase="Evaluation", device=device, head_mode = 'classifier', sample_count=10, threshold=0.5, use_temporal=False, aux_seq_weight=SEQ_LOSS_WEIGHT):
    model.eval()
    running_loss = 0.0
    all_labels, all_probs = [], []
    best_threshold = threshold
    f1 = 0.0

    # For temporal aux loss
    seq_criterion = nn.BCEWithLogitsLoss()
    
    class_counts = {0: 0, 1:0}

    with torch.no_grad():
        for batch  in tqdm(data_loader, desc=phase):
            stft = batch['stft'].to(device)
            cqt = batch['cqt'].to(device)
            labels = batch['label'].to(device).long()

            for lbl in labels.cpu().numpy().flatten():
                class_counts[int(lbl.item())] += 1
            
            if use_temporal:
                head_out, seq_scores = model(stft,cqt) # head_out: [B, ?], seq_scores: (B, T)
                # primary (from head_out)
                if head_mode == "embedding":
                    # Embedding path uses ContrastiveLoss
                    embeddings = head_out
                    normal_proto = model.head.normal_prototype
                    embeddings = F.normalize(embeddings, dim=1)
                    normal_proto = F.normalize(normal_proto, dim=0)
                    cos_sim = torch.sum(embeddings * normal_proto.unsqueeze(0).expand_as(embeddings), dim=1)
                    probs_primary = 1 - cos_sim
                    primary_loss = criterion(embeddings, normal_proto, labels) if isinstance(criterion, ContrastiveLoss) \
                                    else criterion(probs_primary, labels.float())
                
                else:
                    probs_primary, preds_primary, primary_loss = _compute_primary_probs_and_loss_from_head(
                        head_mode, head_out, labels, criterion
                    )
                    # auxiliary temporal smoothing
                    aux_loss, seq_probs = _temporal_aux_loss(seq_scores, labels, seq_criterion)

                    # Merge probs for metric (bend primary with sequence; kep primary dominant)
                    probs = 0.7 * probs_primary + 0.3 * seq_probs
                    loss = primary_loss + aux_seq_weight * aux_loss
            
            else:
                outputs = model(stft, cqt)

                if head_mode == "embedding":
                    embeddings = outputs
                    normal_proto = model.head.normal_prototype
                    embeddings = F.normalize(embeddings, dim=1)
                    normal_proto = F.normalize(normal_proto, dim=0)
                    cos_sim = torch.sum(embeddings * normal_proto.unsqueeze(0).expand_as(embeddings), dim=1)
                    probs = 1 - cos_sim
                    loss = criterion(embeddings, normal_proto, labels) if isinstance(criterion, ContrastiveLoss) \
                           else criterion(probs, labels.float())

                else:
                    probs, preds_tmp, loss = _compute_primary_probs_and_loss_from_head(
                        head_mode, outputs, labels, criterion
                    )

            running_loss += loss.item() * stft.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.detach().cpu().numpy())

    print(f"[DEBUG] {phase} label counts: {class_counts}")

    # Optimal threshold sweep on Validation
    f1 = 0.0
    if phase == "Validation":
        best_f1 = 0
        current_optimal_threshold = 0.5
        for thresh in np.arange(0.01, 1.0, 0.01):
            predictions_thresh = (np.array(all_probs) > thresh).astype(int)
            f1_candidate = f1_score(all_labels, predictions_thresh)

            if f1_candidate > best_f1:
                best_f1 = f1_candidate
                current_optimal_threshold = thresh
        best_threshold = current_optimal_threshold
        f1 = best_f1
        print(f"Optimal Threshold (F1-score): {best_threshold:.2f}")
        print(f"Best F1-score on Validation Set: {best_f1:.4f}")

    # Metrics under chosen threshold
    all_preds = (np.array(all_probs) > best_threshold).astype(int)
    if phase != "Validation":
        if len(np.unique(all_labels)) > 1:
            f1 = f1_score(all_labels, all_preds)
        else:
            f1 = 0.0

    avg_loss = running_loss / len(data_loader.dataset)
    auc_score = roc_auc_score(all_labels, all_probs) if len(np.unique(all_labels)) > 1 else float('nan')
    acc_score = accuracy_score(all_labels, all_preds)
    bacc_score = balanced_accuracy_score(all_labels, all_preds)

    print(f"{phase} Loss: {avg_loss:.4f} | {phase} AUC: {auc_score:.4f} | {phase} ACC: {acc_score:.4f} | {phase} BACC: {bacc_score:.4f}")
    print(f"[DEBUG] {phase} Prediction Distribution: {dict(Counter(all_preds))}")
    print(f"[DEBUG] {phase} Label Distribution: {dict(Counter(all_labels))}")
    print("==================== Misclassification & Samples ====================")
    errors = [(i, p, pr, l) for i, (p, pr, l) in enumerate(zip(all_preds, all_probs, all_labels)) if p != l]
    print(f"{phase} Misclassified Samples: {len(errors)} / {len(all_labels)}")
    print("\nSample Predictions Vs Lables:")
    for i in range(min(10, len(all_labels))):
        print(f"Sample {i+1}: Pred = {all_preds[i]}, Prob = {all_probs[i]:.4f}, True = {all_labels[i]}")
    print("=====================================================================")

    return avg_loss, auc_score, acc_score, bacc_score, f1, all_labels, all_probs, best_threshold

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, head_mode, schedular=None, num_epochs=5, model_save_path="best_model.pth", device=device, save_plots=True, use_temporal=False, aux_seq_weight=SEQ_LOSS_WEIGHT):
    best_val_auc = -np.inf
    best_val_loss = np.inf
    current_threshold = 0.5
    best_threshold = 0.5

    train_losses, val_losses = [], []
    train_aucs, val_aucs = [], []
    train_accs, val_accs = [], []
    train_baccs, val_baccs = [], []

    model.to(device)
    seq_criterion = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        all_labels, all_probs, all_preds = [], [], []

        class_couns_train = {0:0, 1:0}
        epoch_stats = defaultdict(list)

        print(f"Epoch {epoch+1}/{num_epochs}")
        for batch in tqdm(train_loader, desc="Train"):
            stft = batch['stft'].to(device)
            cqt = batch['cqt'].to(device)
            labels = batch['label'].to(device).long()

            for lbl in labels.cpu().numpy().flatten():
                class_couns_train[int(lbl.item())] += 1

            optimizer.zero_grad()

            if use_temporal:
                head_out, seq_scores = model(stft, cqt) # (B, ?), (B, T)

                if head_mode == "embedding":
                    embeddings = head_out
                    normal_proto = model.head.normal_prototype
                    embeddings = F.normalize(embeddings, dim=1)
                    normal_proto = F.normalize(normal_proto, dim=0)
                    cos_sim = torch.sum(embeddings * normal_proto.unsqueeze(0).expand_as(embeddings), dim=1)
                    normal_sim = cos_sim[labels == 0].mean().item() if(labels == 0).any() else None
                    anomaly_sim = cos_sim[labels == 1].mean().item() if (labels == 1).any() else None
                    if normal_sim is not None: epoch_stats['normal_sim'].append(normal_sim)
                    if anomaly_sim is not None: epoch_stats['anomaly_sim'].append(anomaly_sim)

                    probs_primary = 1 - cos_sim
                    primary_loss = criterion(embeddings, normal_proto, labels) if isinstance(criterion, ContrastiveLoss) \
                                   else criterion(probs_primary, labels.float())
                    preds = (probs_primary > current_threshold).long()
                    probs = probs_primary
                
                else:
                    probs, preds, primary_loss = _compute_primary_probs_and_loss_from_head(
                        head_mode, head_out, labels, criterion
                    )

                    aux_loss, seq_probs = _temporal_aux_loss(seq_scores, labels, seq_criterion)
                    loss = primary_loss + aux_seq_weight * aux_loss

                    # Fuse probs for metrics
                    probs = 0.7 * probs + 0.3 * seq_probs
            else:
                outputs = model(stft, cqt)
                if head_mode == "classifier":
                    if outputs.ndim == 2 and outputs.shape[1] == 2:
                        probs = torch.softmax(outputs, dim=1)[:, 1]
                        preds = torch.argmax(outputs.detach(), dim=1)
                        labels = criterion(outputs, labels)
                    else:
                        probs = torch.sigmoid(outputs.squeeze())
                        preds = (probs > current_threshold).long()
                        loss = criterion(outputs.squeeze(), labels.float())
                
                elif head_mode == "classifier-1":
                    logits = outputs.squeeze()
                    probs = torch.sigmoid(logits)
                    preds = (probs > current_threshold).long()
                    loss = criterion(logits, labels.float())
                
                elif head_mode == "mlp":
                    logits = outputs.squeeze(1) if outputs.ndim == 2 else outputs
                    probs = torch.sigmoid(logits)
                    preds = (probs > current_threshold).long()
                    loss = criterion(logits, labels.float())
                
                elif head_mode == "prototype":
                    # Distance-based anomaly score from AnomalyScorer prototype
                    scores = outputs.view(-1)
                    probs = torch.sigmoid(scores)
                    preds = (probs > current_threshold).long()
                    loss = criterion(scores, labels.float()) # BCE/BinaryFocalLoss
                
                elif head_mode == "embedding":
                    embeddings = outputs
                    normal_proto = model.head.normal_prototype
                    embeddings = F.normalize(embeddings, dim=1)
                    normal_proto = F.normalize(normal_proto, dim=0)
                    cos_sim = torch.sum(embeddings * normal_proto.unsqueeze(0).expand_as(embeddings), dim=1)
                    normal_sim = cos_sim[labels == 0].mean().item() if (labels == 0).any() else None
                    anomaly_sim = cos_sim[labels == 1].mean().item() if (labels == 1).any() else None
                    if normal_sim is not None: epoch_stats["normal_sim"].append(normal_sim)
                    if anomaly_sim is not None: epoch_stats["anomaly_sim"].append(anomaly_sim)

                    anomaly_scores = 1 - cos_sim
                    probs = anomaly_scores
                    preds = (anomaly_scores > current_threshold).long()
                    loss = criterion(embeddings, normal_proto, labels)

                else:
                    raise ValueError(f"Unsupported head_mode: {head_mode}")
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * stft.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.detach().cpu().numpy())
            all_preds.extend(preds.detach().cpu().numpy())

        print(f"[DEBUG] Train label counts (epoch {epoch+1}): {class_couns_train}")
        if epoch_stats['normal_sim']:
            avg_normal_sim = sum(epoch_stats["normal_sim"]) / len(epoch_stats["normal_sim"])
            avg_anomaly_sim = sum(epoch_stats["anomaly_sim"]) / len(epoch_stats["anomaly_sim"])
            print(f"[DEBUG] Average Normal CosSim: {avg_normal_sim:.4f} | Average Anomaly CosSim: {avg_anomaly_sim:.4f}")

        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        train_auc = roc_auc_score(all_labels, all_probs) if len(np.unique(all_labels)) > 1 else float('nan')
        train_acc = accuracy_score(all_labels, all_preds)
        train_bacc = balanced_accuracy_score(all_labels, all_preds)
        train_aucs.append(train_auc)
        train_accs.append(train_acc)
        train_baccs.append(train_bacc)

        print(f"Train Loss: {epoch_loss:.4f} | Train AUC: {train_auc:.4f} | Train Acc: {train_acc:.4f} | Train BAcc: {train_bacc:.4f}")

        # Validation
        val_loss, val_auc, val_acc, val_bacc, _, _, _, current_optimal_threshold = evaluate_model(
            model, val_loader, criterion, phase="Validation", device=device,head_mode=head_mode, sample_count=5, threshold=current_threshold, use_temporal=use_temporal, aux_seq_weight=aux_seq_weight
        )
        val_losses.append(val_loss)
        val_aucs.append(val_auc)
        val_accs.append(val_acc)
        val_baccs.append(val_bacc)

        if schedular is not None:
            try:
                schedular.step()
            except Exception as error:
                pass

        print(f"Epoch {epoch+1}: Learning Rate = {optimizer.param_groups[0]['lr']:.6f}")

        # Save by best loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            loss_path = model_save_path.replace(".pth", "_best_loss.pth")
            torch.save(model.state_dict(), loss_path)
            print(f"Saved Best-Loss model to {loss_path} (val_loss improved to {best_val_loss:.4f})")

        # Save by best AUC
        if not np.isnan(val_auc) and val_auc > best_val_auc:
            best_val_auc = val_auc
            best_threshold = current_optimal_threshold
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved Best-AUC model to {model_save_path} (val_auc improved to {best_val_auc:.4f})")
        else:
            print(f"Val AUC {val_auc:.4f} did not improve from {best_val_auc:.4f}")

    if save_plots:
        epochs = range(1, num_epochs+1)
        plt.figure(figsize=(18,4))
        plt.subplot(1,4,1); plt.plot(epochs, train_losses, label='Train Loss'); plt.plot(epochs, val_losses, label='Val Loss'); plt.legend(); plt.grid(True); plt.title("Train / Validation Loss")
        plt.subplot(1,4,2); plt.plot(epochs, train_aucs, label='Train AUC'); plt.plot(epochs, val_aucs, label='Val AUC'); plt.legend(); plt.grid(True); plt.title("Train / Validation AUC")
        plt.subplot(1,4,3); plt.plot(epochs, train_accs, label='Train ACC'); plt.plot(epochs, val_accs, label='Val ACC'); plt.legend(); plt.grid(True); plt.title("Train / Validation Accuracy")
        plt.subplot(1,4,4); plt.plot(epochs, train_baccs, label='Train BACC'); plt.plot(epochs, val_baccs, label='Val BACC'); plt.legend(); plt.grid(True); plt.title("Train / Validation BACC")
        plt.tight_layout(); plt.savefig(os.path.join(save_path, "Training_summpary.png")); plt.show(); plt.close()

    return best_threshold

# Main Pipeline

In [None]:
# =========================================================
# Main
# =========================================================
def main():
    SEED = 42
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Applying the Transformations
    train_transform = ComposeT([
        ToTensor(),
        SpecTimePitchWarp(max_time_scale=1.1, max_freq_scale=1.1),
        SpecAugment(freq_mask_param=2, time_mask_param=2, n_freq_masks=1, n_time_masks=1),
    ])

    no_transform = ComposeT([
        ToTensor(),
    ])

    # Base Paired dataset
    base_dataset = PairedSpectrogramDataset(base_dir=FEATURES_DIR, transform=None)
    all_labels = [int(x) for x in base_dataset.labels]

    # Wrapping with windowed dataset if temporal
    if USE_TEMPORAL_DECODER:
        wrapped_dataset = WindowedPairedSpectrogramDataset(base_dataset=base_dataset, window_size=WINDOW_SIZE)
        effective_len = len(wrapped_dataset)

        windowed_labels = [all_labels[i] for i in range(effective_len)]

        idxs = list(range(effective_len))
        train_idx, temp_idx = train_test_split(idxs, test_size=0.3, stratify=windowed_labels, random_state=42)
        val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, stratify=[windowed_labels[i] for i in temp_idx], random_state=42)

        train_base = PairedSpectrogramDataset(base_dir=FEATURES_DIR, transform=train_transform)
        val_base = PairedSpectrogramDataset(base_dir=FEATURES_DIR, transform=no_transform)
        test_base = PairedSpectrogramDataset(base_dir=FEATURES_DIR, transform=None)

        train_set = Subset(WindowedPairedSpectrogramDataset(base_dataset=train_base, window_size=WINDOW_SIZE), train_idx)
        val_set = Subset(WindowedPairedSpectrogramDataset(base_dataset=val_base, window_size=WINDOW_SIZE), val_idx)
        test_set = Subset(WindowedPairedSpectrogramDataset(base_dataset=test_base, window_size=WINDOW_SIZE), test_idx)
    
    else:
        # Stratified split indices of base dataset
        idxs = list(range(len(base_dataset)))
        train_idx, temp_idx = train_test_split(idxs, test_size=0.3, stratify=all_labels, random_state=42)
        val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, stratify=[all_labels[i] for i in temp_idx], random_state=42)

        train_base = PairedSpectrogramDataset(base_dir=FEATURES_DIR, transform=train_transform)
        val_base = PairedSpectrogramDataset(base_dir=FEATURES_DIR, transform=no_transform)
        test_base = PairedSpectrogramDataset(base_dir=FEATURES_DIR, transform=no_transform)

        train_set = Subset(train_base, train_idx)
        val_set = Subset(val_base, val_idx)
        test_set = Subset(test_base, test_idx)

    # Data Loaders
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    print(f"Splir Sizes => Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")
    print(f"Label Distribution (Train): {Counter([int(base_dataset[i]['label']) for i in train_idx])}")
    print(f"Lable Distribution (Valdation): {Counter([int(base_dataset[i]['label']) for i in val_idx])}")
    print(f"Label Distribution (Test): {Counter([int(base_dataset[i]['label']) for i in test_idx])}")

    head_mode = HEAD_MODE.lower()

    # Configure head + criterion
    if head_mode == "prototype":
        # Note: AnomalyScorer('prototype') returns a distance score.
        # Use BCE/ BinaryFocalLoss on this score (Not ContrastiveLoss)
        head = AnomalyScorer(in_dim=256, dropout=0.4, mode='prototype')
        pos_count = sum(all_labels); neg_count = len(all_labels) - pos_count
        pos_weight = torch.tensor([neg_count / (pos_count + 1e-8)], dtype=torch.float32).to(device)
        criterion = BinaryFocalLoss(alpha=0.25, gamma=2.0, pos_weight=pos_weight, reduction='mean')
        # print(f"Using head: {head}")
        print("\nTransformations Applied to Train Set")
        for transform in train_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
        print("Transformation's Applied to Validation/Test Set")
        for transform in no_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
    
    elif head_mode == "mlp":
        head = ComplexAnomalyMLP(in_dim=256, dropout=0.4, out_dim=1)
        pos_count = sum(all_labels); neg_count = len(all_labels) - pos_count
        pos_weight = torch.tensor([neg_count / (pos_count + 1e-8)], dtype=torch.float32).to(device)
        criterion = BinaryFocalLoss(alpha=0.25, gamma=2.0, pos_weight=pos_weight, reduction='mean')
        # print(f"Used Head: {head}")
        print("\nTransformations Applied to Train Set")
        for transform in train_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
        print("Transformation's Applied to Validation/Test Set")
        for transform in no_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
    
    elif head_mode == "embedding":
        head = EmbeddingMLP(in_dim=256, hidden=128, dropout=0.4, emb_dim=EMB_DIM)
        criterion = ContrastiveLoss(margin=CONTRASTIVE_MARGIN)
        # print(f"Used head: {head}")
        print("\nTransformations Applied to Train Set")
        for transform in train_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
        print("Transformation's Applied to Validation/Test Set")
        for transform in no_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
    
    elif head_mode == "classifier":
        head = SimpleAnomalyMLP(in_dim=256, dropout=0.4, hidden=128, out_dim=2)
        class_counts = [2624, 319]
        alpha = 0.7
        total = sum(class_counts)
        class_weights = [(total /c ) ** alpha for c in class_counts]
        class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        # print(f"Used head: {head}")
        print("\nTransformations Applied to Train Set")
        for transform in train_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
        print("Transformation's Applied to Validation/Test Set")
        for transform in no_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
    
    elif head_mode == "classifier-1":
        head = AnomalyScorer(in_dim=256, dropout=0.4, mode='classifier-1')
        pos_count = sum(all_labels); neg_count = len(all_labels) - pos_count
        pos_weight = torch.tensor([neg_count / (pos_count + 1e-8)], dtype=torch.float32).to(device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        # print(f"Using head: {head}")
        print("\nTransformations Applied to Train Set")
        for transform in train_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
        print("Transformation's Applied to Validation/Test Set")
        for transform in no_transform.transforms:
            # Print the name of the transformation class
            print(f"  - {transform.__class__.__name__}")
    
            # Check for specific transformations and print their parameters
            if isinstance(transform, SpecTimePitchWarp):

                print(f"    - time_scale: {getattr(transform, 'max_time_scale', {transform.max_time})}")
                print(f"    - freq_scale: {getattr(transform, 'max_freq_scale', {transform.max_freq})}")
            if isinstance(transform, SpecAugment):
                print(f"    - freq_mask_param: {getattr(transform,'freq_mask_param',{transform.fm})}")
                print(f"    - time_mask_param: {getattr(transform,'time_mask_param',{transform.tm})}")
                print(f"    - n_freq_masks: {getattr(transform,'n_freq_masks', {transform.nf})}")
                print(f"    - n_time_masks: {getattr(transform,'n_time_masks', {transform.nt})}")
    
    else:
        raise ValueError("Invalid Head_MODE")
    
    # Build Model 
    model = FusedModel(
        stft_dim=512, cqt_dim=320, fusion_dim=256,
        head= head, head_mode=head_mode, use_decoder=USE_TEMPORAL_DECODER, temporal_hidden=64
    ).to(device)
    # print(f"\nUsing the Model: {model}\n")
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=NUM_EPOCHs, eta_min=1e-6)

    model_path = os.path.join(save_path, "best_model.pth")
    os.makedirs(os.path.dirname(model_path), exist_ok=True)

    best_threshold = train_model(model=model, train_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=optimizer,
                                 head_mode=head_mode, schedular=scheduler, num_epochs=NUM_EPOCHs, model_save_path=model_path,
                                 device=device, save_plots=True, use_temporal=USE_TEMPORAL_DECODER, aux_seq_weight=SEQ_LOSS_WEIGHT
                                )
    
    print("\n--- Final Test Evaluation ---")
    model.load_state_dict(torch.load(model_path, map_location=device))

    safe_threshold = float(best_threshold) if best_threshold is not None else 0.5

    test_loss, test_auc, test_acc, test_bacc, test_f1, all_labels_test, all_probs_test, _ = evaluate_model(
        model=model, data_loader=test_loader, criterion=criterion, phase="Test", device=device, head_mode=head_mode,sample_count=5, threshold=safe_threshold, 
        use_temporal=USE_TEMPORAL_DECODER, aux_seq_weight=SEQ_LOSS_WEIGHT
    )
    print(f"\nFinal Test Metrics (With best Validation threshold): {best_threshold:.2f}")
    print(f"Loss: {test_loss:.4f} | AUC: {test_auc:.4f} | Accuracy: {test_acc:.4f} | Balanced Accuracy: {test_bacc:.4f} | F1-Score: {test_f1:.4f}")

    if len(np.unique(all_labels_test)) > 1:
        final_pauc = calculate_pAUC(labels=all_labels_test, preds = all_probs_test, max_fpr=0.2)
    else:
        print("Test set contains only one class; cannot compute pAUC")
    
    # ROC
    fpr, tpr, _ = roc_curve(y_true=all_labels_test, y_score=all_probs_test)
    plt.figure(figsize=(6,6))
    plt.plot(fpr, tpr, lw=2, label=f"{test_auc:.4f}")
    plt.plot([0,1],[0,1], linestyle='--', lw=1)
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("Test ROC With Optimal Threshold"); plt.legend()
    plt.grid(True); plt.tight_layout()
    plt.savefig(os.path.join(save_path, "roc_test_optimal.png")); plt.close()
    plt.show()
    plt.close()

    labels = ["Normal", "Abnormal"]
    all_preds_test = (np.array(all_probs_test) > safe_threshold).astype(int)
    plot_confusion_matrix(y_true=all_labels_test, y_pred=all_preds_test, labels=labels, save_path=save_path, title="Test Set Confusion Matrix")

    # GradCAM
    try:
        cams = build_gradcam_for_model(model=model, device=device)
        cam_out_dir = os.path.join(save_path, 'gradcam')
        ref_set = test_set
        run_and_save_gradcams(model=model, cams=cams, dataset=ref_set, device=device, out_dir=cam_out_dir, n_samples=5)
    except Exception as error:
        print(f"GradCAM step Failed: {error}")

if __name__ == "__main__":
    main()