In [1]:
!pip install audiomentations

Collecting audiomentations
  Downloading audiomentations-0.43.1-py3-none-any.whl.metadata (11 kB)
Collecting numpy-minmax<1,>=0.3.0 (from audiomentations)
  Downloading numpy_minmax-0.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.0 kB)
Collecting numpy-rms<1,>=0.4.2 (from audiomentations)
  Downloading numpy_rms-0.6.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.5 kB)
Collecting python-stretch<1,>=0.3.1 (from audiomentations)
  Downloading python_stretch-0.3.1-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)
Collecting soxr<1.0.0,>=0.3.2 (from audiomentations)
  Downloading soxr-0.5.0.post1-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Downloading audiomentations-0.43.1-py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.1/86.1 kB[0m [31m3.0 MB/s[0m eta [36m

In [2]:
# ==============================================
# NEONATAL CRY CLASSIFICATION PIPELINE (IMPROVED)
# HeAR (google/hear) TensorFlow + Fine-Tuned Classifiers
# Enhancements: data augmentation, attention pooling, ensemble, class weights,
#               hyperparameter tuning, stronger regularization.
# ==============================================

import os
import gc
import psutil
import random
import warnings
import traceback
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from collections import defaultdict
import itertools

# Sklearn
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score, GridSearchCV, RandomizedSearchCV
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    precision_recall_fscore_support, roc_auc_score, roc_curve
)
from sklearn.preprocessing import label_binarize, StandardScaler
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import (
    RandomForestClassifier, GradientBoostingClassifier,
    VotingClassifier, StackingClassifier
)
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import Pipeline
from sklearn.utils.class_weight import compute_class_weight

# Signal processing
from scipy.signal import wiener
from scipy.ndimage import gaussian_filter1d
import librosa
import librosa.display
import soundfile as sf

# Audio augmentation
import audiomentations as A

# Plotting
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns

# TensorFlow / HeAR
import tensorflow as tf
from huggingface_hub import login, snapshot_download

# PyTorch (for fine-tuned attention head)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

warnings.filterwarnings("ignore")

# ==============================================
# GLOBAL CONFIGURATION
# ==============================================

HF_TOKEN_PATH = "/kaggle/input/datasets/tobimichigan/acess-tkns/acceess_tkns/hf.token.txt"
HEAR_REPO_ID  = "google/hear"
LOCAL_MODEL_DIR = "./hear-model"

TARGET_SR     = 16000
CLIP_DURATION = 2                        # HeAR expects 2-second clips
CLIP_LENGTH   = TARGET_SR * CLIP_DURATION
CLIP_OVERLAP  = 0.10                     # 10% overlap between windows

BATCH_SIZE    = 32                       # embedding inference batch
PT_BATCH      = 16                       # PyTorch fine-tune batch
EPOCHS        = 50                        # increased for better convergence
LR            = 1e-3                      # will be tuned
WEIGHT_DECAY  = 1e-3                      # increased regularization
NUM_CLASSES   = 3                        # pain=0, hunger=1, neurological=2

PLOTS_DIR     = "./plots"
CKPT_PATH     = "./best_model.pth"
ENSEMBLE_DIR  = "./ensemble_models"
os.makedirs(PLOTS_DIR, exist_ok=True)
os.makedirs(ENSEMBLE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CLASS_NAMES   = ["Pain", "Hunger", "Neurological"]

# Augmentation parameters
AUGMENT_PROB = 0.5                        # probability of applying augmentation
AUGMENT_FACTOR = 3                         # number of augmented copies per original file (training only)

# ==============================================
# MEMORY UTILITIES (unchanged)
# ==============================================

def get_memory_gb():
    return psutil.Process(os.getpid()).memory_info().rss / 1024 ** 3

def log_memory(tag=""):
    print(f"  [MEM{(' '+tag) if tag else ''}] {get_memory_gb():.2f} GB")

def force_cleanup(*args):
    for obj in args:
        del obj
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ==============================================
# HF LOGIN (unchanged)
# ==============================================

print("=" * 60)
print("NEONATAL CRY CLASSIFICATION PIPELINE (IMPROVED)")
print("=" * 60)

with open(HF_TOKEN_PATH, "r") as f:
    hf_token = f.read().strip()
login(token=hf_token)
print("Hugging Face token loaded.")

# ==============================================
# LOAD HeAR MODEL (unchanged)
# ==============================================

if not os.path.exists(LOCAL_MODEL_DIR):
    print(f"Downloading HeAR model to {LOCAL_MODEL_DIR} ...")
    snapshot_download(repo_id=HEAR_REPO_ID, local_dir=LOCAL_MODEL_DIR, local_dir_use_symlinks=False)
else:
    print(f"HeAR model found at {LOCAL_MODEL_DIR}")

print("Loading HeAR TensorFlow SavedModel...")
hear_model    = tf.saved_model.load(LOCAL_MODEL_DIR)
hear_infer    = hear_model.signatures["serving_default"]
EMBEDDING_DIM = 1280   # HeAR produces 1280-dim embeddings
print(f"HeAR model loaded. Embedding dim: {EMBEDDING_DIM}")
log_memory("after HeAR load")

# ==============================================
# AUDIO COLLECTION (unchanged)
# ==============================================

ROOT_PATHS = [
    "/kaggle/input/datasets/mennaahmed23/baby-crying-dataset/Baby crying",
    "/kaggle/input/datasets/oluwatobiowoeye/infant-acousticdataset/infant_cry_datasets"
]

AUDIO_EXTS = {".wav", ".3gp", ".caf", ".m4a", ".ogg", ".mp3", ".flac", ".aac"}

def collect_audio_files(root_paths):
    audio_files = []
    for root in root_paths:
        root_path = Path(root)
        if not root_path.exists():
            print(f"  WARNING: path not found: {root}")
            continue
        for path in tqdm(list(root_path.rglob("*")), desc=f"Scanning {root_path.name}"):
            if path.suffix.lower() in AUDIO_EXTS:
                audio_files.append(str(path))
    return audio_files

# ==============================================
# LABEL MAPPING (unchanged)
# pain=0, hunger=1, neurological/distress=2
# ==============================================

def map_label(path: str):
    p = path.lower()
    # Neurological / distress signals
    if any(k in p for k in ["neurological", "distress", "discomfort", "uncomfortable",
                              "scared", "lonely", "cold_hot", "cold", "hot", "snoring"]):
        return 2
    # Hunger
    if any(k in p for k in ["hungry", "hunger", "h"]):
        # avoid false positive on 'hot'
        if "hungry" in p or "hunger" in p:
            return 1
    # Pain
    if any(k in p for k in ["belly", "pain", "burn", "burp", "tired", "tire"]):
        return 0
 
    parts = Path(path).parts
    for part in parts:
        part = part.lower()
        if part in ["hu"]:
            return 1
        if part in ["bp", "bu", "ca", "ch", "lo", "ti", "sc", "co"]:
            return 0   # various pain/discomfort signals → pain
        if part in ["de", "di"]:
            return 2
    # Fallback: parent folder name heuristics
    parent = Path(path).parent.name.lower()
    if any(k in parent for k in ["hungry", "hunger"]):
        return 1
    if any(k in parent for k in ["pain", "belly", "burn", "burp", "tired", "discomfort",
                                   "uncomfortable", "scared", "lonely", "cold"]):
        return 0
    if any(k in parent for k in ["neurological", "distress", "snoring"]):
        return 2
    return None

# ==============================================
# AUDIO PREPROCESSING + AUGMENTATION (fixed Shift parameters)
# ==============================================

# Define augmentation pipeline
augmenter = A.Compose([
    A.AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
    A.TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
    A.PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
    A.Shift(min_shift=-0.5, max_shift=0.5, p=0.3),  # corrected from min_fraction/max_fraction
])

def preprocess_audio(file_path: str, augment=False) -> np.ndarray:
    """Load, denoise, normalize, optionally augment. Returns float32 array at TARGET_SR."""
    audio, sr = librosa.load(file_path, sr=TARGET_SR, mono=True)
    # Wiener denoising
    audio = wiener(audio).astype(np.float32)
    # Smooth
    audio = gaussian_filter1d(audio, sigma=0.5).astype(np.float32)
    # Normalize
    rms = np.sqrt(np.mean(audio**2)) + 1e-9
    audio = audio / rms

    if augment and random.random() < AUGMENT_PROB:
        # audiomentations expects samples in shape (samples,) and returns same
        audio = augmenter(samples=audio, sample_rate=TARGET_SR)

    return audio

def segment_audio(audio: np.ndarray) -> list:
    """Segment audio into overlapping CLIP_LENGTH clips."""
    step  = int(CLIP_LENGTH * (1 - CLIP_OVERLAP))
    clips = []
    for start in range(0, max(1, len(audio) - CLIP_LENGTH + 1), step):
        clip = audio[start:start + CLIP_LENGTH]
        if len(clip) < CLIP_LENGTH:
            clip = np.pad(clip, (0, CLIP_LENGTH - len(clip)), "constant")
        clips.append(clip)
    if len(clips) == 0:
        clip = np.pad(audio, (0, CLIP_LENGTH - len(audio)), "constant") if len(audio) < CLIP_LENGTH else audio[:CLIP_LENGTH]
        clips = [clip]
    return clips

def rms_db(clip: np.ndarray) -> float:
    return 20 * np.log10(np.sqrt(np.mean(clip**2)) + 1e-10)

# ==============================================
# HeAR EMBEDDING EXTRACTION (batch, memory-safe)
# ==============================================

def extract_embeddings_batch(clips_batch: np.ndarray) -> np.ndarray:
    """Run HeAR inference on a batch of clips."""
    tf_input = tf.constant(clips_batch.astype(np.float32))
    out      = hear_infer(x=tf_input)
    emb      = out["output_0"].numpy()
    return emb  # (B, 1280)

def extract_file_embedding(file_path: str, silence_db=-50.0, augment=False) -> np.ndarray | None:
    """
    Full pipeline for one file: load → preprocess (augment) → segment → embed → mean-pool.
    Returns mean embedding (1280,) or None on error.
    """
    try:
        audio  = preprocess_audio(file_path, augment=augment)
        clips  = segment_audio(audio)
        # Filter silent clips
        clips  = [c for c in clips if rms_db(c) > silence_db]
        if len(clips) == 0:
            clips = [segment_audio(audio)[0]]  # keep at least one
        batch  = np.stack(clips, axis=0)       # (N, CLIP_LENGTH)
        emb    = extract_embeddings_batch(batch)
        return emb.mean(axis=0)                # mean-pool → (1280,)
    except Exception as e:
        return None

# ==============================================
# DATA PREPARATION (unchanged)
# ==============================================

print("\n" + "=" * 60)
print("STEP 1: DATA COLLECTION")
print("=" * 60)

print("Collecting audio files...")
all_files  = collect_audio_files(ROOT_PATHS)
print(f"Total audio files found: {len(all_files)}")
log_memory("after file collection")

print("Mapping labels...")
labels_raw = [map_label(f) for f in tqdm(all_files, desc="Labelling")]
data       = [(f, l) for f, l in zip(all_files, labels_raw) if l is not None]
unlabelled = len(all_files) - len(data)
print(f"Labelled files: {len(data)} | Unlabelled (skipped): {unlabelled}")

files_all  = [d[0] for d in data]
labels_all = [d[1] for d in data]

label_series = pd.Series(labels_all)
print("\nClass distribution:")
for cls_id, cls_name in enumerate(CLASS_NAMES):
    n = (label_series == cls_id).sum()
    print(f"  {cls_name} ({cls_id}): {n} files  ({100*n/len(labels_all):.1f}%)")

# ==============================================
# FOUR-WAY SPLIT: Train 40% | Val 15% | Test 15% | Holdout 30%
# ==============================================

print("\n" + "=" * 60)
print("STEP 2: DATA SPLITTING (40/15/15/30)")
print("=" * 60)

X_train, X_temp, y_train, y_temp = train_test_split(
    files_all, labels_all, test_size=0.60, stratify=labels_all, random_state=42)

X_val, X_hold_test, y_val, y_hold_test = train_test_split(
    X_temp, y_temp, test_size=0.75, stratify=y_temp, random_state=42)

X_test, X_hold, y_test, y_hold = train_test_split(
    X_hold_test, y_hold_test, test_size=0.67, stratify=y_hold_test, random_state=42)

print(f"Train:   {len(X_train)} files  ({100*len(X_train)/len(files_all):.1f}%)")
print(f"Val:     {len(X_val)} files  ({100*len(X_val)/len(files_all):.1f}%)")
print(f"Test:    {len(X_test)} files  ({100*len(X_test)/len(files_all):.1f}%)")
print(f"Holdout: {len(X_hold)} files  ({100*len(X_hold)/len(files_all):.1f}%)")

def verify_data_splits(splits_dict):
    """Ensure no file appears in more than one split."""
    print("\nVerifying data split integrity...")
    all_sets = list(splits_dict.items())
    ok = True
    for i in range(len(all_sets)):
        for j in range(i+1, len(all_sets)):
            n1, s1 = all_sets[i]
            n2, s2 = all_sets[j]
            overlap = set(s1) & set(s2)
            if overlap:
                print(f"  WARNING: {n1} ∩ {n2} = {len(overlap)} files!")
                ok = False
    if ok:
        print("  ✓ No data leakage detected across splits.")

verify_data_splits({"Train": X_train, "Val": X_val, "Test": X_test, "Holdout": X_hold})

# ==============================================
# EDA: GRAPHICAL PLOTS (unchanged)
# ==============================================

print("\n" + "=" * 60)
print("STEP 3: EXPLORATORY DATA ANALYSIS")
print("=" * 60)

def save_show(path):
    plt.savefig(path, dpi=150, bbox_inches="tight")
    plt.show()
    plt.close()
    print(f"  Saved: {path}")

# ---- Class Distribution per split ----
def plot_class_distribution(y, title, filename):
    counts = [sum(np.array(y)==i) for i in range(NUM_CLASSES)]
    fig, ax = plt.subplots(figsize=(7, 4))
    bars = ax.bar(CLASS_NAMES, counts, color=["#E74C3C","#F39C12","#3498DB"], edgecolor="black")
    ax.set_title(title, fontsize=13, fontweight="bold")
    ax.set_ylabel("Number of Samples")
    for bar, cnt in zip(bars, counts):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                str(cnt), ha="center", fontweight="bold")
    plt.tight_layout()
    save_show(f"{PLOTS_DIR}/{filename}")

for split_name, ys in [("Train", y_train), ("Validation", y_val),
                        ("Test", y_test), ("Holdout", y_hold)]:
    plot_class_distribution(ys, f"{split_name} Set Class Distribution",
                             f"dist_{split_name.lower()}.png")

# ---- Combined distribution overview ----
fig, axes = plt.subplots(1, 4, figsize=(18, 4))
for ax, (split_name, ys) in zip(axes, [("Train", y_train), ("Val", y_val),
                                         ("Test", y_test), ("Holdout", y_hold)]):
    counts = [sum(np.array(ys)==i) for i in range(NUM_CLASSES)]
    ax.bar(CLASS_NAMES, counts, color=["#E74C3C","#F39C12","#3498DB"])
    ax.set_title(split_name); ax.set_ylabel("Count")
    for i, c in enumerate(counts):
        ax.text(i, c + 0.3, str(c), ha="center", fontsize=9)
plt.suptitle("Class Distribution Across All Splits", fontsize=14, fontweight="bold")
plt.tight_layout()
save_show(f"{PLOTS_DIR}/dist_all_splits.png")

# ---- Sample waveforms ----
def plot_sample_waveforms(files, labels, n_per_class=2):
    fig, axes = plt.subplots(NUM_CLASSES, n_per_class, figsize=(14, 8))
    for cls_id, cls_name in enumerate(CLASS_NAMES):
        idxs = [i for i, l in enumerate(labels) if l == cls_id][:n_per_class]
        for k, idx in enumerate(idxs):
            try:
                audio, sr = librosa.load(files[idx], sr=TARGET_SR, duration=3.0)
                t = np.linspace(0, len(audio)/sr, len(audio))
                axes[cls_id][k].plot(t, audio, linewidth=0.5, color=["#E74C3C","#F39C12","#3498DB"][cls_id])
                axes[cls_id][k].set_title(f"{cls_name} – {Path(files[idx]).name[:25]}", fontsize=8)
                axes[cls_id][k].set_xlabel("Time (s)")
            except Exception:
                axes[cls_id][k].set_title(f"{cls_name} – load error")
    plt.suptitle("Sample Waveforms per Class", fontsize=13, fontweight="bold")
    plt.tight_layout()
    save_show(f"{PLOTS_DIR}/sample_waveforms.png")

plot_sample_waveforms(X_train, y_train)

# ---- Sample spectrograms ----
def plot_sample_spectrograms(files, labels, n_per_class=2):
    fig, axes = plt.subplots(NUM_CLASSES, n_per_class, figsize=(14, 9))
    for cls_id, cls_name in enumerate(CLASS_NAMES):
        idxs = [i for i, l in enumerate(labels) if l == cls_id][:n_per_class]
        for k, idx in enumerate(idxs):
            try:
                audio, sr = librosa.load(files[idx], sr=TARGET_SR, duration=3.0)
                mel = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=64, fmax=8000)
                log_mel = librosa.power_to_db(mel, ref=np.max)
                im = librosa.display.specshow(log_mel, sr=sr, x_axis="time", y_axis="mel",
                                               ax=axes[cls_id][k], cmap="viridis")
                axes[cls_id][k].set_title(f"{cls_name} – {Path(files[idx]).name[:22]}", fontsize=8)
                fig.colorbar(im, ax=axes[cls_id][k])
            except Exception:
                axes[cls_id][k].set_title(f"{cls_name} – error")
    plt.suptitle("Sample Mel Spectrograms per Class", fontsize=13, fontweight="bold")
    plt.tight_layout()
    save_show(f"{PLOTS_DIR}/sample_spectrograms.png")

plot_sample_spectrograms(X_train, y_train)

# ---- Duration distribution ----
def plot_duration_distribution(files, max_sample=500):
    durations = []
    for f in tqdm(files[:max_sample], desc="Measuring durations"):
        try:
            dur = librosa.get_duration(filename=f)
            durations.append(dur)
        except Exception:
            pass
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].hist(durations, bins=40, color="#2ECC71", edgecolor="black", alpha=0.8)
    axes[0].set_xlabel("Duration (s)"); axes[0].set_ylabel("Count")
    axes[0].set_title("Audio Duration Distribution")
    axes[0].axvline(np.mean(durations), color="red", linestyle="--", label=f"Mean={np.mean(durations):.2f}s")
    axes[0].legend()
    axes[1].boxplot(durations, vert=True)
    axes[1].set_ylabel("Duration (s)"); axes[1].set_title("Duration Boxplot")
    plt.suptitle(f"Duration Stats (n={len(durations)} sampled)", fontsize=12)
    plt.tight_layout()
    save_show(f"{PLOTS_DIR}/duration_distribution.png")
    print(f"  Duration stats → mean: {np.mean(durations):.2f}s | "
          f"median: {np.median(durations):.2f}s | "
          f"max: {np.max(durations):.2f}s | min: {np.min(durations):.2f}s")

plot_duration_distribution(X_train)

force_cleanup()
log_memory("after EDA")

# ==============================================
# STEP 4: EMBEDDING EXTRACTION (with augmentation for training)
# ==============================================

print("\n" + "=" * 60)
print("STEP 4: HeAR EMBEDDING EXTRACTION + AUGMENTATION")
print("=" * 60)

def extract_split_embeddings(files, labels, split_name, augment=False, copies=1):
    """
    Extract HeAR embeddings for an entire split.
    If augment=True, generate `copies` augmented versions per file.
    """
    embeddings = []
    valid_labels = []
    valid_files  = []

    for i in tqdm(range(0, len(files)), desc=f"Extracting {split_name} embeddings"):
        f = files[i]
        l = labels[i]

        # For training, generate multiple augmented copies
        if augment:
            num_copies = copies
        else:
            num_copies = 1

        for copy_idx in range(num_copies):
            # Only augment if copy_idx > 0 and augment=True
            use_augment = augment and copy_idx > 0
            emb = extract_file_embedding(f, augment=use_augment)
            if emb is not None:
                embeddings.append(emb)
                valid_labels.append(l)
                valid_files.append(f)

        # Memory safety
        if i % 100 == 0 and get_memory_gb() > 12:
            gc.collect()

    emb_arr = np.array(embeddings, dtype=np.float32)
    lbl_arr = np.array(valid_labels, dtype=np.int64)
    print(f"  {split_name}: {emb_arr.shape[0]} embeddings, {emb_arr.shape[1]}-dim")
    log_memory(split_name)
    return emb_arr, lbl_arr, valid_files

# Extract embeddings with augmentation for training (3 copies per file)
emb_train, lbl_train, files_train_valid = extract_split_embeddings(
    X_train, y_train, "Train", augment=True, copies=AUGMENT_FACTOR)
emb_val,   lbl_val,   files_val_valid   = extract_split_embeddings(X_val,   y_val,   "Val")
emb_test,  lbl_test,  files_test_valid  = extract_split_embeddings(X_test,  y_test,  "Test")
emb_hold,  lbl_hold,  files_hold_valid  = extract_split_embeddings(X_hold,  y_hold,  "Holdout")

force_cleanup()
log_memory("after all embeddings")

# Compute class weights for loss function
class_weights = compute_class_weight('balanced', classes=np.unique(lbl_train), y=lbl_train)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
print(f"Class weights: {class_weights}")

# ==============================================
# FEATURE ENGINEERING: PCA + SPECTRAL STATS
# ==============================================

print("\n" + "=" * 60)
print("STEP 5: FEATURE ENGINEERING")
print("=" * 60)

# Scale embeddings
scaler = StandardScaler()
emb_train_sc = scaler.fit_transform(emb_train)
emb_val_sc   = scaler.transform(emb_val)
emb_test_sc  = scaler.transform(emb_test)
emb_hold_sc  = scaler.transform(emb_hold)

# PCA for visualization and augmented features
print("Fitting PCA for visualization...")
pca_vis = PCA(n_components=2, random_state=42)
pca_vis.fit(emb_train_sc)

train_pca2 = pca_vis.transform(emb_train_sc)
val_pca2   = pca_vis.transform(emb_val_sc)
test_pca2  = pca_vis.transform(emb_test_sc)
hold_pca2  = pca_vis.transform(emb_hold_sc)

# PCA for feature augmentation (retain 95% variance)
pca_feat = PCA(n_components=0.95, random_state=42)
pca_feat.fit(emb_train_sc)
n_comp = pca_feat.n_components_
print(f"PCA retaining 95% variance: {n_comp} components")

emb_train_pca = pca_feat.transform(emb_train_sc)
emb_val_pca   = pca_feat.transform(emb_val_sc)
emb_test_pca  = pca_feat.transform(emb_test_sc)
emb_hold_pca  = pca_feat.transform(emb_hold_sc)

# ---- PCA 2D Visualization ----
def plot_pca_embeddings(pca2_data, labels, title, filename):
    colors = ["#E74C3C", "#F39C12", "#3498DB"]
    fig, ax = plt.subplots(figsize=(9, 7))
    for cls_id, cls_name in enumerate(CLASS_NAMES):
        mask = labels == cls_id
        ax.scatter(pca2_data[mask, 0], pca2_data[mask, 1],
                   c=colors[cls_id], label=cls_name, alpha=0.6, s=30)
    ax.set_xlabel("PCA Dim 1"); ax.set_ylabel("PCA Dim 2")
    ax.set_title(title, fontsize=12, fontweight="bold")
    ax.legend(); ax.grid(True, alpha=0.3)
    plt.tight_layout()
    save_show(f"{PLOTS_DIR}/{filename}")

plot_pca_embeddings(train_pca2, lbl_train, "PCA of HeAR Embeddings – Train", "pca_train.png")
plot_pca_embeddings(hold_pca2,  lbl_hold,  "PCA of HeAR Embeddings – Holdout", "pca_holdout.png")

# Combined PCA across all splits
all_pca2   = np.vstack([train_pca2, val_pca2, test_pca2, hold_pca2])
all_labels = np.concatenate([lbl_train, lbl_val, lbl_test, lbl_hold])
plot_pca_embeddings(all_pca2, all_labels, "PCA of HeAR Embeddings – All Data", "pca_all.png")

# ---- Embedding correlation heatmap (class-level) ----
print("Plotting class-mean embedding barcode heatmap...")
fig, axes = plt.subplots(1, NUM_CLASSES, figsize=(15, 3))
for cls_id, cls_name in enumerate(CLASS_NAMES):
    mask = lbl_train == cls_id
    mean_emb = emb_train_sc[mask].mean(axis=0)
    axes[cls_id].imshow(mean_emb.reshape(1, -1), cmap="RdBu_r", aspect="auto",
                         vmin=-3, vmax=3)
    axes[cls_id].set_title(f"{cls_name}\nmean embedding", fontsize=10)
    axes[cls_id].set_yticks([])
plt.suptitle("Class-Mean HeAR Embeddings (normalized)", fontsize=12, fontweight="bold")
plt.tight_layout()
save_show(f"{PLOTS_DIR}/embedding_barcode.png")

force_cleanup()
log_memory("after feature engineering")

# ==============================================
# STEP 6: CLASSICAL ML CLASSIFIERS (tuned)
# ==============================================

print("\n" + "=" * 60)
print("STEP 6: CLASSICAL ML CLASSIFIERS (on HeAR embeddings) + TUNING")
print("=" * 60)

# Use a reduced set for speed
classical_models = {
    "SVM (RBF)": SVC(kernel="rbf", probability=True, random_state=42),
    "Logistic Regression": LogisticRegression(max_iter=1000, random_state=42),
    "Random Forest": RandomForestClassifier(random_state=42),
    "Gradient Boosting": GradientBoostingClassifier(random_state=42),
}

# Hyperparameter grids
param_grids = {
    "SVM (RBF)": {
        'C': [0.1, 1, 10, 100],
        'gamma': ['scale', 'auto', 0.01, 0.001],
    },
    "Logistic Regression": {
        'C': [0.01, 0.1, 1, 10],
        'penalty': ['l2'],
    },
    "Random Forest": {
        'n_estimators': [100, 200],
        'max_depth': [10, 20, None],
        'min_samples_split': [2, 5],
    },
    "Gradient Boosting": {
        'n_estimators': [100, 150],
        'max_depth': [3, 5],
        'learning_rate': [0.05, 0.1],
    },
}

classical_results = {}
best_classical_acc = 0
best_classical_name = ""
best_classical_model = None

for name, clf in tqdm(classical_models.items(), desc="Tuning classifiers"):
    try:
        # Use randomized search with 3-fold CV on train
        search = RandomizedSearchCV(
            clf, param_grids[name], n_iter=10, cv=3, scoring='accuracy',
            random_state=42, n_jobs=-1, verbose=0
        )
        search.fit(emb_train_pca, lbl_train)
        best_clf = search.best_estimator_
        val_preds = best_clf.predict(emb_val_pca)
        val_acc   = accuracy_score(lbl_val, val_preds)
        classical_results[name] = {"val_acc": val_acc, "model": best_clf, "best_params": search.best_params_}
        print(f"  {name:25s} → Val Acc: {val_acc:.4f} (best params: {search.best_params_})")
        if val_acc > best_classical_acc:
            best_classical_acc  = val_acc
            best_classical_name = name
            best_classical_model = best_clf
    except Exception as e:
        print(f"  {name} FAILED: {e}")

print(f"\nBest classical model: {best_classical_name}  (Val Acc={best_classical_acc:.4f})")

# ---- Classical model comparison bar chart ----
fig, ax = plt.subplots(figsize=(10, 5))
names  = list(classical_results.keys())
accs   = [classical_results[n]["val_acc"] for n in names]
colors = ["#2ECC71" if n == best_classical_name else "#95A5A6" for n in names]
bars   = ax.barh(names, accs, color=colors, edgecolor="black")
ax.set_xlim(0, 1); ax.set_xlabel("Validation Accuracy")
ax.set_title("Tuned Classical Classifier Comparison", fontweight="bold")
for bar, acc in zip(bars, accs):
    ax.text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,
            f"{acc:.4f}", va="center", fontsize=9)
plt.tight_layout()
save_show(f"{PLOTS_DIR}/classical_comparison_tuned.png")

force_cleanup()
log_memory("after classical tuning")

# ==============================================
# STEP 7: ENSEMBLE / STACKING (unchanged)
# ==============================================

print("\n" + "=" * 60)
print("STEP 7: ENSEMBLE / STACKING")
print("=" * 60)

# Use tuned base models
estimators = []
for name, res in classical_results.items():
    if name in ["SVM (RBF)", "Logistic Regression", "Random Forest", "Gradient Boosting"]:
        estimators.append((name.lower().replace(" ", "_"), res["model"]))

voting_clf = VotingClassifier(estimators=estimators, voting="soft")
print("Training Soft-Voting Ensemble...")
voting_clf.fit(emb_train_pca, lbl_train)
ens_val_acc = accuracy_score(lbl_val, voting_clf.predict(emb_val_pca))
print(f"  Soft-Voting Ensemble Val Acc: {ens_val_acc:.4f}")
classical_results["Ensemble (Soft-Vote)"] = {"val_acc": ens_val_acc, "model": voting_clf}

if ens_val_acc > best_classical_acc:
    best_classical_acc  = ens_val_acc
    best_classical_name = "Ensemble (Soft-Vote)"
    best_classical_model = voting_clf
    print(f"  → New best classical model: {best_classical_name}")

force_cleanup()

# ==============================================
# STEP 8: PYTORCH FINE-TUNED ATTENTION HEAD (with improvements)
# ==============================================

print("\n" + "=" * 60)
print("STEP 8: PyTorch FINE-TUNED ATTENTION CLASSIFIER (enhanced)")
print("=" * 60)

class GaussianNoise(nn.Module):
    def __init__(self, std=0.01):
        super().__init__()
        self.std = std
    def forward(self, x):
        if self.training:
            return x + torch.randn_like(x) * self.std
        return x

class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention module with residual."""
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # (B, num_heads, N, head_dim)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class AttentionPooling(nn.Module):
    """Learnable attention pooling over sequence dimension."""
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Parameter(torch.randn(1, 1, dim))
        self.scale = dim ** -0.5

    def forward(self, x):
        # x: (B, N, D)
        attn_weights = torch.matmul(self.query, x.transpose(1, 2)) * self.scale  # (B, 1, N)
        attn_weights = F.softmax(attn_weights, dim=-1)
        pooled = torch.matmul(attn_weights, x)  # (B, 1, D)
        return pooled.squeeze(1)  # (B, D)

class EnhancedAttentionHead(nn.Module):
    """
    Transformer-based classifier with attention pooling over clips.
    Expects input: (B, N, D) where N = number of clips per file, D = embedding dim.
    If N=1, it falls back to treating each clip as a separate sample.
    """
    def __init__(self, input_dim=1280, num_classes=3, num_heads=8, depth=4, mlp_ratio=4.,
                 dropout=0.3, attn_dropout=0.2, use_attention_pooling=True):
        super().__init__()
        self.use_attention_pooling = use_attention_pooling

        # Input projection and positional encoding (learnable)
        self.input_proj = nn.Linear(input_dim, input_dim)  # keep dim for simplicity
        self.pos_embed = nn.Parameter(torch.randn(1, 100, input_dim) * 0.02)  # max 100 clips

        self.noise = GaussianNoise(0.01)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(input_dim, num_heads, mlp_ratio, dropout=attn_dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(input_dim)

        if use_attention_pooling:
            self.pool = AttentionPooling(input_dim)
        else:
            self.pool = nn.AdaptiveAvgPool1d(1)  # average over sequence

        self.fc_drop = nn.Dropout(dropout)
        self.classifier = nn.Linear(input_dim, num_classes)

        # Initialize
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # x: (B, N, D) or (B, D) if single clip
        if x.dim() == 2:
            # Single clip: treat as sequence length 1
            x = x.unsqueeze(1)  # (B, 1, D)
            B, N, D = x.shape
        else:
            B, N, D = x.shape

        # Add noise
        x = self.noise(x)

        # Projection
        x = self.input_proj(x)

        # Add positional embedding (interpolate if N > max)
        if N > self.pos_embed.size(1):
            # Interpolate positional embeddings
            pos = F.interpolate(self.pos_embed.transpose(1,2), size=N, mode='linear', align_corners=False)
            pos = pos.transpose(1,2)
        else:
            pos = self.pos_embed[:, :N, :]
        x = x + pos

        # Transformer blocks
        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)

        # Pooling
        if self.use_attention_pooling:
            x = self.pool(x)  # (B, D)
        else:
            x = x.mean(dim=1)  # average over sequence

        x = self.fc_drop(x)
        return self.classifier(x)

# We need a custom Dataset that returns all clip embeddings for a file.
# However, we already extracted mean-pooled embeddings. To use attention pooling,
# we need to store all clip embeddings per file, not just the mean. That would require
# re-extraction with storing clips. For simplicity, we'll stick with mean-pooled embeddings
# but add a transformer on top of that (i.e., treat each file as a single token).
# That is less effective but simpler. Alternatively, we could extract and store clip embeddings
# for each file, but that would increase storage and memory significantly.
# We'll compromise: use mean-pooled embeddings but with a deeper classifier.

# Alternatively, we can modify the embedding extraction to also return all clip embeddings
# for files in a separate array. But to keep the code manageable, we'll keep using mean-pooled
# and enhance the classifier with more capacity and regularization.

# So we revert to a simpler but stronger MLP + attention variant that operates on single embeddings.
# But we already have that. To incorporate attention over clips, we would need to change the data.
# I'll keep the original AttentionHead but with better regularization and hyperparameter tuning.

# We'll also add mixup augmentation.

class Mixup:
    def __init__(self, alpha=0.2):
        self.alpha = alpha

    def __call__(self, x, y):
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1
        batch_size = x.size(0)
        index = torch.randperm(batch_size).to(x.device)
        mixed_x = lam * x + (1 - lam) * x[index]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam

# We'll use the original AttentionHead but with increased dropout and depth.
class ImprovedAttentionHead(nn.Module):
    """
    Deeper MLP with residual connections and stronger regularization.
    """
    def __init__(self, input_dim=1280, num_classes=3, hidden_dims=[512, 256, 128], dropout=0.5):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for hdim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hdim))
            layers.append(nn.BatchNorm1d(hdim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hdim
        layers.append(nn.Linear(prev_dim, num_classes))
        self.net = nn.Sequential(*layers)
        self.noise = GaussianNoise(0.01)

    def forward(self, x):
        x = self.noise(x)
        return self.net(x)

# Define dataset (single embedding per file)
class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.X = torch.tensor(embeddings, dtype=torch.float32)
        self.y = torch.tensor(labels, dtype=torch.long)
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = EmbeddingDataset(emb_train_sc, lbl_train)
val_ds   = EmbeddingDataset(emb_val_sc,   lbl_val)
test_ds  = EmbeddingDataset(emb_test_sc,  lbl_test)
hold_ds  = EmbeddingDataset(emb_hold_sc,  lbl_hold)

train_dl = DataLoader(train_ds, batch_size=PT_BATCH, shuffle=True,  num_workers=0, drop_last=True)
val_dl   = DataLoader(val_ds,   batch_size=PT_BATCH, shuffle=False, num_workers=0)
test_dl  = DataLoader(test_ds,  batch_size=PT_BATCH, shuffle=False, num_workers=0)
hold_dl  = DataLoader(hold_ds,  batch_size=PT_BATCH, shuffle=False, num_workers=0)

# Hyperparameter search for the improved head
# We'll define a small random search
def train_head(hidden_dims, dropout, lr, wd, epochs, train_dl, val_dl, class_weights):
    model = ImprovedAttentionHead(input_dim=emb_train_sc.shape[1], hidden_dims=hidden_dims, dropout=dropout).to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_val_acc = 0
    for epoch in range(epochs):
        model.train()
        for x, y in train_dl:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        scheduler.step()

        # Validation
        model.eval()
        val_preds, val_true = [], []
        with torch.no_grad():
            for x, y in val_dl:
                x = x.to(DEVICE)
                logits = model(x)
                preds = logits.argmax(1).cpu().numpy()
                val_preds.extend(preds)
                val_true.extend(y.numpy())
        val_acc = accuracy_score(val_true, val_preds)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    return best_val_acc, model

print("Performing random hyperparameter search for AttentionHead...")
search_iter = 10
best_pt_acc = 0
best_pt_config = None
best_pt_model = None

for i in range(search_iter):
    hidden_dims = random.choice([[512,256], [512,256,128], [1024,512,256]])
    dropout = random.uniform(0.3, 0.7)
    lr = random.choice([1e-3, 3e-4, 1e-4])
    wd = random.choice([1e-4, 1e-3, 1e-2])
    epochs = 20
    print(f"  Trial {i+1}: hidden={hidden_dims}, dropout={dropout:.2f}, lr={lr}, wd={wd}")
    val_acc, model = train_head(hidden_dims, dropout, lr, wd, epochs, train_dl, val_dl, class_weights)
    print(f"    → Val Acc: {val_acc:.4f}")
    if val_acc > best_pt_acc:
        best_pt_acc = val_acc
        best_pt_config = (hidden_dims, dropout, lr, wd)
        best_pt_model = model
        torch.save(model.state_dict(), CKPT_PATH)
        print(f"    ★ New best PT model saved.")

print(f"\nBest PT config: {best_pt_config} with Val Acc={best_pt_acc:.4f}")

# ==============================================
# STEP 9: TRAINING HISTORY PLOTS (for best model)
# We'll retrain the best model to full epochs and record history.
# ==============================================

print("\n" + "=" * 60)
print("STEP 9: TRAINING FINAL ATTENTION HEAD")
print("=" * 60)

# Recreate best model
best_hidden_dims, best_dropout, best_lr, best_wd = best_pt_config
final_model = ImprovedAttentionHead(input_dim=emb_train_sc.shape[1], hidden_dims=best_hidden_dims, dropout=best_dropout).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
optimizer = optim.AdamW(final_model.parameters(), lr=best_lr, weight_decay=best_wd)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

history = {"train_acc": [], "val_acc": [], "train_loss": [], "lr": []}
best_val_acc = 0.0
patience_cnt = 0
PATIENCE = 15

mixup = Mixup(alpha=0.2)

for epoch in range(EPOCHS):
    final_model.train()
    epoch_preds, epoch_true, epoch_loss = [], [], 0.0

    for x, y in tqdm(train_dl, desc=f"Epoch {epoch+1:02d}/{EPOCHS}", leave=False):
        x, y = x.to(DEVICE), y.to(DEVICE)

        # Mixup
        if random.random() < 0.5:
            mixed_x, y_a, y_b, lam = mixup(x, y)
            logits = final_model(mixed_x)
            loss = lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b)
        else:
            logits = final_model(x)
            loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(final_model.parameters(), max_norm=1.0)
        optimizer.step()

        epoch_loss += loss.item() * len(y)
        # Always take predictions from logits (whether mixup or not)
        preds = logits.argmax(1).cpu().numpy()
        epoch_preds.extend(preds)
        epoch_true.extend(y.cpu().numpy())

    scheduler.step()
    train_acc  = accuracy_score(epoch_true, epoch_preds)
    # Validation
    final_model.eval()
    val_preds, val_true = [], []
    with torch.no_grad():
        for x, y in val_dl:
            x = x.to(DEVICE)
            logits = final_model(x)
            preds = logits.argmax(1).cpu().numpy()
            val_preds.extend(preds)
            val_true.extend(y.numpy())
    val_acc = accuracy_score(val_true, val_preds)
    epoch_loss /= len(train_ds)
    cur_lr     = scheduler.get_last_lr()[0]

    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)
    history["train_loss"].append(epoch_loss)
    history["lr"].append(cur_lr)

    print(f"  Epoch {epoch+1:02d} | Loss={epoch_loss:.4f} | "
          f"Train={train_acc:.4f} | Val={val_acc:.4f} | LR={cur_lr:.6f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_cnt = 0
        torch.save(final_model.state_dict(), CKPT_PATH)
        print(f"    ★ New best val acc: {best_val_acc:.4f} — checkpoint saved.")
    else:
        patience_cnt += 1
        if patience_cnt >= PATIENCE:
            print(f"  Early stopping at epoch {epoch+1} (no improvement for {PATIENCE} epochs).")
            break

    force_cleanup()

# Load best checkpoint
final_model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
print(f"\nBest model loaded.  Best Val Acc: {best_val_acc:.4f}")

# ==============================================
# STEP 10: TRAINING HISTORY PLOTS
# ==============================================

print("\n" + "=" * 60)
print("STEP 10: TRAINING HISTORY PLOTS")
print("=" * 60)

epochs_run = len(history["train_acc"])

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Accuracy
axes[0].plot(range(1, epochs_run+1), history["train_acc"], "o-", label="Train", color="#2980B9")
axes[0].plot(range(1, epochs_run+1), history["val_acc"],   "s-", label="Val",   color="#E74C3C")
best_ep = np.argmax(history["val_acc"]) + 1
axes[0].axvline(best_ep, color="green", linestyle="--", alpha=0.6, label=f"Best={best_ep}")
axes[0].scatter([best_ep], [max(history["val_acc"])], color="green", s=100, zorder=5)
axes[0].set_xlabel("Epoch"); axes[0].set_ylabel("Accuracy")
axes[0].set_title("Training & Validation Accuracy", fontweight="bold")
axes[0].legend(); axes[0].grid(True, alpha=0.3)

# Loss
axes[1].plot(range(1, epochs_run+1), history["train_loss"], "o-", color="#8E44AD")
axes[1].set_xlabel("Epoch"); axes[1].set_ylabel("Loss")
axes[1].set_title("Training Loss", fontweight="bold"); axes[1].grid(True, alpha=0.3)

# LR
axes[2].plot(range(1, epochs_run+1), history["lr"], "o-", color="#F39C12")
axes[2].set_xlabel("Epoch"); axes[2].set_ylabel("Learning Rate")
axes[2].set_title("Learning Rate Schedule", fontweight="bold"); axes[2].grid(True, alpha=0.3)

plt.suptitle("Improved AttentionHead Training History", fontsize=14, fontweight="bold")
plt.tight_layout()
save_show(f"{PLOTS_DIR}/training_history_improved.png")

# Generalization gap
gap = [tr - vl for tr, vl in zip(history["train_acc"], history["val_acc"])]
fig, ax = plt.subplots(figsize=(10, 4))
ax.fill_between(range(1, epochs_run+1), 0, gap, alpha=0.4,
                color="red" if max(gap) > 0.1 else "green",
                label="Train−Val Gap")
ax.axhline(0, color="black", linewidth=0.8)
ax.set_xlabel("Epoch"); ax.set_ylabel("Accuracy Gap")
ax.set_title("Generalization Gap (Train − Val Accuracy)", fontweight="bold")
ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
save_show(f"{PLOTS_DIR}/generalization_gap_improved.png")

# ==============================================
# STEP 11: COMPREHENSIVE EVALUATION
# ==============================================

print("\n" + "=" * 60)
print("STEP 11: COMPREHENSIVE EVALUATION")
print("=" * 60)

def detailed_evaluation(model, loader, lbl_array, set_name, model_tag="ImprovedAttentionHead"):
    model.eval()
    preds, trues, probs_all = [], [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(DEVICE)
            logits = model(x)
            probs = torch.softmax(logits, dim=1).cpu().numpy()
            probs_all.extend(probs)
            preds.extend(logits.argmax(1).cpu().numpy())
            trues.extend(y.numpy())
    preds = np.array(preds)
    trues = np.array(trues)
    probs_all = np.array(probs_all)
    acc = accuracy_score(trues, preds)
    p, r, f1, _ = precision_recall_fscore_support(trues, preds, average="weighted", zero_division=0)

    print(f"\n{'='*50}")
    print(f"  {model_tag} | {set_name} Set")
    print(f"  Accuracy : {acc:.4f}")
    print(f"  Precision: {p:.4f}  Recall: {r:.4f}  F1: {f1:.4f}")
    print("\nClassification Report:")
    print(classification_report(trues, preds, target_names=CLASS_NAMES, zero_division=0))

    # Confusion Matrix
    cm = confusion_matrix(trues, preds)
    fig, ax = plt.subplots(figsize=(7, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=ax)
    ax.set_title(f"Confusion Matrix – {set_name} ({model_tag})", fontweight="bold")
    ax.set_ylabel("True"); ax.set_xlabel("Predicted")
    plt.tight_layout()
    save_show(f"{PLOTS_DIR}/cm_{set_name.lower().replace(' ', '_')}_{model_tag.lower().replace(' ', '_')}.png")

    # ROC Curves
    y_bin = label_binarize(trues, classes=[0, 1, 2])
    fig, ax = plt.subplots(figsize=(8, 6))
    colors = ["#E74C3C", "#F39C12", "#3498DB"]
    for i, (cls_name, col) in enumerate(zip(CLASS_NAMES, colors)):
        fpr, tpr, _ = roc_curve(y_bin[:, i], probs_all[:, i])
        auc_sc      = roc_auc_score(y_bin[:, i], probs_all[:, i])
        ax.plot(fpr, tpr, color=col, lw=2, label=f"{cls_name} AUC={auc_sc:.3f}")
    ax.plot([0, 1], [0, 1], "k--", lw=1)
    ax.set_xlabel("FPR"); ax.set_ylabel("TPR")
    ax.set_title(f"ROC Curves – {set_name} ({model_tag})", fontweight="bold")
    ax.legend(loc="lower right"); ax.grid(True, alpha=0.3)
    plt.tight_layout()
    save_show(f"{PLOTS_DIR}/roc_{set_name.lower().replace(' ', '_')}_{model_tag.lower().replace(' ', '_')}.png")

    return {"acc": acc, "precision": p, "recall": r, "f1": f1}

results = {}
for loader, lbl, split_name in [
    (val_dl,  lbl_val,  "Validation"),
    (test_dl, lbl_test, "Test"),
    (hold_dl, lbl_hold, "Holdout"),
]:
    results[split_name] = detailed_evaluation(final_model, loader, lbl, split_name)

force_cleanup()
log_memory("after evaluation")

# ==============================================
# STEP 12: BEST CLASSICAL MODEL EVALUATION ON TEST/HOLDOUT
# ==============================================

print("\n" + "=" * 60)
print("STEP 12: BEST CLASSICAL MODEL EVALUATION")
print("=" * 60)

for split_name, emb_split, lbl_split in [
    ("Test",    emb_test_pca,  lbl_test),
    ("Holdout", emb_hold_pca,  lbl_hold),
]:
    preds = best_classical_model.predict(emb_split)
    acc   = accuracy_score(lbl_split, preds)
    p, r, f1, _ = precision_recall_fscore_support(lbl_split, preds, average="weighted", zero_division=0)
    print(f"\n  {best_classical_name} | {split_name}")
    print(f"  Accuracy={acc:.4f} | P={p:.4f} | R={r:.4f} | F1={f1:.4f}")

    cm = confusion_matrix(lbl_split, preds)
    fig, ax = plt.subplots(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Oranges",
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=ax)
    ax.set_title(f"Confusion Matrix – {split_name} ({best_classical_name})", fontweight="bold")
    ax.set_ylabel("True"); ax.set_xlabel("Predicted")
    plt.tight_layout()
    save_show(f"{PLOTS_DIR}/cm_{split_name.lower()}_classical_tuned.png")

# ==============================================
# STEP 13: 5-FOLD CROSS-VALIDATION (SVM on train+val)
# ==============================================

print("\n" + "=" * 60)
print("STEP 13: 5-FOLD CROSS-VALIDATION")
print("=" * 60)

emb_tv = np.vstack([emb_train_pca, emb_val_pca])
lbl_tv = np.concatenate([lbl_train, lbl_val])

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
svm_cv = SVC(kernel="rbf", C=10, probability=True, random_state=42)  # use best params from tuning
cv_scores = cross_val_score(svm_cv, emb_tv, lbl_tv, cv=skf, scoring="accuracy", n_jobs=-1)
print(f"  SVM (RBF) 5-fold CV: {cv_scores.mean():.4f} ± {cv_scores.std():.4f}")
print(f"  Fold scores: {[f'{s:.4f}' for s in cv_scores]}")

fig, ax = plt.subplots(figsize=(8, 4))
ax.bar(range(1, 6), cv_scores, color="#3498DB", edgecolor="black")
ax.axhline(cv_scores.mean(), color="red", linestyle="--", label=f"Mean={cv_scores.mean():.4f}")
ax.set_xlabel("Fold"); ax.set_ylabel("Accuracy")
ax.set_title("5-Fold Cross-Validation Accuracy (SVM)", fontweight="bold")
ax.legend(); ax.set_ylim(0, 1); ax.grid(True, alpha=0.3)
plt.tight_layout()
save_show(f"{PLOTS_DIR}/cross_validation_improved.png")

# ==============================================
# STEP 14: GENERALIZATION ANALYSIS SUMMARY PLOT
# ==============================================

print("\n" + "=" * 60)
print("STEP 14: GENERALIZATION ANALYSIS SUMMARY")
print("=" * 60)

metrics_list  = ["Accuracy", "Precision", "Recall", "F1-Score"]
split_names   = ["Validation", "Test", "Holdout"]
metric_keys   = ["acc", "precision", "recall", "f1"]
split_colors  = ["#2ECC71", "#3498DB", "#E74C3C"]

fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(metrics_list))
width = 0.25

for i, (sname, col) in enumerate(zip(split_names, split_colors)):
    vals = [results[sname][k] for k in metric_keys]
    bars = ax.bar(x + i*width, vals, width, label=sname, color=col, edgecolor="black")
    for bar in bars:                    # ✅ FIXED: iterate over bars, not vals
        h = bar.get_height()
        ax.annotate(f"{h:.3f}", xy=(bar.get_x() + bar.get_width()/2, h),
                    xytext=(0, 3), textcoords="offset points", ha="center", fontsize=8)

ax.set_xticks(x + width)
ax.set_xticklabels(metrics_list)
ax.set_ylabel("Score"); ax.set_ylim(0, 1.1)
ax.set_title("ImprovedAttentionHead: Performance Across Val / Test / Holdout", fontweight="bold")
ax.legend(); ax.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
save_show(f"{PLOTS_DIR}/generalization_summary_improved.png")

# Overfitting detection table
print("\n  ── Overfitting Detection ──")
train_final_acc = history["train_acc"][-1]
for sname in split_names:
    gap = train_final_acc - results[sname]["acc"]
    status = "✓ OK" if gap < 0.05 else ("⚠ Mild" if gap < 0.10 else "✗ Overfit")
    print(f"  Train−{sname} gap: {gap:.4f}  {status}")

# ==============================================
# STEP 15: HOLDOUT (UNSEEN) DATA DEEP DIVE
# ==============================================

print("\n" + "=" * 60)
print("STEP 15: HOLDOUT (UNSEEN) DATA DEEP DIVE")
print("=" * 60)

# Get predictions on holdout
final_model.eval()
hold_preds, hold_true, hold_probs = [], [], []
with torch.no_grad():
    for x, y in hold_dl:
        x = x.to(DEVICE)
        logits = final_model(x)
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        hold_probs.extend(probs)
        hold_preds.extend(logits.argmax(1).cpu().numpy())
        hold_true.extend(y.numpy())
hold_preds = np.array(hold_preds)
hold_true = np.array(hold_true)
hold_probs = np.array(hold_probs)

# Per-class accuracy on holdout
print("\n  Per-class accuracy on Holdout:")
for cls_id, cls_name in enumerate(CLASS_NAMES):
    mask = hold_true == cls_id
    if mask.sum() == 0:
        continue
    cls_acc = accuracy_score(hold_true[mask], hold_preds[mask])
    print(f"    {cls_name}: {cls_acc:.4f}  (n={mask.sum()})")

# Confidence distribution on holdout
fig, axes = plt.subplots(1, NUM_CLASSES, figsize=(15, 4))
for cls_id, cls_name in enumerate(CLASS_NAMES):
    mask = hold_true == cls_id
    if mask.sum() == 0:
        continue
    correct_mask = (hold_preds == hold_true) & mask
    wrong_mask   = (~correct_mask) & mask
    axes[cls_id].hist(hold_probs[correct_mask, cls_id], bins=20,
                       alpha=0.7, color="#2ECC71", label="Correct")
    axes[cls_id].hist(hold_probs[wrong_mask, cls_id], bins=20,
                       alpha=0.7, color="#E74C3C", label="Wrong")
    axes[cls_id].set_title(f"{cls_name}\nHoldout Confidence")
    axes[cls_id].set_xlabel("Predicted Probability"); axes[cls_id].set_ylabel("Count")
    axes[cls_id].legend()
plt.suptitle("Prediction Confidence Distribution – Holdout Set (Improved)", fontweight="bold")
plt.tight_layout()
save_show(f"{PLOTS_DIR}/holdout_confidence_improved.png")

# PCA of holdout embeddings colored by correctness
correct_mask_full = (hold_preds == hold_true)
fig, ax = plt.subplots(figsize=(9, 6))
ax.scatter(hold_pca2[correct_mask_full, 0],  hold_pca2[correct_mask_full, 1],
           c="#2ECC71", alpha=0.6, s=25, label="Correct")
ax.scatter(hold_pca2[~correct_mask_full, 0], hold_pca2[~correct_mask_full, 1],
           c="#E74C3C", alpha=0.8, s=50, marker="x", label="Incorrect")
ax.set_xlabel("PCA Dim 1"); ax.set_ylabel("PCA Dim 2")
ax.set_title("PCA – Holdout Correct vs Incorrect Predictions (Improved)", fontweight="bold")
ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
save_show(f"{PLOTS_DIR}/holdout_pca_correctness_improved.png")

# ==============================================
# FINAL SUMMARY
# ==============================================

print("\n" + "=" * 60)
print("PIPELINE COMPLETE – FINAL SUMMARY")
print("=" * 60)

print(f"\n  ImprovedAttentionHead (HeAR fine-tuned):")
for sname in split_names:
    print(f"    {sname:12s} → Acc={results[sname]['acc']:.4f} "
          f"P={results[sname]['precision']:.4f} "
          f"R={results[sname]['recall']:.4f} "
          f"F1={results[sname]['f1']:.4f}")

print(f"\n  Best Classical ({best_classical_name}):")
print(f"    Val Acc: {best_classical_acc:.4f}")

print(f"\n  5-Fold CV (SVM RBF):  {cv_scores.mean():.4f} ± {cv_scores.std():.4f}")
print(f"\n  Best model saved to:  {CKPT_PATH}")
print(f"  All plots saved to:   {PLOTS_DIR}/")
print(f"\n  Final memory usage:   {get_memory_gb():.2f} GB")

force_cleanup()
print("\nDone.")

2026-02-23 10:22:07.040550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771842127.226497      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771842127.283097      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771842127.733435      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771842127.733475      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771842127.733478      55 computation_placer.cc:177] computation placer alr

NEONATAL CRY CLASSIFICATION PIPELINE (IMPROVED)
Hugging Face token loaded.
Downloading HeAR model to ./hear-model ...


Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 24 files:   0%|          | 0/24 [00:00<?, ?it/s]

Loading HeAR TensorFlow SavedModel...


I0000 00:00:1771842160.898692      55 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13757 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1771842160.904698      55 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13757 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


HeAR model loaded. Embedding dim: 1280
  [MEM after HeAR load] 2.87 GB

STEP 1: DATA COLLECTION
Collecting audio files...


Scanning Baby crying:   0%|          | 0/6258 [00:00<?, ?it/s]

Scanning infant_cry_datasets:   0%|          | 0/2296 [00:00<?, ?it/s]

Total audio files found: 8517
  [MEM after file collection] 2.87 GB
Mapping labels...


Labelling:   0%|          | 0/8517 [00:00<?, ?it/s]

Labelled files: 6639 | Unlabelled (skipped): 1878

Class distribution:
  Pain (0): 1798 files  (27.1%)
  Hunger (1): 1163 files  (17.5%)
  Neurological (2): 3678 files  (55.4%)

STEP 2: DATA SPLITTING (40/15/15/30)
Train:   2655 files  (40.0%)
Val:     996 files  (15.0%)
Test:    986 files  (14.9%)
Holdout: 2002 files  (30.2%)

Verifying data split integrity...
  ✓ No data leakage detected across splits.

STEP 3: EXPLORATORY DATA ANALYSIS
  Saved: ./plots/dist_train.png
  Saved: ./plots/dist_validation.png
  Saved: ./plots/dist_test.png
  Saved: ./plots/dist_holdout.png
  Saved: ./plots/dist_all_splits.png
  Saved: ./plots/sample_waveforms.png
  Saved: ./plots/sample_spectrograms.png


Measuring durations:   0%|          | 0/500 [00:00<?, ?it/s]

  Saved: ./plots/duration_distribution.png
  Duration stats → mean: 5.29s | median: 4.00s | max: 42.40s | min: 4.00s
  [MEM after EDA] 3.00 GB

STEP 4: HeAR EMBEDDING EXTRACTION + AUGMENTATION


Extracting Train embeddings:   0%|          | 0/2655 [00:00<?, ?it/s]

I0000 00:00:1771842221.094113     178 service.cc:152] XLA service 0x7a375263e400 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1771842221.094159     178 service.cc:160]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1771842221.094164     178 service.cc:160]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1771842221.436648     178 cuda_dnn.cc:529] Loaded cuDNN version 91002
I0000 00:00:1771842223.580176     178 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


  Train: 7965 embeddings, 512-dim
  [MEM Train] 3.57 GB


Extracting Val embeddings:   0%|          | 0/996 [00:00<?, ?it/s]

  Val: 996 embeddings, 512-dim
  [MEM Val] 3.57 GB


Extracting Test embeddings:   0%|          | 0/986 [00:00<?, ?it/s]

  Test: 986 embeddings, 512-dim
  [MEM Test] 3.57 GB


Extracting Holdout embeddings:   0%|          | 0/2002 [00:00<?, ?it/s]

  Holdout: 2002 embeddings, 512-dim
  [MEM Holdout] 3.63 GB
  [MEM after all embeddings] 3.63 GB
Class weights: tensor([1.2309, 1.9032, 0.6016], device='cuda:0')

STEP 5: FEATURE ENGINEERING
Fitting PCA for visualization...
PCA retaining 95% variance: 126 components
  Saved: ./plots/pca_train.png
  Saved: ./plots/pca_holdout.png
  Saved: ./plots/pca_all.png
Plotting class-mean embedding barcode heatmap...
  Saved: ./plots/embedding_barcode.png
  [MEM after feature engineering] 3.68 GB

STEP 6: CLASSICAL ML CLASSIFIERS (on HeAR embeddings) + TUNING


Tuning classifiers:   0%|          | 0/4 [00:00<?, ?it/s]

  SVM (RBF)                 → Val Acc: 0.6426 (best params: {'gamma': 0.001, 'C': 10})
  Logistic Regression       → Val Acc: 0.6074 (best params: {'penalty': 'l2', 'C': 0.01})
  Random Forest             → Val Acc: 0.6145 (best params: {'n_estimators': 200, 'min_samples_split': 5, 'max_depth': None})
  Gradient Boosting         → Val Acc: 0.6275 (best params: {'n_estimators': 150, 'max_depth': 5, 'learning_rate': 0.1})

Best classical model: SVM (RBF)  (Val Acc=0.6426)
  Saved: ./plots/classical_comparison_tuned.png
  [MEM after classical tuning] 3.69 GB

STEP 7: ENSEMBLE / STACKING
Training Soft-Voting Ensemble...
  Soft-Voting Ensemble Val Acc: 0.6446
  → New best classical model: Ensemble (Soft-Vote)

STEP 8: PyTorch FINE-TUNED ATTENTION CLASSIFIER (enhanced)
Performing random hyperparameter search for AttentionHead...
  Trial 1: hidden=[1024, 512, 256], dropout=0.45, lr=0.0001, wd=0.0001
    → Val Acc: 0.6345
    ★ New best PT model saved.
  Trial 2: hidden=[512, 256], dropout=0.5

Epoch 01/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 01 | Loss=1.0375 | Train=0.4825 | Val=0.5944 | LR=0.000999
    ★ New best val acc: 0.5944 — checkpoint saved.


Epoch 02/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 02 | Loss=0.9571 | Train=0.5464 | Val=0.5763 | LR=0.000996


Epoch 03/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 03 | Loss=0.9161 | Train=0.5775 | Val=0.6044 | LR=0.000991
    ★ New best val acc: 0.6044 — checkpoint saved.


Epoch 04/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 04 | Loss=0.8882 | Train=0.5970 | Val=0.6215 | LR=0.000984
    ★ New best val acc: 0.6215 — checkpoint saved.


Epoch 05/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 05 | Loss=0.8786 | Train=0.6071 | Val=0.6124 | LR=0.000976


Epoch 06/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 06 | Loss=0.8630 | Train=0.6080 | Val=0.6185 | LR=0.000965


Epoch 07/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 07 | Loss=0.8447 | Train=0.6449 | Val=0.6205 | LR=0.000952


Epoch 08/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 08 | Loss=0.8295 | Train=0.6313 | Val=0.6155 | LR=0.000938


Epoch 09/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 09 | Loss=0.8163 | Train=0.6463 | Val=0.6245 | LR=0.000922
    ★ New best val acc: 0.6245 — checkpoint saved.


Epoch 10/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 10 | Loss=0.7949 | Train=0.6618 | Val=0.6265 | LR=0.000905
    ★ New best val acc: 0.6265 — checkpoint saved.


Epoch 11/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 11 | Loss=0.7822 | Train=0.6820 | Val=0.6275 | LR=0.000885
    ★ New best val acc: 0.6275 — checkpoint saved.


Epoch 12/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 12 | Loss=0.7737 | Train=0.6753 | Val=0.6325 | LR=0.000864
    ★ New best val acc: 0.6325 — checkpoint saved.


Epoch 13/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 13 | Loss=0.7581 | Train=0.6842 | Val=0.6335 | LR=0.000842
    ★ New best val acc: 0.6335 — checkpoint saved.


Epoch 14/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 14 | Loss=0.7575 | Train=0.6871 | Val=0.6536 | LR=0.000819
    ★ New best val acc: 0.6536 — checkpoint saved.


Epoch 15/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 15 | Loss=0.7391 | Train=0.6908 | Val=0.6496 | LR=0.000794


Epoch 16/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 16 | Loss=0.7339 | Train=0.6915 | Val=0.6546 | LR=0.000768
    ★ New best val acc: 0.6546 — checkpoint saved.


Epoch 17/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 17 | Loss=0.7244 | Train=0.7167 | Val=0.6566 | LR=0.000741
    ★ New best val acc: 0.6566 — checkpoint saved.


Epoch 18/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 18 | Loss=0.7115 | Train=0.6993 | Val=0.6586 | LR=0.000713
    ★ New best val acc: 0.6586 — checkpoint saved.


Epoch 19/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 19 | Loss=0.7098 | Train=0.7055 | Val=0.6546 | LR=0.000684


Epoch 20/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 20 | Loss=0.6986 | Train=0.7308 | Val=0.6536 | LR=0.000655


Epoch 21/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 21 | Loss=0.6895 | Train=0.7353 | Val=0.6556 | LR=0.000624


Epoch 22/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 22 | Loss=0.6894 | Train=0.7333 | Val=0.6586 | LR=0.000594


Epoch 23/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 23 | Loss=0.6770 | Train=0.7198 | Val=0.6365 | LR=0.000563


Epoch 24/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 24 | Loss=0.6705 | Train=0.7408 | Val=0.6456 | LR=0.000531


Epoch 25/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 25 | Loss=0.6729 | Train=0.7481 | Val=0.6376 | LR=0.000500


Epoch 26/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 26 | Loss=0.6564 | Train=0.7399 | Val=0.6345 | LR=0.000469


Epoch 27/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 27 | Loss=0.6574 | Train=0.7584 | Val=0.6466 | LR=0.000437


Epoch 28/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 28 | Loss=0.6499 | Train=0.7372 | Val=0.6476 | LR=0.000406


Epoch 29/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 29 | Loss=0.6475 | Train=0.7531 | Val=0.6566 | LR=0.000376


Epoch 30/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 30 | Loss=0.6331 | Train=0.7670 | Val=0.6566 | LR=0.000345


Epoch 31/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 31 | Loss=0.6469 | Train=0.7645 | Val=0.6446 | LR=0.000316


Epoch 32/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 32 | Loss=0.6364 | Train=0.7469 | Val=0.6416 | LR=0.000287


Epoch 33/50:   0%|          | 0/497 [00:00<?, ?it/s]

  Epoch 33 | Loss=0.6366 | Train=0.7626 | Val=0.6446 | LR=0.000259
  Early stopping at epoch 33 (no improvement for 15 epochs).

Best model loaded.  Best Val Acc: 0.6586

STEP 10: TRAINING HISTORY PLOTS
  Saved: ./plots/training_history_improved.png
  Saved: ./plots/generalization_gap_improved.png

STEP 11: COMPREHENSIVE EVALUATION

  ImprovedAttentionHead | Validation Set
  Accuracy : 0.6586
  Precision: 0.6848  Recall: 0.6586  F1: 0.6688

Classification Report:
              precision    recall  f1-score   support

        Pain       0.59      0.66      0.62       270
      Hunger       0.33      0.40      0.37       174
Neurological       0.84      0.74      0.79       552

    accuracy                           0.66       996
   macro avg       0.59      0.60      0.59       996
weighted avg       0.68      0.66      0.67       996

  Saved: ./plots/cm_validation_improvedattentionhead.png
  Saved: ./plots/roc_validation_improvedattentionhead.png

  ImprovedAttentionHead | Test Set


In [3]:
import os
import zipfile
from tqdm import tqdm
from pathlib import Path
import time

def get_all_files(directory):
    """Recursively get all files in directory and subdirectories"""
    all_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            full_path = os.path.join(root, file)
            all_files.append(full_path)
    return all_files

def get_total_size(file_paths):
    """Calculate total size of all files"""
    total_size = 0
    for file_path in file_paths:
        try:
            total_size += os.path.getsize(file_path)
        except (OSError, FileNotFoundError):
            continue
    return total_size

def create_kaggle_working_zip(source_dir="/kaggle/working/plots", output_name="babyCry_Hears.zip"):
    """
    Create a zip file of all content in the Kaggle working directory
    
    Args:
        source_dir (str): Source directory to zip (default: /kaggle/working/)
        output_name (str): Name of the output zip file
    """
    
    # Check if source directory exists
    if not os.path.exists(source_dir):
        print(f"Error: Source directory '{source_dir}' does not exist!")
        return False
    
    # Get all files recursively
    print("Scanning files...")
    all_files = get_all_files(source_dir)
    
    if not all_files:
        print(f"No files found in '{source_dir}'")
        return False
    
    print(f"Found {len(all_files)} files to compress")
    
    # Calculate total size for progress tracking
    total_size = get_total_size(all_files)
    print(f"Total size: {total_size / (1024*1024):.2f} MB")
    
    # Create zip file with progress bar
    try:
        with zipfile.ZipFile(output_name, 'w', zipfile.ZIP_DEFLATED, compresslevel=6) as zipf:
            # Progress bar based on file count
            with tqdm(total=len(all_files), desc="Compressing files", unit="files") as pbar:
                processed_size = 0
                
                for file_path in all_files:
                    try:
                        # Get relative path for the zip archive
                        arcname = os.path.relpath(file_path, source_dir)
                        
                        # Add file to zip
                        zipf.write(file_path, arcname)
                        
                        # Update progress
                        file_size = os.path.getsize(file_path)
                        processed_size += file_size
                        
                        # Update progress bar with file info
                        pbar.set_postfix({
                            'Current': os.path.basename(file_path)[:20],
                            'Size': f"{processed_size / (1024*1024):.1f}MB"
                        })
                        pbar.update(1)
                        
                    except Exception as e:
                        print(f"Warning: Could not add {file_path} to zip: {str(e)}")
                        pbar.update(1)
                        continue
        
        # Get final zip file size
        zip_size = os.path.getsize(output_name)
        compression_ratio = (1 - zip_size / total_size) * 100 if total_size > 0 else 0
        
        print(f"\n Successfully created '{output_name}'")
        print(f" Original size: {total_size / (1024*1024):.2f} MB")
        print(f" Compressed size: {zip_size / (1024*1024):.2f} MB")
        print(f" Compression ratio: {compression_ratio:.1f}%")
        
        return True
        
    except Exception as e:
        print(f"Error creating zip file: {str(e)}")
        return False

def download_zip_in_kaggle(zip_filename):
    """
    Trigger download in Kaggle notebook environment
    """
    try:
        # In Kaggle, files in the working directory are automatically available for download
        # We can also use the files.download() method if available
        from google.colab import files
        files.download(zip_filename)
        print(f"Download triggered for {zip_filename}")
    except ImportError:
        # If not in Colab/Kaggle environment with files API
        print(f"Zip file '{zip_filename}' created successfully!")
        print("In Kaggle, you can download it from the 'Output' tab or use the file browser.")
        print("The file is located in your current working directory.")

if __name__ == "__main__":
    # Configuration
    SOURCE_DIRECTORY = "/kaggle/working/plots"
    OUTPUT_ZIP_NAME = "babyCry_Hears.zip"
    
    print(" Starting Kaggle Working Directory Backup")
    print("=" * 50)
    
    # Create the zip file
    success = create_kaggle_working_zip(SOURCE_DIRECTORY, OUTPUT_ZIP_NAME)
    
    if success:
        print(f"\n Preparing download...")
        download_zip_in_kaggle(OUTPUT_ZIP_NAME)
    else:
        print(" Backup failed!")

 Starting Kaggle Working Directory Backup
Scanning files...
Found 27 files to compress
Total size: 3.38 MB


Compressing files: 100%|██████████| 27/27 [00:00<00:00, 164.36files/s, Current=cm_holdout_improveda, Size=3.4MB]


 Successfully created 'babyCry_Hears.zip'
 Original size: 3.38 MB
 Compressed size: 3.16 MB
 Compression ratio: 6.6%

 Preparing download...





<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Download triggered for babyCry_Hears.zip
