# Standard Imports

In [None]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from collections import Counter
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, f1_score
from tqdm import tqdm

# Custom Imports

In [None]:
custom_modules_path = os.path.abspath(r'F:\Capstone\DFCA')

# Add the path to sys.path
if custom_modules_path not in sys.path:
    sys.path.append(custom_modules_path)

from utils.datasets import PairedSpectrogramDataset, WindowedPairedSpectrogramDataset
from utils.metrics_utils import calculate_pAUC, plot_confusion_matrix
from scripts.pretrain_pipeline import FusedModel
from models.heads import AnomalyScorer

# Configuration's

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}" + (f" - {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else ""))
WINDOW_SIZE = 5
BATCH_SIZE=32
USE_TEMPORAL_DECODER = True 
SEQ_LOSS_WEIGHT = 0.3
save_path = os.path.abspath(r"F:\CapStone\DFCA\checkpoints\Classifier[0_dB_valve]")

# Helper Function's

In [None]:
def _compute_primary_probs_and_loss_from_head(outputs, labels, criterion):
    """
        Returns: probs [B], preds [B], loss (scalar)
        Assumes outputs is:
            - logits tensor for classifier/mlp/classifier-1
            - distance/anomaly score for prototype (AnomalyScorer prototype)
            - embeddings for embedding head (handled separately!)
    """
    # ===== DEBUG PRINT ===================
    # print("Logits shape: ", outputs.shape)
    # print("Labels shape: ",labels.shape)
    # ===== DEBUG PRINT ===================
    logits = outputs.squeeze()
    probs = torch.sigmoid(logits)
    loss = criterion(logits, labels.float())
    preds = (probs > 0.5).long()
    
    return probs, preds, loss

def _temporal_aux_loss(seq_scores, labels, criterions_for_seq):
    """
        seq_scores: (B, T) raw logits from TemporalSmoothingDEcoder (Linear output)
        labels: (B, ) => expand to (B, T)
        criterion_for_seq: BCEWithLogitsLoss (or similar) for temporal smoothing
    """
    if seq_scores.ndim == 2:
        B, T = seq_scores.shape
        labels_T = labels.float().unsqueeze(1).expand(B, T)
    
    elif seq_scores.ndim == 1:
        B = seq_scores.shape[0]
        T = 1
        seq_scores = seq_scores.unsqueeze(1)
        labels_T = labels.float().unsqueeze(1)
    
    else:
        raise ValueError(f"Unexpected seq_scores shape {seq_scores.shape}")
    
    aux_loss = criterions_for_seq(seq_scores, labels_T)
    # derive a sequence-level probability for metrics by averaging sigmoid(seq_scores)
    seq_probs = torch.sigmoid(seq_scores).mean(dim=1)
    
    return aux_loss, seq_probs

# Evaluate Model

In [None]:
def evaluate_model(model, data_loader, criterion, phase="Evaluation", device="cpu", sample_count=10, threshold=0.5, use_temporal=False, aux_seq_weight=SEQ_LOSS_WEIGHT):
    model.eval()
    running_loss = 0.0
    all_labels, all_probs = [], []
    best_threshold = threshold
    f1 = 0.0

    # For temporal aux loss
    seq_criterion = nn.BCEWithLogitsLoss()
    
    class_counts = {0: 0, 1:0}

    with torch.no_grad():
        for batch  in tqdm(data_loader, desc=phase):
            stft = batch['stft'].to(device)
            cqt = batch['cqt'].to(device)
            labels = batch['label'].to(device).long()

            for lbl in labels.cpu().numpy().flatten():
                class_counts[int(lbl.item())] += 1
            
            if use_temporal:
                head_out, seq_scores = model(stft,cqt) # head_out: [B, ?], seq_scores: (B, T)
                probs_primary, preds_primary, primary_loss = _compute_primary_probs_and_loss_from_head(
                    head_out, labels, criterion
                )
                # auxiliary temporal smoothing
                aux_loss, seq_probs = _temporal_aux_loss(seq_scores, labels, seq_criterion)

                # Merge probs for metric (bend primary with sequence; kep primary dominant)
                probs = 0.7 * probs_primary + 0.3 * seq_probs
                loss = primary_loss + aux_seq_weight * aux_loss
            
            
            running_loss += loss.item() * stft.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.detach().cpu().numpy())

    print(f"[DEBUG] {phase} label counts: {class_counts}")

    # Optimal threshold sweep on Validation
    f1 = 0.0
    if phase == "Validation":
        best_f1 = 0
        current_optimal_threshold = 0.5
        for thresh in np.arange(0.01, 1.0, 0.01):
            predictions_thresh = (np.array(all_probs) > thresh).astype(int)
            f1_candidate = f1_score(all_labels, predictions_thresh)

            if f1_candidate > best_f1:
                best_f1 = f1_candidate
                current_optimal_threshold = thresh
        best_threshold = current_optimal_threshold
        f1 = best_f1
        print(f"Optimal Threshold (F1-score): {best_threshold:.2f}")
        print(f"Best F1-score on Validation Set: {best_f1:.4f}")

    # Metrics under chosen threshold
    all_preds = (np.array(all_probs) > best_threshold).astype(int)
    if phase != "Validation":
        if len(np.unique(all_labels)) > 1:
            f1 = f1_score(all_labels, all_preds)
        else:
            f1 = 0.0

    avg_loss = running_loss / len(data_loader.dataset)
    auc_score = roc_auc_score(all_labels, all_probs) if len(np.unique(all_labels)) > 1 else float('nan')
    acc_score = accuracy_score(all_labels, all_preds)
    bacc_score = balanced_accuracy_score(all_labels, all_preds)

    print(f"{phase} Loss: {avg_loss:.4f} | {phase} AUC: {auc_score:.4f} | {phase} ACC: {acc_score:.4f} | {phase} BACC: {bacc_score:.4f}")
    print(f"[DEBUG] {phase} Prediction Distribution: {dict(Counter(all_preds))}")
    print(f"[DEBUG] {phase} Label Distribution: {dict(Counter(all_labels))}")
    print("==================== Misclassification & Samples ====================")
    errors = [(i, p, pr, l) for i, (p, pr, l) in enumerate(zip(all_preds, all_probs, all_labels)) if p != l]
    print(f"{phase} Misclassified Samples: {len(errors)} / {len(all_labels)}")
    print("\nSample Predictions Vs Lables:")
    for i in range(min(10, len(all_labels))):
        print(f"Sample {i+1}: Pred = {all_preds[i]}, Prob = {all_probs[i]:.4f}, True = {all_labels[i]}")
    
    return avg_loss, auc_score, acc_score, bacc_score, f1, all_labels, all_probs, best_threshold

# Function to test on new dataset

In [None]:
def test_dataset(model, features_dir, model_path, best_threshold, device):
    """
        Loads a trained model and evaluate it on a new dataset directory.
        Args:
            model (nn.Module): The model architectue.
            features_dir(str): Path to the directory containing the new test features.
            model_path(str): Path to saved .pth model file.
            best_threshold(float): The optimal threshold found during training
            device(torch.device): The device to run evaluation on ('cuda' or 'cpu')
    """
    print(f"=== Evaluating on:{os.path.basename(features_dir)} ===")

    try:
        model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
        model.to(device)
        print(f"Model Loaded successfully from {model_path}")
    except Exception as error:
        print(f"Error in loading model from {model_path}: {error}")
        return
    
    new_base_dataset = PairedSpectrogramDataset(base_dir=features_dir, transform=None)
    new_test_set = WindowedPairedSpectrogramDataset(base_dataset=new_base_dataset, window_size=WINDOW_SIZE)
    new_test_loader = DataLoader(dataset=new_test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    print(f"Test Set Size: {len(new_test_set)}")
    print(f"Label Distribution: {Counter(new_base_dataset.labels)}")

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], dtype=torch.float32).to(device))

    test_loss, test_auc, test_acc, test_bacc, test_f1, all_labels, all_probs, _ = evaluate_model(
        model=model, data_loader=new_test_loader, criterion=criterion, phase="Test",
        device=device, threshold=best_threshold, use_temporal=USE_TEMPORAL_DECODER, aux_seq_weight=SEQ_LOSS_WEIGHT
    )

    if len(np.unique(all_labels)) > 1:
        pauc_score = calculate_pAUC(labels=all_labels, preds=all_probs, max_fpr=0.2)
        print(f"Partial AUC (pAUC @ 0.2 FPR): {pauc_score:.4f}")

    else:
        pauc_score=float('nan')
        print("Test set contains only one class; Cannot compute pAUC")

    print(f"\n Final Test Metrics (Threshold: {best_threshold:.2f})")
    print(f"Loss:{test_loss:.4f} | AUC:{test_auc:.4f} | Accuracy: {test_acc:.4f} | Balanced Accuracy:{test_bacc:.4f} | F1-Score:{test_f1:.4f} | pAUC:{pauc_score:.4f}")
    

    labels_display = ["Normal", "Abnormal"]
    all_preds = (np.array(all_probs) > best_threshold).astype(int)
    plot_confusion_matrix(y_true=all_labels, y_pred=all_preds, labels=labels_display, save_path=save_image_path, title=f"Confusion Matrix - {os.path.basename(features_dir)}") # type: ignore
    print(f"\nCompleted Test on {features_dir}\n")

# Main Pipeline

In [None]:
if __name__ == "__main__":
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}" + (f" - {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else ""))
    save_image_path = os.path.abspath(r"F:\Capstone\DFCA\checkpoints\Results\[6 and -6_valve](0.07)")
    save_image_path = os.path.join(save_image_path,'AUC_Model')
    # save_image_path = os.path.join(save_image_path,'BACC_Model')
    os.makedirs(save_image_path, exist_ok=True)
    print(f"Image Saving Path: {save_image_path}")
    SAVED_MODEL_PATH = os.path.join(save_path, "best_model.pth")
    # SAVED_MODEL_PATH = os.path.join(save_path, "best_bacc.pth")
    OPTIMAL_THRESHOLD = 0.07
    # Base directory
    BASE_FEATURES_DIRECTORY = os.path.abspath(r'F:\Capstone\DFCA\data\features')

    # Paths to the new test directories
    SIX_DB_VALVE_DIR = os.path.join(BASE_FEATURES_DIRECTORY, '6_dB_valve_features')
    MINUS_SIX_DB_VALVE_DIR = os.path.join(BASE_FEATURES_DIRECTORY,'-6_dB_valve_features')

    # Instiate of head
    head = AnomalyScorer(in_dim=256, dropout=0.4, mode='classifier-1')
    model = FusedModel(stft_dim=512, cqt_dim=320, fusion_dim=256, head=head, use_decoder=USE_TEMPORAL_DECODER, temporal_hidden=64)
    
    # ====== Test On 6_dB_valve_features dataset =====
    if os.path.isdir(SIX_DB_VALVE_DIR):
        test_dataset(
            model=model,
            features_dir=SIX_DB_VALVE_DIR,
            model_path=SAVED_MODEL_PATH,
            best_threshold=OPTIMAL_THRESHOLD,
            device=device
        )
    else:
        print(f"Directory not found: {SIX_DB_VALVE_DIR}")

    # ====== Test On -6_dB_valve_features dataset =====
    if os.path.isdir(MINUS_SIX_DB_VALVE_DIR):
        test_dataset(
            model=model,
            features_dir=MINUS_SIX_DB_VALVE_DIR,
            model_path=SAVED_MODEL_PATH,
            best_threshold=OPTIMAL_THRESHOLD,
            device=device
        )
    else:
        print(f"Directory not found: {MINUS_SIX_DB_VALVE_DIR}")