<a href="https://colab.research.google.com/github/tonyPooyappallil/OSA_Detection_Project/blob/main/signal_window_approach.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# mounting google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
import os

# PATH to the uploaded ZIP file in Google drive
zip_file_path = '/content/drive/MyDrive/AI_Sleep_Apnea/data/selected_records208.zip'

#zip_file_path = '/content/drive/MyDrive/AI_Sleep_Apnea/data/selected_records_stratified_30.zip'
# folder name inside the zip
extracted_folder_name = 'selected_records208'
#extracted_folder_name = 'selected_records_stratified_30'

# Defining where to unzip inside Colab's temporary storage
# Since Using /content/ makes it faster to access during processing
unzip_dir = '/content/temp_data'

# --- Unzipping command ---
print(f"Creating directory {unzip_dir} if it doesn't exist...")
os.makedirs(unzip_dir, exist_ok=True)

print(f"Unzipping {zip_file_path} to {unzip_dir}...")
# Using '-o' to overwrite without asking if run again, '-q' for quiet
!unzip -oq "{zip_file_path}" -d "{unzip_dir}"

print("Unzipping complete.")

# --- Defining the path to the actual data *after* the unzipping ---
# This is the path that will be used in the main script's config
colab_data_path = os.path.join(unzip_dir, extracted_folder_name)
print(f"Data should now be available in: {colab_data_path}")

# Verifying some contents
print("Example content listing:")
!ls "{colab_data_path}" | head

Creating directory /content/temp_data if it doesn't exist...
Unzipping /content/drive/MyDrive/AI_Sleep_Apnea/data/selected_records208.zip to /content/temp_data...
Unzipping complete.
Data should now be available in: /content/temp_data/selected_records208
Example content listing:
ls: cannot access '/content/temp_data/selected_records208': No such file or directory


In [6]:
#installing Dependencies
!pip install antropy scikit-image h5py wfdb -q
print("Dependencies installed.")

Dependencies installed.


In [3]:
# --- Required Libraries ---
import os
import re
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.io
import scipy.stats
import scipy.signal
import h5py # For the arousal .mat files v7.3
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, roc_auc_score, f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from sklearn.ensemble import RandomForestClassifier # For the secondary model
import shutil
import random
from tqdm import tqdm
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import regularizers
import time
import gc
import warnings
import wfdb
import skimage.transform
# import zipfile # Zipping logic now removed, import can be removed if not re-added
import copy # For deep copying of config
import multiprocessing

# Trying to import antropy, skimage, h5py needed for features/resizing/loading
try: import antropy
except ImportError: print("Error: pip install antropy"); exit()
try: import skimage.transform
except ImportError: print("Error: pip install scikit-image"); exit()
try: import h5py
except ImportError: print("Error: pip install h5py"); exit()

# --- Configuration ---
DRIVE_PROJECT_PATH = '/content/drive/MyDrive/AI_Sleep_Apnea/' # DRIVE PATH
UNZIPPED_DATA_PARENT_DIR = '/content/temp_data' # Local temp directory for the unzipped data
DATA_FOLDER_NAME = 'selected_records208' # Using the 208 records
DATA_DIR = os.path.join(UNZIPPED_DATA_PARENT_DIR, DATA_FOLDER_NAME)
LABELS_FILE = os.path.join(DATA_DIR, "selected_records_ahi_results_with_severity.csv") # Path to AHI results ground truth
DEMOGRAPHICS_FILE = os.path.join(DATA_DIR, "age-sex.csv") # Path to the demographics file
AROUSAL_MAT_SUFFIX = "-arousal.mat" # Suffix for the arousal files (used by wfdb.rdann)

# Output Directories
PERSISTENT_OUTPUT_DIR_BASE = os.path.join(DRIVE_PROJECT_PATH, "output/sleep_apnea_analysis_simplified_208_v2_secondary_model") # New output folder
LOCAL_DATA_DIR_BASE = "/content/processed_data_local_simplified_v2" # Local temp for the processed window data
FINAL_MODEL_RESULTS_DIR = os.path.join(PERSISTENT_OUTPUT_DIR_BASE, "final_model_evaluation_results")

# Processed Data Storage
PERSISTENT_METADATA_CSV_NAME = "all_processed_windows_metadata.csv"
# PERSISTENT_WINDOW_DATA_ZIP_NAME = "all_processed_windows_data.zip" # Zipping removed

# Signal Processing Parameters
SAMPLING_RATE = 200
WINDOW_SECONDS = 30
WINDOW_SAMPLES = int(WINDOW_SECONDS * SAMPLING_RATE)
WINDOW_OVERLAP_RATIO = 0.5
WINDOW_STEP_SAMPLES = int(WINDOW_SAMPLES * (1 - WINDOW_OVERLAP_RATIO))
if WINDOW_STEP_SAMPLES <= 0: WINDOW_STEP_SAMPLES = WINDOW_SAMPLES # Ensuring step is at least 1 window

SPECTROGRAM_SIGNALS = ["SaO2", "ECG"] # Signals to be converted to spectrograms
RAW_TS_SIGNALS = ['AIRFLOW', 'C3-M2'] # Signals to be used as raw time series
BASE_TABULAR_FEATURES = ['Age', 'Sex_encoded'] # Tabular features from demographics data

SPECTROGRAM_NPERSEG = 256
SPECTROGRAM_NOVERLAP = 128
SPECTROGRAM_TARGET_SIZE = (64, 128) # (height, width) for the spectrogram images

# Model Training Parameters
RANDOM_STATE = 42
BATCH_SIZE = 32
GLOBAL_INITIAL_EPOCHS = 15 # Epochs for the initial training phase
GLOBAL_FINE_TUNE_EPOCHS = 15 # Epochs for the fine-tuning phase

# Default Hyperparameters (used directly as K-Fold is now removed)
DEFAULT_LEARNING_RATE = 1e-4
DEFAULT_FINE_TUNE_LR_FACTOR = 0.1 # Factor to reduce LR for the fine-tuning
DEFAULT_FINE_TUNE_LAYERS = 15 # Number of layers to unfreeze in MobileNetV2 for fine-tuning
DEFAULT_FOCAL_LOSS_GAMMA = 3.0 # Using increased gamma
DEFAULT_L2_REG_FACTOR = 1e-5
DEFAULT_ENABLE_OVERSAMPLING = True # For window-level labels during training
DEFAULT_OVERSAMPLING_TARGET_POSITIVE_RATIO = 0.3
DEFAULT_AUGMENT_SPECTROGRAMS = True
DEFAULT_AUGMENT_RAW_TS = True

# Workflow Control
MAX_RECORDS_TO_PROCESS = 208 # Process all of the 208 records
FORCE_REPROCESS_ALL_WINDOWS = False # If True, re-generates all window data. If False, tries to load the existing.
FINAL_HOLDOUT_TEST_SET_RATIO = 0.2 # Proportion of the data for the final test set
# Proportion of the train+val pool for validation set, used for threshold tuning
FINAL_MODEL_VALIDATION_RATIO = 0.15 / (1.0 - FINAL_HOLDOUT_TEST_SET_RATIO) if (1.0 - FINAL_HOLDOUT_TEST_SET_RATIO) > 0 else 0.15
USE_SECONDARY_PATIENT_MODEL = True # Flag to use the secondary model approach

# Multi-Class Severity Configuration
PATIENT_SEVERITY_TARGET_COL = 'AHI_Severity_MultiClass' # Column name in patient_true_labels_df
PATIENT_BINARY_TARGET_COL = 'OSA_Severity_Binary' # Column name for binary OSA (AHI >= 5)
# Default thresholds for mapping the positive window ratio to patient severity (0:Normal, 1:Mild, 2:Moderate, 3:Severe)
# These will be tuned on the validation set if USE_SECONDARY_PATIENT_MODEL is False.
PATIENT_SEVERITY_THRESHOLDS_ON_RATIO = [0.04, 0.10, 0.20] # Initial guess


# --- Function to Save the Confusion Matrix Plot ---
def save_confusion_matrix_plot(cm, class_names, title, filename):
    """
    Saves a confusion matrix plot using seaborn and matplotlib.
    Args:
        cm (numpy.array): The confusion matrix.
        class_names (list): List of class names for labels.
        title (str): Title for the plot.
        filename (str): Full path to save the plot image.
    """
    try:
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names)
        plt.title(title)
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        os.makedirs(os.path.dirname(filename), exist_ok=True) # Ensure directory exists
        plt.savefig(filename)
        plt.close() # Close the figure to free memory
        print(f"Saved confusion matrix plot to: {filename}")
    except Exception as e:
        print(f"Error in save_confusion_matrix_plot for '{title}': {e}")

# --- Helper Functions (Data Loading, Processing) ---
def prepare_output_dirs_windowed(persistent_base_dir, local_base_dir_config_val):
    print(f"Ensuring persistent base output directory exists: {persistent_base_dir}")
    os.makedirs(persistent_base_dir, exist_ok=True)

    print(f"Preparing local base directory for processing/training: {local_base_dir_config_val}")
    if os.path.exists(local_base_dir_config_val):
        print(f"  Removing existing local directory: {local_base_dir_config_val}")
        shutil.rmtree(local_base_dir_config_val)
    # Create the structure for storing processed window data locally
    # This is where .npy files will be saved.
    os.makedirs(os.path.join(local_base_dir_config_val, "all_data_temp", "raw_ts"), exist_ok=True)
    for signal_name in SPECTROGRAM_SIGNALS:
        os.makedirs(os.path.join(local_base_dir_config_val, "all_data_temp", "spectrograms", signal_name), exist_ok=True)

    os.makedirs(os.path.join(persistent_base_dir, "final_model_evaluation_results"), exist_ok=True)
    print("Output directories prepared.")

def load_labels_and_metadata(labels_file, demo_file,
                             patient_binary_target_col='OSA_Severity_Binary',
                             patient_multiclass_target_col='AHI_Severity_MultiClass',
                             ahi_threshold_binary=5,
                             ahi_thresholds_multiclass=[5, 15, 30]):
    labels_df = pd.DataFrame()
    demo_dict = {}
    try:
        labels_df = pd.read_csv(labels_file)
        if 'Record' in labels_df.columns: labels_df['Record'] = labels_df['Record'].astype(str)
        else: print(f"Warning: 'Record' column not found in labels file: {labels_file}"); labels_df['Record'] = None

        if 'AHI' in labels_df.columns:
            labels_df[patient_binary_target_col] = (labels_df['AHI'] >= ahi_threshold_binary).astype(int)
            conditions = [
                (labels_df['AHI'] < ahi_thresholds_multiclass[0]),
                (labels_df['AHI'] >= ahi_thresholds_multiclass[0]) & (labels_df['AHI'] < ahi_thresholds_multiclass[1]),
                (labels_df['AHI'] >= ahi_thresholds_multiclass[1]) & (labels_df['AHI'] < ahi_thresholds_multiclass[2]),
                (labels_df['AHI'] >= ahi_thresholds_multiclass[2])
            ]
            choices = [0, 1, 2, 3]
            labels_df[patient_multiclass_target_col] = np.select(conditions, choices, default=0)
            print(f"Created patient multi-class target '{patient_multiclass_target_col}' using AHI thresholds: <{ahi_thresholds_multiclass[0]} (0), "
                  f"[{ahi_thresholds_multiclass[0]}-{ahi_thresholds_multiclass[1]}) (1), "
                  f"[{ahi_thresholds_multiclass[1]}-{ahi_thresholds_multiclass[2]}) (2), "
                  f">={ahi_thresholds_multiclass[2]} (3).")
            print(f"Created patient binary target '{patient_binary_target_col}' using AHI threshold: >={ahi_threshold_binary} (1).")

        else:
            print(f"Warning: 'AHI' column not found in labels file. Cannot create patient targets '{patient_binary_target_col}' or '{patient_multiclass_target_col}'.")
            labels_df[patient_binary_target_col] = pd.NA
            labels_df[patient_multiclass_target_col] = pd.NA

    except FileNotFoundError:
        print(f"Warning: Labels file not found at: {labels_file}. Patient-level true labels will be unavailable.")
        labels_df = pd.DataFrame(columns=['Record', patient_binary_target_col, patient_multiclass_target_col])
    except Exception as e:
        print(f"Error loading labels file or creating patient target: {e}")
        labels_df = pd.DataFrame(columns=['Record', patient_binary_target_col, patient_multiclass_target_col])

    try:
        demo_df = pd.read_csv(demo_file)
        if 'Record' in demo_df.columns: demo_df['Record'] = demo_df['Record'].astype(str)
        if 'Sex' in demo_df.columns: demo_df['Sex_encoded'] = LabelEncoder().fit_transform(demo_df['Sex'])
        else: demo_df['Sex_encoded'] = np.nan

        if 'Record' in demo_df.columns: demo_dict = demo_df.set_index('Record').to_dict('index')
        else: print(f"Warning: 'Record' column not found in demographics file: {demo_file}")
    except FileNotFoundError: print(f"Warning: Demographics file not found at: {demo_file}")
    except Exception as e: print(f"Error loading demographics file: {e}")

    return labels_df, demo_dict

def read_signal_names(hea_path):
    try:
        with open(hea_path, 'r') as f: lines = f.readlines()
        return [line.split(' ')[-1].strip() for line in lines[1:] if len(line.split(' ')) > 7]
    except Exception as e: print(f"Error reading signal names from {hea_path}: {e}"); return None

def load_signal_data(mat_path):
    try:
        with h5py.File(mat_path, 'r') as f:
            data_key = 'val'; found=False
            if data_key in f and isinstance(f[data_key], h5py.Dataset): found=True
            else:
                for k_search in ['signal_data', 'data', 'signals', 'val']:
                    if k_search in f and isinstance(f[k_search], h5py.Dataset): data_key=k_search; found=True; break
            if not found:
                largest_size = 0
                for k_search in f.keys():
                    if isinstance(f[k_search], h5py.Dataset) and f[k_search].size > largest_size:
                         largest_size = f[k_search].size; data_key = k_search; found=True
            if not found: print(f"Error: Cannot find signal data key in HDF5 file {mat_path}"); return None
            data = f[data_key][()].astype(np.float32)
            return data
    except OSError:
        try:
            mat = scipy.io.loadmat(mat_path); data_key = 'val'
            if data_key not in mat:
                for k_search,v_search in mat.items():
                    if k_search.startswith('__'): continue
                    if isinstance(v_search, np.ndarray) and v_search.ndim > 1 : data_key=k_search; break
            if data_key not in mat or mat[data_key] is None : print(f"Error: No signal data in MAT file {mat_path} (scipy)"); return None
            return mat[data_key].astype(np.float32)
        except Exception as e_scipy: print(f"Error loading signal data {mat_path} with scipy: {e_scipy}"); return None
    except Exception as e: print(f"Error loading signal data {mat_path}: {e}"); return None

def load_apnea_hypopnea_events_wfdb(record_base_path, event_patterns_re, annotation_extension='arousal'):
    events = []
    try:
        annotation = wfdb.rdann(record_base_path, extension=annotation_extension)
        n_ann = len(annotation.sample)
        i = 0
        while i < n_ann:
            note = annotation.aux_note[i].strip()
            if note.startswith('('):
                tag_content = note[1:]
                matched_pattern = None
                for pattern in event_patterns_re:
                    if pattern.fullmatch(tag_content): matched_pattern = pattern; break
                if matched_pattern:
                    start_sample = annotation.sample[i]
                    expected_end_tag = tag_content + ')'
                    found_end = False
                    for j in range(i + 1, n_ann):
                        end_note = annotation.aux_note[j].strip()
                        if end_note == expected_end_tag:
                            end_sample = annotation.sample[j]
                            if start_sample < end_sample: events.append((start_sample, end_sample))
                            else: print(f"Warning: End tag '{expected_end_tag}' before start in {record_base_path} at {end_sample} (start {start_sample})")
                            i = j; found_end = True; break
            i += 1
        return events
    except FileNotFoundError: print(f"Warning: Annotation file {record_base_path}.{annotation_extension} not found."); return None
    except Exception as e: print(f"Error reading WFDB annotations for {record_base_path}: {e}"); return None

def generate_and_save_numerical_spectrogram(signal_window, sr, output_path_npy, nperseg, noverlap, target_img_size_hw):
    try:
        if len(signal_window) < nperseg:
            padding = nperseg - len(signal_window)
            signal_window = np.pad(signal_window, (0, padding), 'constant', constant_values=0.0)

        f, t, Sxx = scipy.signal.spectrogram(signal_window, fs=sr, nperseg=nperseg, noverlap=noverlap)

        if Sxx.size == 0:
            resized_Sxx_db = np.zeros(target_img_size_hw, dtype=np.float32)
        else:
            Sxx_db = 10 * np.log10(np.maximum(Sxx, 1e-9))
            resized_Sxx_db = skimage.transform.resize(Sxx_db, target_img_size_hw, anti_aliasing=True, preserve_range=True).astype(np.float32)

        os.makedirs(os.path.dirname(output_path_npy), exist_ok=True)
        np.save(output_path_npy, resized_Sxx_db)
        return True
    except Exception as e:
        if isinstance(e, OSError) and e.errno == 28:
            print(f"FATAL ERROR: No space left on device saving: {output_path_npy}");
            raise e
        print(f"Error generating window spectrogram {output_path_npy}: {e}")
        return False

def load_numerical_spectrogram(path_npy, target_size_hw, augment=False, config_params=None):
    try:
        spec_array_2d = np.load(path_npy)
        if spec_array_2d.shape != target_size_hw:
            spec_array_2d = skimage.transform.resize(spec_array_2d, target_size_hw, anti_aliasing=True, preserve_range=True).astype(np.float32)

        if augment and config_params and config_params.get('AUGMENT_SPECTROGRAMS', False):
            if random.random() < 0.3:
                h, w = spec_array_2d.shape
                cutout_h = random.randint(max(1, h // 10), max(2, h // 4))
                cutout_w = random.randint(max(1, w // 10), max(2, w // 4))
                if h > cutout_h and w > cutout_w :
                    y_start = random.randint(0, h - cutout_h); x_start = random.randint(0, w - cutout_w)
                    spec_array_2d[y_start:y_start+cutout_h, x_start:x_start+cutout_w] = np.mean(spec_array_2d)
            if random.random() < 0.2:
                factor = random.uniform(0.7, 1.3)
                spec_array_2d = spec_array_2d * factor
                spec_array_2d = np.clip(spec_array_2d, np.min(spec_array_2d), np.max(spec_array_2d))

        spec_array_3d_single = np.expand_dims(spec_array_2d, axis=-1)
        spec_array_3d_rgb = np.concatenate([spec_array_3d_single] * 3, axis=-1)
        return spec_array_3d_rgb.astype(np.float32)
    except Exception as e:
        print(f"Error loading spec {path_npy}: {e}")
        # Return a NumPy array of zeros with correct shape
        return np.zeros(target_size_hw + (3,), dtype=np.float32)

def load_npy(path, expected_shape=None, augment=False, config_params=None):
    try:
        arr = np.load(path).astype(np.float32)
        if expected_shape and arr.shape != expected_shape:
            print(f"Warning: Shape mismatch loading {path}. Expected {expected_shape}, got {arr.shape}.")
            return None
        if augment and config_params and config_params.get('AUGMENT_RAW_TS', False) and arr.ndim == 1:
            if random.random() < 0.3:
                noise_factor = 0.005 * np.std(arr) if np.std(arr) > 1e-6 else 0.005
                if noise_factor > 0 : arr = arr + np.random.normal(0, noise_factor, arr.shape)
            if random.random() < 0.2:
                scale_factor = random.uniform(0.9, 1.1); arr = arr * scale_factor
        return arr
    except Exception as e: print(f"Error loading npy {path}: {e}"); return None

def process_record_windowed_mp_wrapper(args):
    record_id_str, split_name_for_saving, config_dict_mp, demo_meta_record_mp, base_save_path_mp = args
    return process_record_windowed(record_id_str, split_name_for_saving, config_dict_mp, demo_meta_record_mp, base_save_path_mp)

def process_record_windowed(record_id, split_name_for_saving, config_dict, demo_meta_record, base_save_path):
    # print(f"PROCESS_RECORD_WINDOWED: Start processing record_id: {record_id}, PID: {os.getpid()}") # Debug log (optional)

    record_dir = os.path.join(config_dict['data_dir'], record_id)
    mat_path = os.path.join(record_dir, f"{record_id}.mat")
    hea_path = os.path.join(record_dir, f"{record_id}.hea")
    record_base_path = os.path.join(record_dir, record_id)
    processed_windows_metadata = []

    if not (os.path.exists(mat_path) and os.path.exists(hea_path)):
        print(f"Warning: MAT or HEA file missing for record {record_id}. Skipping. MAT: {mat_path}, HEA: {hea_path}")
        return []

    signal_names = read_signal_names(hea_path)
    if signal_names is None:
        print(f"Warning: Failed to read signal names for record {record_id}. Skipping.")
        return []

    signal_data = load_signal_data(mat_path)
    if signal_data is None:
        print(f"Warning: Failed to load signal data for record {record_id}. Skipping.")
        return []
    if signal_data.shape[0] != len(signal_names):
        print(f"Warning: Signal data channels ({signal_data.shape[0]}) mismatch HEA names ({len(signal_names)}) for {record_id}. Skipping.")
        return []
    total_samples = signal_data.shape[1]

    try:
        event_patterns_re = [re.compile(p) for p in config_dict['apnea_hypopnea_patterns']]
    except Exception as e:
        print(f"Error compiling regex patterns for {record_id}: {e}. Skipping.")
        return []

    apnea_hypopnea_events = load_apnea_hypopnea_events_wfdb(
        record_base_path,
        event_patterns_re,
        annotation_extension=config_dict.get('arousal_suffix','').replace('-','').replace('.mat','')
    )
    if apnea_hypopnea_events is None:
        # print(f"PROCESS_RECORD_WINDOWED: {record_id} - Warning: Could not load relevant events. Treating as no events.") # Optional log
        apnea_hypopnea_events = []
    # else:
        # print(f"PROCESS_RECORD_WINDOWED: {record_id} - Loaded {len(apnea_hypopnea_events)} apnea/hypopnea events.") # Optional log

    signal_name_to_index = {name: i for i, name in enumerate(signal_names)}
    try:
        spec_indices = [signal_name_to_index[name] for name in config_dict['spectrogram_signals']]
        raw_ts_indices = [signal_name_to_index[name] for name in config_dict['raw_ts_signals']]
    except KeyError as e:
        print(f"Warning: Required signal '{e}' not found in record {record_id}. Skipping record.")
        return []

    window_count = 0
    sr = config_dict.get('sampling_rate', SAMPLING_RATE)
    window_samples = config_dict['window_samples']
    step_samples = config_dict['window_step_samples']
    # print(f"PROCESS_RECORD_WINDOWED: {record_id} - Windowing parameters: SR={sr}, WinSamples={window_samples}, StepSamples={step_samples}") # Optional log

    for start_sample in range(0, total_samples - window_samples + 1, step_samples):
        end_sample = start_sample + window_samples
        window_id = f"{record_id}_w{window_count:06d}"
        # print(f"PROCESS_RECORD_WINDOWED: {record_id} - Processing window {window_id} ({start_sample}-{end_sample})") # Debug log (optional)

        window_label = 0
        for event_start, event_end in apnea_hypopnea_events:
            if max(start_sample, event_start) < min(end_sample, event_end):
                window_label = 1; break

        modalities = {'window_id': window_id, 'record_id': record_id, 'label': window_label, 'split': split_name_for_saving}
        modalities.update(demo_meta_record if demo_meta_record else {})
        modalities_saved_ok = True; spec_paths = {}; raw_ts_paths = {}

        for i, signal_idx in enumerate(spec_indices):
            signal_name = config_dict['spectrogram_signals'][i]
            signal_window_data = signal_data[signal_idx, start_sample:end_sample]
            spec_filename = f"{window_id}_{signal_name}_spec.npy"
            spec_rel_path = os.path.join("spectrograms", signal_name, spec_filename)
            spec_full_path = os.path.join(base_save_path, split_name_for_saving, spec_rel_path)

            # print(f"PROCESS_RECORD_WINDOWED: {record_id} - Generating spectrogram for {signal_name}, window {window_id}") # Debug log (optional)
            success = generate_and_save_numerical_spectrogram(
                signal_window_data, sr, spec_full_path,
                config_dict.get('spec_nperseg', SPECTROGRAM_NPERSEG),
                config_dict.get('spec_noverlap', SPECTROGRAM_NOVERLAP),
                config_dict['spectrogram_target_size'])
            if success:
                spec_paths[f'spec_{signal_name}_path'] = spec_rel_path
            else:
                print(f"Warning: Failed to save spectrogram for {signal_name}, window {window_id} of record {record_id}. Skipping window.")
                modalities_saved_ok = False; break
        if not modalities_saved_ok: continue

        modalities.update(spec_paths)

        for i, signal_idx in enumerate(raw_ts_indices):
            signal_name = config_dict['raw_ts_signals'][i]
            raw_segment = signal_data[signal_idx, start_sample:end_sample]
            raw_filename = f"{window_id}_{signal_name}_rawts.npy"
            raw_rel_path = os.path.join("raw_ts", raw_filename)
            raw_full_path = os.path.join(base_save_path, split_name_for_saving, raw_rel_path)

            # print(f"PROCESS_RECORD_WINDOWED: {record_id} - Preparing to save raw TS for {signal_name}, window {window_id} to {raw_full_path}") # Debug log (optional)
            try:
                if len(raw_segment) == window_samples:
                    os.makedirs(os.path.dirname(raw_full_path), exist_ok=True)
                    np.save(raw_full_path, raw_segment.astype(np.float32))
                    # print(f"PROCESS_RECORD_WINDOWED: {record_id} - Successfully saved raw TS for {signal_name}, window {window_id}") # Debug log (optional)
                    raw_ts_paths[f'raw_{signal_name}_path'] = raw_rel_path
                else:
                    print(f"Warning: Raw segment length mismatch for {signal_name}, window {window_id} of record {record_id}. Expected {window_samples}, got {len(raw_segment)}. Skipping window.")
                    modalities_saved_ok = False; break
            except Exception as e:
                print(f"Error saving raw TS {raw_filename} for record {record_id}: {e}. Skipping window.")
                modalities_saved_ok = False; break
        if not modalities_saved_ok: continue

        modalities.update(raw_ts_paths)
        processed_windows_metadata.append(modalities)
        window_count += 1

    del signal_data; gc.collect()
    # print(f"PROCESS_RECORD_WINDOWED: Finished processing record_id: {record_id}. Generated {window_count} windows. PID: {os.getpid()}") # Debug log (optional)
    return processed_windows_metadata


# --- Data Loading for the Windowed Training ---
def data_generator_windowed(metadata_df, config_dict, batch_size_val, scaler=None, feature_cols_list=None, is_training=True, data_base_path_val=None):
    num_samples_original = len(metadata_df)
    data_base_path = data_base_path_val if data_base_path_val else config_dict.get('training_data_dir', config_dict['persistent_output_dir_base'])

    tab_feature_cols = feature_cols_list if feature_cols_list else [col for col in config_dict['base_tab_features'] if col in metadata_df.columns]
    num_tab_features = len(tab_feature_cols)

    if is_training and config_dict.get('ENABLE_OVERSAMPLING', False) and 'label' in metadata_df.columns:
        positive_df = metadata_df[metadata_df['label'] == 1]; negative_df = metadata_df[metadata_df['label'] == 0]
        positive_indices = positive_df.index.tolist(); negative_indices = negative_df.index.tolist()
        if positive_indices and negative_indices:
            target_positive_ratio = config_dict.get('OVERSAMPLING_TARGET_POSITIVE_RATIO', 0.3)
            num_epoch_samples = num_samples_original
            num_pos_epoch = int(num_epoch_samples * target_positive_ratio)
            num_neg_epoch = num_epoch_samples - num_pos_epoch

            oversampled_pos_indices = np.random.choice(positive_indices, size=num_pos_epoch, replace=True) if num_pos_epoch > 0 else np.array([], dtype=int)
            oversampled_neg_indices = np.random.choice(negative_indices, size=num_neg_epoch, replace=num_neg_epoch > len(negative_indices)) if num_neg_epoch > 0 else np.array([], dtype=int)

            epoch_indices = np.concatenate([oversampled_pos_indices, oversampled_neg_indices]); np.random.shuffle(epoch_indices)
            metadata_to_iterate_from = metadata_df; num_samples_for_iteration = len(epoch_indices)
        else:
            epoch_indices = np.arange(num_samples_original); np.random.shuffle(epoch_indices)
            metadata_to_iterate_from = metadata_df; num_samples_for_iteration = num_samples_original
    else:
        epoch_indices = np.arange(num_samples_original)
        if is_training: np.random.shuffle(epoch_indices)
        metadata_to_iterate_from = metadata_df; num_samples_for_iteration = num_samples_original

    for start_idx in range(0, num_samples_for_iteration, batch_size_val):
        current_batch_original_indices = epoch_indices[start_idx : min(start_idx + batch_size_val, num_samples_for_iteration)]
        actual_batch_size = len(current_batch_original_indices)
        if actual_batch_size == 0: continue

        batch_specs_np = {f'spec_{name}': np.zeros((actual_batch_size,) + config_dict['spectrogram_target_size'] + (3,), dtype=np.float32) for name in config_dict['spectrogram_signals']}
        batch_raw_ts_np = {f'raw_{name}': np.zeros((actual_batch_size, config_dict['window_samples']), dtype=np.float32) for name in config_dict['raw_ts_signals']}
        if num_tab_features > 0:
            batch_tabular_np = np.zeros((actual_batch_size, num_tab_features), dtype=np.float32)
        # Else, batch_tabular_np is not created if the num_tab_features is 0, it won't be yielded.

        batch_labels_np = np.zeros((actual_batch_size, 1), dtype=np.int32)
        valid_windows_in_batch_count = 0

        for i_in_batch, original_df_idx in enumerate(current_batch_original_indices):
            try:
                row = metadata_to_iterate_from.iloc[original_df_idx]
            except IndexError:
                continue

            valid_modalities = True; loaded_specs = {}; loaded_raw = {}
            current_split_for_path = row.get('split', 'all_data_temp')
            base_path_for_loading = os.path.join(data_base_path, current_split_for_path)
            apply_augmentation = is_training

            for signal_name in config_dict['spectrogram_signals']:
                spec_path_key = f'spec_{signal_name}_path'
                if spec_path_key not in row or pd.isna(row[spec_path_key]):
                    valid_modalities=False; break
                spec_path_npy = os.path.join(base_path_for_loading, row[spec_path_key])
                spec_array = load_numerical_spectrogram(spec_path_npy, config_dict['spectrogram_target_size'], augment=apply_augmentation, config_params=config_dict)
                if np.sum(np.isnan(spec_array)) > 0 : # Check for NaNs in NumPy array
                    valid_modalities=False; break
                loaded_specs[signal_name] = spec_array
            if not valid_modalities: continue

            for signal_name in config_dict['raw_ts_signals']:
                raw_path_key = f'raw_{signal_name}_path'
                if raw_path_key not in row or pd.isna(row[raw_path_key]):
                    valid_modalities=False; break
                raw_path_npy = os.path.join(base_path_for_loading, row[raw_path_key])
                raw_segment = load_npy(raw_path_npy, expected_shape=(config_dict['window_samples'],), augment=apply_augmentation, config_params=config_dict)
                if raw_segment is None:
                    valid_modalities=False; break
                loaded_raw[signal_name] = raw_segment
            if not valid_modalities: continue

            current_tab_data = None # Initialize for this window
            if num_tab_features > 0:
                try:
                    current_tab_data = row[tab_feature_cols].values.astype(np.float32)
                    if np.isnan(current_tab_data).any():
                        current_tab_data = np.nan_to_num(current_tab_data, nan=0.0)
                except Exception:
                    valid_modalities = False # Skip window if tabular data fails for it
            if not valid_modalities: continue


            try:
                label = int(row['label'])
            except Exception:
                continue

            current_idx_in_batch = valid_windows_in_batch_count
            for s_name in config_dict['spectrogram_signals']: batch_specs_np[f'spec_{s_name}'][current_idx_in_batch] = loaded_specs[s_name]
            for s_name in config_dict['raw_ts_signals']: batch_raw_ts_np[f'raw_{s_name}'][current_idx_in_batch] = loaded_raw[s_name]
            if num_tab_features > 0 and current_tab_data is not None: # Ensure current_tab_data was successfully processed
                batch_tabular_np[current_idx_in_batch] = current_tab_data
            elif num_tab_features > 0 and current_tab_data is None: # Should not happen if logic above is correct
                print(f"Warning: Tabular data expected but not loaded for window {row.get('window_id', 'Unknown')}. Filling with zeros.")
                batch_tabular_np[current_idx_in_batch] = np.zeros(num_tab_features, dtype=np.float32)


            batch_labels_np[current_idx_in_batch, 0] = label
            valid_windows_in_batch_count += 1

        if valid_windows_in_batch_count == 0:
            continue

        if valid_windows_in_batch_count < actual_batch_size:
            for s_name in config_dict['spectrogram_signals']: batch_specs_np[f'spec_{s_name}'] = batch_specs_np[f'spec_{s_name}'][:valid_windows_in_batch_count]
            for s_name in config_dict['raw_ts_signals']: batch_raw_ts_np[f'raw_{s_name}'] = batch_raw_ts_np[f'raw_{s_name}'][:valid_windows_in_batch_count]
            if num_tab_features > 0: # Only trim if batch_tabular_np was created
                batch_tabular_np = batch_tabular_np[:valid_windows_in_batch_count]
            batch_labels_np = batch_labels_np[:valid_windows_in_batch_count]

        # Prepare model inputs dictionary
        model_inputs = {}
        for s_name in config_dict['spectrogram_signals']: model_inputs[f"spec_{s_name}_input"] = batch_specs_np[f'spec_{s_name}']
        for s_name in config_dict['raw_ts_signals']: model_inputs[f"raw_{s_name}_input"] = np.expand_dims(batch_raw_ts_np[f'raw_{s_name}'], axis=-1)

        if num_tab_features > 0:
            # Apply scaling to the tabular data if scaler is provided and features exist
            if scaler and hasattr(scaler, 'mean_') and batch_tabular_np.shape[0] > 0 and batch_tabular_np.shape[1] == scaler.n_features_in_:
                try:
                    batch_tabular_np_scaled = scaler.transform(batch_tabular_np)
                    model_inputs["tabular_input"] = batch_tabular_np_scaled
                except Exception:
                    model_inputs["tabular_input"] = batch_tabular_np # Use unscaled if error
            else: # No scaler or mismatch
                 model_inputs["tabular_input"] = batch_tabular_np


        yield model_inputs, batch_labels_np

# --- Create TF Dataset Function ---
def create_tf_dataset_windowed(metadata_df, config_dict, batch_size_val, scaler, feature_cols_list, is_training=True, repeat_for_fit=False, data_base_path_val=None):
    final_feature_cols = [col for col in feature_cols_list if col in metadata_df.columns]
    num_final_tab_features = len(final_feature_cols)

    # Define the output signature for the tf.data.Dataset
    output_signature_inputs = {} # Start empty
    for name in config_dict['spectrogram_signals']:
        output_signature_inputs[f"spec_{name}_input"] = tf.TensorSpec(shape=(None,) + config_dict['spectrogram_target_size'] + (3,), dtype=tf.float32)
    for name in config_dict['raw_ts_signals']:
        output_signature_inputs[f"raw_{name}_input"] = tf.TensorSpec(shape=(None, config_dict['window_samples'], 1), dtype=tf.float32)

    if num_final_tab_features > 0: # Conditionally add tabular input signature
        output_signature_inputs["tabular_input"] = tf.TensorSpec(shape=(None, num_final_tab_features), dtype=tf.float32)

    output_signature = (output_signature_inputs, tf.TensorSpec(shape=(None, 1), dtype=tf.int32))

    required_paths = [f'spec_{name}_path' for name in config_dict['spectrogram_signals']] + \
                     [f'raw_{name}_path' for name in config_dict['raw_ts_signals']]
    # Essential columns for filtering (tabular features are only essential if they are going to be used)
    essential_cols_for_filtering = ['split', 'label'] + required_paths
    if num_final_tab_features > 0:
        essential_cols_for_filtering += final_feature_cols

    metadata_df_filtered = metadata_df.copy()
    cols_to_check_for_na = [col for col in essential_cols_for_filtering if col in metadata_df_filtered.columns and col != 'split']
    metadata_df_filtered.dropna(subset=cols_to_check_for_na, inplace=True)
    metadata_df_filtered.reset_index(drop=True, inplace=True)

    if len(metadata_df_filtered) < len(metadata_df):
        print(f"Filtered {len(metadata_df) - len(metadata_df_filtered)} rows from metadata due to missing essential data for tf.data.Dataset.")
    if len(metadata_df_filtered) == 0:
        print("Error: Metadata is empty after filtering for essential columns. Cannot create tf.data.Dataset.")
        return None

    gen_lambda = lambda: data_generator_windowed(
        metadata_df_filtered,
        config_dict,
        batch_size_val,
        scaler,
        final_feature_cols, # Pass the determined final_feature_cols
        is_training=is_training,
        data_base_path_val=data_base_path_val
    )
    dataset = tf.data.Dataset.from_generator(gen_lambda, output_signature=output_signature)

    if is_training:
        dataset = dataset.shuffle(buffer_size=max(1000, batch_size_val*10)).repeat()
    elif repeat_for_fit:
        dataset = dataset.repeat()

    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

# --- Windowed Model Definition ---
def build_multimodal_model_windowed(config_dict, num_tabular_features): # num_tabular_features is len(final_feature_cols)
    l2_reg = config_dict.get('l2_reg_factor', 1e-5)

    spectrogram_inputs = [keras.Input(shape=config_dict['spectrogram_target_size'] + (3,), name=f"spec_{name}_input") for name in config_dict['spectrogram_signals']]
    raw_ts_inputs = [keras.Input(shape=(config_dict['window_samples'], 1), name=f"raw_{name}_input") for name in config_dict['raw_ts_signals']]

    model_inputs_list = spectrogram_inputs + raw_ts_inputs

    tabular_input_layer = None # Keep track of the Keras Input layer for the tabular data
    if num_tabular_features > 0:
        tabular_input_layer = keras.Input(shape=(num_tabular_features,), name="tabular_input")
        model_inputs_list.append(tabular_input_layer)

    spectrogram_cnn_outputs = []
    for i, spec_input_raw in enumerate(spectrogram_inputs):
        signal_name = config_dict['spectrogram_signals'][i]
        spec_input_processed = tf.keras.applications.mobilenet_v2.preprocess_input(spec_input_raw)
        base_model = keras.applications.MobileNetV2(
            input_shape=config_dict['spectrogram_target_size']+(3,),
            include_top=False,
            weights='imagenet',
            pooling='avg',
            name=f"mobilenet_{signal_name}"
        )
        base_model.trainable = False

        img_features = base_model(spec_input_processed, training=False)
        img_branch = layers.Dropout(0.5, name=f"dropout_spec_{signal_name}")(img_features)
        img_branch = layers.Dense(64, activation='relu', kernel_regularizer=regularizers.l2(l2_reg), name=f"spec_features_{signal_name}")(img_branch)
        spectrogram_cnn_outputs.append(img_branch)

    raw_ts_cnn_outputs = []
    for i, raw_input_val in enumerate(raw_ts_inputs):
        signal_name = config_dict['raw_ts_signals'][i]
        ts_branch = layers.Conv1D(32, 16, activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2_reg), name=f"ts_conv1_{signal_name}")(raw_input_val)
        ts_branch = layers.BatchNormalization(name=f"ts_bn1_{signal_name}")(ts_branch)
        ts_branch = layers.MaxPooling1D(4, name=f"ts_pool1_{signal_name}")(ts_branch)
        ts_branch = layers.Conv1D(64, 16, activation='relu', padding='same', kernel_regularizer=regularizers.l2(l2_reg), name=f"ts_conv2_{signal_name}")(ts_branch)
        ts_branch = layers.BatchNormalization(name=f"ts_bn2_{signal_name}")(ts_branch)
        ts_branch = layers.GlobalMaxPooling1D(name=f"ts_features_{signal_name}")(ts_branch)
        ts_branch = layers.Dense(64, activation='relu', kernel_regularizer=regularizers.l2(l2_reg), name=f"ts_dense_{signal_name}")(ts_branch)
        raw_ts_cnn_outputs.append(ts_branch)

    features_to_fuse = spectrogram_cnn_outputs + raw_ts_cnn_outputs

    if num_tabular_features > 0 and tabular_input_layer is not None: # Check if the Input layer was created
        tab_branch = layers.Dense(32, activation='relu', kernel_regularizer=regularizers.l2(l2_reg))(tabular_input_layer) # Use the Input layer here
        tab_branch = layers.BatchNormalization()(tab_branch)
        tab_branch = layers.Dropout(0.3)(tab_branch)
        tab_branch = layers.Dense(64, activation='relu', kernel_regularizer=regularizers.l2(l2_reg), name="tabular_features")(tab_branch)
        features_to_fuse.append(tab_branch)

    if not features_to_fuse:
        raise ValueError("No input branches to fuse! Check signal and tabular feature configurations.")

    fused_features = features_to_fuse[0] if len(features_to_fuse) == 1 else layers.Concatenate(name="fused_features")(features_to_fuse)

    x = layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(l2_reg))(fused_features)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    output = layers.Dense(1, activation='sigmoid', name="output")(x)

    model = keras.Model(inputs=model_inputs_list, outputs=output, name="windowed_multimodal_classifier")
    return model

# Function to Engineer the Patient-Level Features ---
def engineer_patient_features(window_df, tuned_window_threshold, proba_col='win_proba'):
    """
    Engineers patient-level features from window-level predictions.
    Args:
        window_df (pd.DataFrame): DataFrame with window predictions, must include 'record_id' and proba_col.
        tuned_window_threshold (float): The threshold to consider a window as an "event".
        proba_col (str): Name of the column containing window probabilities.
    Returns:
        pd.DataFrame: DataFrame with one row per patient and engineered features.
    """
    if window_df.empty or proba_col not in window_df.columns:
        return pd.DataFrame()

    # Create the binary predictions based on the tuned threshold
    window_df['win_pred_binary_event'] = (window_df[proba_col] > tuned_window_threshold).astype(int)

    patient_features_list = []
    for record_id, group in window_df.groupby('record_id'):
        features = {'record_id': record_id}

        # Feature 1: Positive Window Ratio (fraction of windows predicted as event)
        features['positive_window_ratio'] = group['win_pred_binary_event'].mean()

        # Feature 2: Mean probability of windows predicted as positive
        positive_windows = group[group['win_pred_binary_event'] == 1]
        if not positive_windows.empty:
            features['mean_event_proba'] = positive_windows[proba_col].mean()
        else:
            features['mean_event_proba'] = 0.0 # Or np.nan, or a low value

        # Feature 3: Standard deviation of all window probabilities for the patient
        features['std_proba'] = group[proba_col].std()

        # Feature 4: Median of all window probabilities for the patient
        features['median_proba'] = group[proba_col].median()

        # Feature 5: 90th percentile of window probabilities
        features['p90_proba'] = group[proba_col].quantile(0.90)

        # Feature 6: Number of "event segments" (consecutive windows predicted as event)
        # This is a bit more complex; for simplicity, we'll use a basic version
        # A segment is a run of 1s in 'win_pred_binary_event'
        group['is_event_shifted'] = group['win_pred_binary_event'].shift(1, fill_value=0)
        group['segment_start'] = (group['win_pred_binary_event'] == 1) & (group['is_event_shifted'] == 0)
        features['num_event_segments'] = group['segment_start'].sum()

        patient_features_list.append(features)

    patient_features_df = pd.DataFrame(patient_features_list)
    # Handle NaNs that might arise from std (if only 1 window) or if no positive windows for mean_event_proba
    patient_features_df.fillna(0.0, inplace=True)
    return patient_features_df


# --- Main Execution ---
if __name__ == "__main__":
    main_start_time = time.time()
    config = {
        'data_dir': DATA_DIR,
        'persistent_output_dir_base': PERSISTENT_OUTPUT_DIR_BASE,
        'training_data_dir': LOCAL_DATA_DIR_BASE,
        'final_model_output_dir': FINAL_MODEL_RESULTS_DIR,
        'labels_file': LABELS_FILE, 'demo_file': DEMOGRAPHICS_FILE,
        'arousal_suffix': AROUSAL_MAT_SUFFIX.replace('-','').replace('.mat',''),
        'spectrogram_signals': SPECTROGRAM_SIGNALS, 'raw_ts_signals': RAW_TS_SIGNALS,
        'base_tab_features': BASE_TABULAR_FEATURES,
        'spectrogram_target_size': SPECTROGRAM_TARGET_SIZE,
        'sampling_rate': SAMPLING_RATE,
        'window_seconds': WINDOW_SECONDS, 'window_samples': WINDOW_SAMPLES,
        'window_overlap_ratio': WINDOW_OVERLAP_RATIO, 'window_step_samples': WINDOW_STEP_SAMPLES,
        'spec_nperseg': SPECTROGRAM_NPERSEG, 'spec_noverlap': SPECTROGRAM_NOVERLAP,
        'learning_rate': DEFAULT_LEARNING_RATE,
        'fine_tune_lr_factor': DEFAULT_FINE_TUNE_LR_FACTOR,
        'fine_tune_layers': DEFAULT_FINE_TUNE_LAYERS,
        'focal_loss_gamma': DEFAULT_FOCAL_LOSS_GAMMA,
        'l2_reg_factor': DEFAULT_L2_REG_FACTOR,
        'ENABLE_OVERSAMPLING': DEFAULT_ENABLE_OVERSAMPLING,
        'OVERSAMPLING_TARGET_POSITIVE_RATIO': DEFAULT_OVERSAMPLING_TARGET_POSITIVE_RATIO,
        'AUGMENT_SPECTROGRAMS': DEFAULT_AUGMENT_SPECTROGRAMS,
        'AUGMENT_RAW_TS': DEFAULT_AUGMENT_RAW_TS,
        'INITIAL_EPOCHS': GLOBAL_INITIAL_EPOCHS,
        'FINE_TUNE_EPOCHS': GLOBAL_FINE_TUNE_EPOCHS,
        'BATCH_SIZE': BATCH_SIZE,
        'apnea_hypopnea_patterns': [
            r"\(?resp_centralapnea\)?", r"\(?resp_hypopnea\)?",
            r"\(?resp_obstructiveapnea\)?", r"\(?resp_mixedapnea\)?", r"mixed apnea"
        ],
        'binary_target_col': PATIENT_BINARY_TARGET_COL,
        'patient_multiclass_target_col': PATIENT_SEVERITY_TARGET_COL,
        'patient_ahi_threshold_binary': 5,
        'patient_ahi_thresholds_multiclass': [5, 15, 30],
        'PATIENT_SEVERITY_THRESHOLDS_ON_RATIO': list(PATIENT_SEVERITY_THRESHOLDS_ON_RATIO),
        'USE_SECONDARY_PATIENT_MODEL': USE_SECONDARY_PATIENT_MODEL
    }
    prepare_output_dirs_windowed(config['persistent_output_dir_base'], config['training_data_dir'])

    print("--- STAGE 1: Data Preprocessing and Windowing ---")
    patient_true_labels_df, demo_dict = load_labels_and_metadata(
        config['labels_file'], config['demo_file'],
        patient_binary_target_col=config['binary_target_col'],
        patient_multiclass_target_col=config['patient_multiclass_target_col'],
        ahi_threshold_binary=config['patient_ahi_threshold_binary'],
        ahi_thresholds_multiclass=config['patient_ahi_thresholds_multiclass']
    )

    all_records_in_dir = [d for d in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, d))]
    valid_record_ids_stage1 = set()
    if not patient_true_labels_df.empty and 'Record' in patient_true_labels_df.columns:
        valid_record_ids_stage1.update(patient_true_labels_df['Record'].astype(str).unique())
    if demo_dict:
        valid_record_ids_stage1.update(demo_dict.keys())

    all_valid_records_for_processing = [r for r in all_records_in_dir if r in valid_record_ids_stage1]
    print(f"Found {len(all_valid_records_for_processing)} records in '{DATA_DIR}' that have corresponding entries in labels/demographics files.")

    if not all_valid_records_for_processing:
        print(f"Error: No valid records found for Stage 1 processing. Check paths and content. Exiting."); exit()

    subset_records_for_stage1 = all_valid_records_for_processing
    if MAX_RECORDS_TO_PROCESS > 0 and MAX_RECORDS_TO_PROCESS < len(all_valid_records_for_processing):
        print(f"Selecting {MAX_RECORDS_TO_PROCESS} records RANDOMLY for processing.")
        random.seed(RANDOM_STATE);
        subset_records_for_stage1 = random.sample(all_valid_records_for_processing, MAX_RECORDS_TO_PROCESS)

    subset_records_array = np.array(subset_records_for_stage1)

    all_windows_meta_path = os.path.join(config['persistent_output_dir_base'], PERSISTENT_METADATA_CSV_NAME)
    all_windows_df = pd.DataFrame()
    local_all_data_temp_path = os.path.join(config['training_data_dir'], "all_data_temp")

    if not FORCE_REPROCESS_ALL_WINDOWS and os.path.exists(all_windows_meta_path):
        print(f"Attempting to load existing metadata from: {all_windows_meta_path}")
        all_windows_df = pd.read_csv(all_windows_meta_path)
        if 'record_id' in all_windows_df.columns:
            all_windows_df['record_id'] = all_windows_df['record_id'].astype(str)

        if not set(subset_records_array).issubset(set(all_windows_df['record_id'].unique())):
            print("Warning: Existing metadata does not cover all currently selected records. Reprocessing NPY files.")
            FORCE_REPROCESS_ALL_WINDOWS = True
        else:
            raw_ts_path_check = os.path.join(local_all_data_temp_path, "raw_ts")
            if not os.path.exists(raw_ts_path_check) or not os.listdir(raw_ts_path_check):
                 print(f"Warning: Metadata found, but local NPY data at '{local_all_data_temp_path}' (specifically 'raw_ts' subdir) seems missing/empty. Reprocessing NPY files.")
                 FORCE_REPROCESS_ALL_WINDOWS = True
            else:
                 print("Metadata loaded. Assuming local NPY files are present. Skipping NPY generation.")
    else:
        if FORCE_REPROCESS_ALL_WINDOWS: print("FORCE_REPROCESS_ALL_WINDOWS is True. Reprocessing NPY files.")
        else: print("Existing processed metadata not found or incomplete. Reprocessing NPY files.")
        FORCE_REPROCESS_ALL_WINDOWS = True


    if FORCE_REPROCESS_ALL_WINDOWS:
        print(f"Processing {len(subset_records_array)} records into windows (NPY files)...")
        if os.path.exists(local_all_data_temp_path):
            shutil.rmtree(local_all_data_temp_path)
        os.makedirs(os.path.join(local_all_data_temp_path, "raw_ts"), exist_ok=True)
        for signal_name in SPECTROGRAM_SIGNALS:
            os.makedirs(os.path.join(local_all_data_temp_path, "spectrograms", signal_name), exist_ok=True)

        processing_start_time = time.time()
        process_args_list = []
        for record_id_proc in subset_records_array:
            record_id_str = str(record_id_proc)
            demo_meta = {}
            if demo_dict and record_id_str in demo_dict:
                demo_meta = {feat: demo_dict[record_id_str].get(feat, np.nan) for feat in config['base_tab_features']}
            else:
                demo_meta = {feat: np.nan for feat in config['base_tab_features']}

            process_args_list.append((record_id_str, "all_data_temp", config, demo_meta, config['training_data_dir']))

        num_processes = min(os.cpu_count(), 4) if os.cpu_count() else 2
        print(f"Using {num_processes} processes for window generation.")

        all_processed_windows_list_of_lists = []
        if num_processes > 1 and __name__ == '__main__':
            try:
                with multiprocessing.Pool(processes=num_processes) as pool:
                    with tqdm(total=len(process_args_list), desc="Stage 1: Processing records (MP)") as pbar:
                        for result in pool.imap_unordered(process_record_windowed_mp_wrapper, process_args_list):
                            if result:
                                all_processed_windows_list_of_lists.append(result)
                            pbar.update(1)
            except Exception as mp_e:
                print(f"Multiprocessing failed with error: {mp_e}. Falling back to sequential processing.")
                all_processed_windows_list_of_lists = []
                for args_item in tqdm(process_args_list, desc="Stage 1: Processing records (Sequential Fallback)"):
                    result = process_record_windowed_mp_wrapper(args_item)
                    if result:
                        all_processed_windows_list_of_lists.append(result)
        else:
            print("Executing record processing sequentially.")
            for args_item in tqdm(process_args_list, desc="Stage 1: Processing records (Sequential)"):
                result = process_record_windowed_mp_wrapper(args_item)
                if result:
                    all_processed_windows_list_of_lists.append(result)

        all_processed_windows_list = [item for sublist in all_processed_windows_list_of_lists if sublist for item in sublist]
        print(f"\nWindow processing (Stage 1) took: {time.time() - processing_start_time:.2f} seconds.")
        print(f"Total windows generated: {len(all_processed_windows_list)}")

        if not all_processed_windows_list:
            print("Error: No windows were processed. Check logs. Exiting."); exit()

        all_windows_df = pd.DataFrame(all_processed_windows_list)
        if not all_windows_df.empty:
            all_windows_df['record_id'] = all_windows_df['record_id'].astype(str)
            all_windows_df.to_csv(all_windows_meta_path, index=False)
            print(f"Saved all processed windows metadata to: {all_windows_meta_path}")

        del all_processed_windows_list, all_processed_windows_list_of_lists; gc.collect()

    if all_windows_df.empty:
        print("Error: Window data DataFrame is empty after Stage 1. Cannot proceed. Exiting."); exit()
    print("--- Stage 1 Finished ---")

    print("\n--- FINAL MODEL TRAINING AND EVALUATION STAGE ---")

    record_to_label_map_binary = {}
    record_to_label_map_multiclass = {}
    if not patient_true_labels_df.empty and 'Record' in patient_true_labels_df.columns and config['binary_target_col'] in patient_true_labels_df.columns:
        record_to_label_map_binary = {str(r):l for r,l in patient_true_labels_df.set_index('Record')[config['binary_target_col']].to_dict().items()}
    if not patient_true_labels_df.empty and 'Record' in patient_true_labels_df.columns and config['patient_multiclass_target_col'] in patient_true_labels_df.columns:
        record_to_label_map_multiclass = {str(r):l for r,l in patient_true_labels_df.set_index('Record')[config['patient_multiclass_target_col']].to_dict().items()}

    subset_binary_labels_for_split = np.array([record_to_label_map_binary.get(str(rec_id), 0) for rec_id in subset_records_array])
    stratify_final_split = subset_binary_labels_for_split if len(np.unique(subset_binary_labels_for_split)) > 1 else None

    final_train_val_pool_ids, final_test_ids = train_test_split(
        subset_records_array,
        test_size=FINAL_HOLDOUT_TEST_SET_RATIO,
        random_state=RANDOM_STATE,
        stratify=stratify_final_split
    )

    final_train_val_multiclass_labels = np.array([record_to_label_map_multiclass.get(str(rec_id), 0) for rec_id in final_train_val_pool_ids])
    stratify_val_split = final_train_val_multiclass_labels if len(np.unique(final_train_val_multiclass_labels)) > 1 else None

    final_train_ids, final_val_ids = [], []
    if len(final_train_val_pool_ids) > 0 :
        if len(final_train_val_pool_ids) < 2 or FINAL_MODEL_VALIDATION_RATIO <= 0 or FINAL_MODEL_VALIDATION_RATIO >=1:
            final_train_ids = final_train_val_pool_ids
            print("Validation set not created. Using entire pool for training.")
        else:
            final_train_ids, final_val_ids = train_test_split(
                final_train_val_pool_ids,
                test_size=FINAL_MODEL_VALIDATION_RATIO,
                random_state=RANDOM_STATE,
                stratify=stratify_val_split
            )

    print(f"Final Data Split (by Record IDs): Train: {len(final_train_ids)}, Validation: {len(final_val_ids)}, Test: {len(final_test_ids)}")

    if len(final_train_ids) == 0:
        print("Error: Final training set is empty. Exiting."); exit()

    # Filter all_windows_df for train, val, test based on the patient IDs
    # These DFs will be used for the window-level training and later for patient-level feature engineering
    train_windows_df = all_windows_df[all_windows_df['record_id'].isin(final_train_ids)].copy(); train_windows_df.loc[:, 'split'] = 'all_data_temp'
    val_windows_df = pd.DataFrame()
    if len(final_val_ids) > 0:
        val_windows_df = all_windows_df[all_windows_df['record_id'].isin(final_val_ids)].copy(); val_windows_df.loc[:, 'split'] = 'all_data_temp'
    test_windows_df = pd.DataFrame()
    if len(final_test_ids) > 0:
        test_windows_df = all_windows_df[all_windows_df['record_id'].isin(final_test_ids)].copy(); test_windows_df.loc[:, 'split'] = 'all_data_temp'

    print(f"Window counts: Train: {len(train_windows_df)}, Validation: {len(val_windows_df)}, Test: {len(test_windows_df)}")

    final_feature_cols = [col for col in config['base_tab_features'] if col in train_windows_df.columns] # Use train_windows_df for column check
    print(f"Using tabular features for primary model: {final_feature_cols}")
    num_final_tab_features = len(final_feature_cols)

    final_imputation_values = {}
    if not train_windows_df.empty:
        for col in final_feature_cols:
            if train_windows_df[col].isnull().all():
                final_imputation_values[col] = 0.0
            else:
                final_imputation_values[col] = train_windows_df[col].median()
    print(f"Imputation values for tabular features: {final_imputation_values}")

    for df_loop in [train_windows_df, val_windows_df, test_windows_df]:
        if not df_loop.empty:
            for col in final_feature_cols:
                if col in df_loop.columns:
                    df_loop.loc[:, col] = df_loop[col].fillna(final_imputation_values.get(col, 0.0))

    final_scaler = None
    if num_final_tab_features > 0 and not train_windows_df.empty:
        numeric_train_tab = train_windows_df[final_feature_cols].apply(pd.to_numeric, errors='coerce').fillna(0.0)
        if not numeric_train_tab.empty and numeric_train_tab.shape[1] > 0:
             final_scaler = StandardScaler().fit(numeric_train_tab.values)
             print("StandardScaler fitted on training tabular data.")

    class_weights = None
    if not train_windows_df.empty and 'label' in train_windows_df.columns:
        labels_train = train_windows_df['label'].values
        unique_classes, class_counts = np.unique(labels_train, return_counts=True)
        if len(unique_classes) == 2:
            class_weights_arr = compute_class_weight(class_weight='balanced', classes=unique_classes, y=labels_train)
            class_weights = dict(zip(map(int,unique_classes), class_weights_arr))
        print(f"Window-level class weights: {class_weights}")

    current_bs = config['BATCH_SIZE']
    train_ds = create_tf_dataset_windowed(train_windows_df, config, current_bs, final_scaler, final_feature_cols, is_training=True, data_base_path_val=config['training_data_dir'])

    val_ds_fit = None
    if not val_windows_df.empty:
        val_ds_fit = create_tf_dataset_windowed(val_windows_df, config, current_bs, final_scaler, final_feature_cols, is_training=False, repeat_for_fit=True, data_base_path_val=config['training_data_dir'])

    if train_ds is None:
        print("Error: Failed to create training tf.data.Dataset. Exiting."); exit()

    train_steps = int(np.ceil(len(train_windows_df) / current_bs)) if not train_windows_df.empty and current_bs > 0 else 0
    val_steps = int(np.ceil(len(val_windows_df) / current_bs)) if not val_windows_df.empty and current_bs > 0 else 0

    if train_steps == 0:
        print("Error: Training steps is 0. Exiting."); exit()

    # Don't delete the train_windows_df, val_windows_df, test_windows_df yet, needed for patient-level feature engineering
    # del train_windows_df; gc.collect()

    final_model = build_multimodal_model_windowed(config, num_final_tab_features)
    final_model.summary(line_length=120)

    loss_fn = tf.keras.losses.BinaryFocalCrossentropy(gamma=config['focal_loss_gamma'])
    metrics_list = [
        'accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc')
    ]

    final_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=config['learning_rate']),
        loss=loss_fn,
        metrics=metrics_list
    )

    checkpoint_path = os.path.join(FINAL_MODEL_RESULTS_DIR, "best_final_model.keras")
    cb_monitor = 'val_auc' if val_ds_fit and val_steps > 0 else 'auc'
    cb_monitor_loss = 'val_loss' if val_ds_fit and val_steps > 0 else 'loss'

    callbacks_p1 = [
        keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_best_only=True, monitor=cb_monitor, mode='max', verbose=1),
        keras.callbacks.EarlyStopping(monitor=cb_monitor_loss, patience=5, verbose=1, mode='min', restore_best_weights=False),
        keras.callbacks.ReduceLROnPlateau(monitor=cb_monitor_loss, factor=0.2, patience=3, verbose=1, min_lr=1e-6)
    ]

    print(f"\n--- Phase 1 Training ({config['INITIAL_EPOCHS']} epochs) ---")
    print(f"Monitoring '{cb_monitor}' for ModelCheckpoint and '{cb_monitor_loss}' for EarlyStopping/ReduceLROnPlateau.")
    hist_p1 = final_model.fit(
        train_ds,
        epochs=config['INITIAL_EPOCHS'],
        steps_per_epoch=train_steps,
        validation_data=val_ds_fit,
        validation_steps=val_steps,
        callbacks=callbacks_p1,
        verbose=1,
        class_weight=class_weights
    )

    if os.path.exists(checkpoint_path):
        print(f"Loading best weights from Phase 1 checkpoint: {checkpoint_path}")
        final_model.load_weights(checkpoint_path)
    else:
        print("Warning: Checkpoint from Phase 1 not found.")

    fine_tune_layers_count = config.get('fine_tune_layers', 0)
    hist_p2 = None

    if fine_tune_layers_count > 0 and config['FINE_TUNE_EPOCHS'] > 0:
        print(f"\n--- Setting up for Phase 2 Fine-tuning ---")
        print(f"Unfreezing last {fine_tune_layers_count} layers of MobileNetV2 backbones.")

        for sn in config['spectrogram_signals']:
            try:
                base_model_layer = final_model.get_layer(f"mobilenet_{sn}")
                base_model_layer.trainable = True

                num_layers_in_backbone = len(base_model_layer.layers)
                num_to_unfreeze = min(fine_tune_layers_count, num_layers_in_backbone)

                print(f"  For '{sn}': Unfreezing last {num_to_unfreeze} of {num_layers_in_backbone} layers.")

                for layer_idx, layer in enumerate(base_model_layer.layers[:-num_to_unfreeze]):
                    layer.trainable = False

                for layer_idx, layer in enumerate(base_model_layer.layers[-num_to_unfreeze:]):
                    layer.trainable = True
                print(f"  Successfully set trainability for layers in '{sn}' backbone.")

            except ValueError:
                print(f"  Warning: MobileNetV2 backbone for '{sn}' not found. Cannot fine-tune.")
                continue

        final_model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=config['learning_rate'] * config['fine_tune_lr_factor']),
            loss=loss_fn,
            metrics=metrics_list
        )
        print(f"Model re-compiled for fine-tuning with LR: {config['learning_rate'] * config['fine_tune_lr_factor']:.2e}")

        callbacks_p2 = [
            keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_best_only=True, monitor=cb_monitor, mode='max', verbose=1),
            keras.callbacks.EarlyStopping(monitor=cb_monitor_loss, patience=10, verbose=1, mode='min', restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor=cb_monitor_loss, factor=0.2, patience=5, verbose=1, min_lr=1e-7)
        ]

        start_epoch_p2 = hist_p1.epoch[-1] + 1 if hist_p1 and hist_p1.epoch else config['INITIAL_EPOCHS']
        total_epochs_combined = config['INITIAL_EPOCHS'] + config['FINE_TUNE_EPOCHS']

        print(f"\n--- Phase 2 Fine-tuning (up to {config['FINE_TUNE_EPOCHS']} additional epochs) ---")
        print(f"Starting from epoch {start_epoch_p2}, total epochs planned: {total_epochs_combined}.")

        hist_p2 = final_model.fit(
            train_ds,
            epochs=total_epochs_combined,
            initial_epoch=start_epoch_p2,
            steps_per_epoch=train_steps,
            validation_data=val_ds_fit,
            validation_steps=val_steps,
            callbacks=callbacks_p2,
            verbose=1,
            class_weight=class_weights
        )
    else:
        print("\nSkipping Phase 2 Fine-tuning.")

    best_model = final_model
    if os.path.exists(checkpoint_path):
        print(f"Loading overall best model from checkpoint: {checkpoint_path}")
        try:
            best_model = keras.models.load_model(checkpoint_path, custom_objects={'BinaryFocalCrossentropy': loss_fn})
        except Exception as e:
            print(f"Error loading best model from checkpoint: {e}. Using model from end of training.")
            best_model = final_model
    else:
        print("Warning: Final best model checkpoint not found. Using model from end of training.")

    # Get Window Probabilities for ALL data splits (Train, Val, Test) for patient-level feature engineering
    print("\n--- Generating Window Probabilities for Patient-Level Feature Engineering ---")
    all_windows_df_with_probas = all_windows_df.copy() # Start with all metadata

    # Create a dataset for all the windows to get predictions
    all_windows_ds_for_pred = create_tf_dataset_windowed(all_windows_df_with_probas, config, current_bs, final_scaler, final_feature_cols, is_training=False, repeat_for_fit=False, data_base_path_val=config['training_data_dir'])
    all_windows_steps = int(np.ceil(len(all_windows_df_with_probas) / current_bs)) if not all_windows_df_with_probas.empty and current_bs > 0 else 0

    if all_windows_ds_for_pred and all_windows_steps > 0:
        print(f"Predicting probabilities for all {len(all_windows_df_with_probas)} windows...")
        all_probas = best_model.predict(all_windows_ds_for_pred, steps=all_windows_steps, verbose=1).flatten()
        if len(all_probas) == len(all_windows_df_with_probas):
            all_windows_df_with_probas['win_proba'] = all_probas
        else:
            print(f"Warning: Mismatch in predicted probabilities ({len(all_probas)}) and total windows ({len(all_windows_df_with_probas)}). Patient-level features might be incorrect.")
            # Fallback: add a dummy column to prevent key errors, though results will be meaningless
            all_windows_df_with_probas['win_proba'] = 0.0
    else:
        print("Warning: Could not create dataset for all window predictions. Patient-level features might be incorrect.")
        all_windows_df_with_probas['win_proba'] = 0.0 # Fallback

    # Now filter this df for train, val, test sets
    train_windows_for_patient_model_df = all_windows_df_with_probas[all_windows_df_with_probas['record_id'].isin(final_train_ids)].copy()
    val_windows_for_patient_model_df = all_windows_df_with_probas[all_windows_df_with_probas['record_id'].isin(final_val_ids)].copy()
    test_windows_for_patient_model_df = all_windows_df_with_probas[all_windows_df_with_probas['record_id'].isin(final_test_ids)].copy()


    # --- Tune the Window-Level Decision Threshold ---
    tuned_window_threshold = 0.5 # Default if tuning fails or val_df is empty
    if best_model and not val_windows_for_patient_model_df.empty and 'win_proba' in val_windows_for_patient_model_df.columns:
        print("\n--- Tuning Window-Level Decision Threshold (Validation Set Windows) ---")
        val_win_probas_for_thresh_tuning = val_windows_for_patient_model_df['win_proba'].values
        val_true_labels_for_thresh_tuning = val_windows_for_patient_model_df['label'].values

        best_f1_win = -1.0
        candidate_win_thresholds = np.arange(0.1, 0.91, 0.05)
        print(f"  Searching for optimal window decision threshold from candidates: {np.round(candidate_win_thresholds,2)}")
        for th in candidate_win_thresholds:
            y_pred_win_val = (val_win_probas_for_thresh_tuning > th).astype(int)
            f1 = f1_score(val_true_labels_for_thresh_tuning, y_pred_win_val, pos_label=1, average='binary', zero_division=0)
            if f1 > best_f1_win:
                best_f1_win = f1
                tuned_window_threshold = th
        print(f"  Optimal Window-Level Decision Threshold (Val F1 for Event Class {best_f1_win:.4f}): {tuned_window_threshold:.2f}")
    else:
        print("\nSkipping window-level decision threshold tuning (no model/val_df/val_steps or win_proba missing). Using default 0.5.")
    config['tuned_window_threshold'] = tuned_window_threshold


    # --- Patient-Level Modeling ---
    secondary_patient_model = None
    patient_feature_names = ['positive_window_ratio', 'mean_event_proba', 'std_proba', 'median_proba', 'p90_proba', 'num_event_segments']

    if config['USE_SECONDARY_PATIENT_MODEL']:
        print("\n--- Training Secondary Patient-Level Model ---")
        # Engineer features for train, val
        X_train_patient_df = engineer_patient_features(train_windows_for_patient_model_df, tuned_window_threshold, proba_col='win_proba')
        X_val_patient_df = engineer_patient_features(val_windows_for_patient_model_df, tuned_window_threshold, proba_col='win_proba')

        # Merge with the true patient labels
        patient_true_labels_df_str_rec = patient_true_labels_df.copy()
        patient_true_labels_df_str_rec['Record'] = patient_true_labels_df_str_rec['Record'].astype(str)

        train_patient_data = pd.merge(X_train_patient_df, patient_true_labels_df_str_rec[['Record', config['patient_multiclass_target_col'], config['binary_target_col']]], left_on='record_id', right_on='Record', how='inner')
        val_patient_data = pd.merge(X_val_patient_df, patient_true_labels_df_str_rec[['Record', config['patient_multiclass_target_col'], config['binary_target_col']]], left_on='record_id', right_on='Record', how='inner')

        if not train_patient_data.empty and not val_patient_data.empty:
            X_train_patient = train_patient_data[patient_feature_names].values
            y_train_patient_multiclass = train_patient_data[config['patient_multiclass_target_col']].values

            X_val_patient = val_patient_data[patient_feature_names].values
            y_val_patient_multiclass = val_patient_data[config['patient_multiclass_target_col']].values

            secondary_patient_model = RandomForestClassifier(n_estimators=100, random_state=RANDOM_STATE, class_weight='balanced')
            secondary_patient_model.fit(X_train_patient, y_train_patient_multiclass)
            print("Secondary patient-level model (RandomForestClassifier) trained.")

            # Optional: Evaluate the secondary model on patient validation set
            if len(X_val_patient) > 0:
                y_pred_val_mc_secondary = secondary_patient_model.predict(X_val_patient)
                val_f1_secondary = f1_score(y_val_patient_multiclass, y_pred_val_mc_secondary, average='macro', zero_division=0)
                print(f"  Secondary model F1 Macro on patient validation set: {val_f1_secondary:.4f}")
        else:
            print("Warning: Not enough data to train/validate secondary patient model. Falling back to thresholding.")
            config['USE_SECONDARY_PATIENT_MODEL'] = False # Fallback

    # If not using the secondary model or it failed, tune patient severity thresholds
    tuned_patient_thresholds = list(config['PATIENT_SEVERITY_THRESHOLDS_ON_RATIO'])
    if not config['USE_SECONDARY_PATIENT_MODEL']:
        print("\n--- Tuning Patient Severity Thresholds (Validation Set Windows - Thresholding Approach) ---")
        if not val_windows_for_patient_model_df.empty and 'win_proba' in val_windows_for_patient_model_df.columns:
            val_windows_for_patient_model_df.loc[:, 'win_pred_binary_for_agg'] = (val_windows_for_patient_model_df['win_proba'] > tuned_window_threshold).astype(int)
            val_patient_agg = val_windows_for_patient_model_df.groupby('record_id')['win_pred_binary_for_agg'].mean().reset_index().rename(columns={'win_pred_binary_for_agg':'positive_window_ratio'})
            val_patient_agg['record_id'] = val_patient_agg['record_id'].astype(str)

            patient_true_labels_df_str_rec = patient_true_labels_df.copy()
            if 'Record' in patient_true_labels_df_str_rec.columns:
                 patient_true_labels_df_str_rec['Record'] = patient_true_labels_df_str_rec['Record'].astype(str)
            true_val_multi = patient_true_labels_df_str_rec[patient_true_labels_df_str_rec['Record'].isin(val_patient_agg['record_id'])][['Record', config['patient_multiclass_target_col']]].set_index('Record')
            val_comp_df = pd.merge(val_patient_agg.set_index('record_id'), true_val_multi, left_index=True, right_index=True, how='inner')

            if not val_comp_df.empty and config['patient_multiclass_target_col'] in val_comp_df.columns:
                y_true_val_mc = val_comp_df[config['patient_multiclass_target_col']]
                y_score_val_ratio = val_comp_df['positive_window_ratio']
                best_f1_val_mc = -1.0
                candidate_th_sets = [
                    [0.01, 0.03, 0.05], [0.02, 0.05, 0.10], [0.03, 0.08, 0.15],
                    [0.04, 0.10, 0.20], [0.05, 0.12, 0.25], [0.05, 0.15, 0.30],
                    [0.07, 0.20, 0.35], [0.10, 0.25, 0.40], [0.10, 0.30, 0.50],
                    [0.15, 0.35, 0.55], [0.20, 0.40, 0.60]
                ]
                print(f"  Searching for optimal patient severity thresholds from candidates: {candidate_th_sets}")
                for th_set in candidate_th_sets:
                    conditions_mc = [(y_score_val_ratio < th_set[0]), (y_score_val_ratio >= th_set[0]) & (y_score_val_ratio < th_set[1]), (y_score_val_ratio >= th_set[1]) & (y_score_val_ratio < th_set[2]), (y_score_val_ratio >= th_set[2])]
                    y_pred_val_mc = np.select(conditions_mc, [0,1,2,3], default=0)
                    f1 = f1_score(y_true_val_mc, y_pred_val_mc, average='macro', zero_division=0)
                    if f1 > best_f1_val_mc: best_f1_val_mc = f1; tuned_patient_thresholds = th_set
                print(f"  Optimal Patient Severity Thresholds (Val F1 Macro {best_f1_val_mc:.4f}): {tuned_patient_thresholds}")
            else: print("  Warning: Validation comparison DataFrame empty for patient threshold tuning.")
        else: print("  Warning: Validation window data missing for patient threshold tuning.")
    config['PATIENT_SEVERITY_THRESHOLDS_ON_RATIO'] = tuned_patient_thresholds

    # --- Evaluate on the Test Set ---
    if best_model and not test_windows_for_patient_model_df.empty:
        print("\n--- Evaluating Final Model (Test Set) ---")

        # Window-level evaluation (using the tuned_window_threshold)
        if 'win_proba' in test_windows_for_patient_model_df.columns:
            test_win_probas = test_windows_for_patient_model_df['win_proba'].values
            test_true_labels_win = test_windows_for_patient_model_df['label'].values
            current_tuned_window_threshold = config.get('tuned_window_threshold', 0.5)
            test_pred_labels_win = (test_win_probas > current_tuned_window_threshold).astype(int)

            print("\n--- Window-Level Results (Test Set) ---")
            # Note: Re-evaluating loss, auc etc. on test_ds_eval would be more accurate than just using the probas
            # For simplicity here, we'll focus on the metrics from binarized predictions
            # test_win_results = best_model.evaluate(test_ds_eval, steps=test_steps_eval, verbose=0, return_dict=True) # If test_ds_eval is available
            # for name, value in test_win_results.items(): print(f"  {name}: {value:.4f}")
            print(f"  (Using tuned window threshold: {current_tuned_window_threshold:.2f})")
            print(f"  Accuracy: {accuracy_score(test_true_labels_win, test_pred_labels_win):.4f}")
            try: # AUC needs probabilities
                print(f"  AUC: {roc_auc_score(test_true_labels_win, test_win_probas):.4f}")
            except ValueError:
                print("  AUC could not be calculated for window-level (possibly all labels are same).")


            cm_win_test = confusion_matrix(test_true_labels_win, test_pred_labels_win)
            save_confusion_matrix_plot(cm_win_test, ['No Event', 'Event'], f'Window-Level CM (Test Set, Th={current_tuned_window_threshold:.2f})', os.path.join(FINAL_MODEL_RESULTS_DIR, "cm_window_test.png"))
            print(f"\nWindow-Level Classification Report (Test Set, Threshold={current_tuned_window_threshold:.2f}):")
            print(classification_report(test_true_labels_win, test_pred_labels_win, target_names=['No Event', 'Event'], zero_division=0))
        else:
            print("Warning: 'win_proba' not found in test_windows_for_patient_model_df. Skipping window-level report.")


        # Patient-level evaluation
        X_test_patient_df = engineer_patient_features(test_windows_for_patient_model_df, tuned_window_threshold, proba_col='win_proba')
        patient_true_labels_df_str_rec_test = patient_true_labels_df.copy()
        patient_true_labels_df_str_rec_test['Record'] = patient_true_labels_df_str_rec_test['Record'].astype(str)
        test_patient_data_final = pd.merge(X_test_patient_df, patient_true_labels_df_str_rec_test[['Record', config['patient_multiclass_target_col'], config['binary_target_col']]], left_on='record_id', right_on='Record', how='inner')

        if not test_patient_data_final.empty:
            X_test_patient_final = test_patient_data_final[patient_feature_names].values
            y_true_test_mc = test_patient_data_final[config['patient_multiclass_target_col']].values
            y_true_test_binary = test_patient_data_final[config['binary_target_col']].values

            y_pred_test_mc = None
            prediction_method_info = ""

            if config['USE_SECONDARY_PATIENT_MODEL'] and secondary_patient_model:
                y_pred_test_mc = secondary_patient_model.predict(X_test_patient_final)
                prediction_method_info = "Secondary Model (RandomForest)"
            else: # Fallback to thresholding
                # positive_window_ratio is already a column in X_test_patient_df (and thus in test_patient_data_final)
                th_multi = config['PATIENT_SEVERITY_THRESHOLDS_ON_RATIO']
                conditions_test_mc = [
                    (test_patient_data_final['positive_window_ratio'] < th_multi[0]),
                    (test_patient_data_final['positive_window_ratio'] >= th_multi[0]) & (test_patient_data_final['positive_window_ratio'] < th_multi[1]),
                    (test_patient_data_final['positive_window_ratio'] >= th_multi[1]) & (test_patient_data_final['positive_window_ratio'] < th_multi[2]),
                    (test_patient_data_final['positive_window_ratio'] >= th_multi[2])
                ]
                y_pred_test_mc = np.select(conditions_test_mc, [0,1,2,3], default=0)
                prediction_method_info = f"Thresholding (WinTh:{tuned_window_threshold:.2f}, PatientTh:{th_multi})"

            # Multi-class patient severity results
            print(f"\n--- Patient-Level Multi-Class Severity Results (Test Set, Method: {prediction_method_info}) ---")
            print(f"  Accuracy: {accuracy_score(y_true_test_mc, y_pred_test_mc):.4f}")
            cm_pat_test_mc = confusion_matrix(y_true_test_mc, y_pred_test_mc, labels=[0,1,2,3])
            severity_names = ['Normal','Mild','Moderate','Severe']
            save_confusion_matrix_plot(cm_pat_test_mc, severity_names, f'Patient Multi-Class Severity CM (Test Set - {prediction_method_info})', os.path.join(FINAL_MODEL_RESULTS_DIR, "cm_patient_severity_multiclass_test.png"))
            print("\nPatient-Level Multi-Class Severity Classification Report (Test Set):")
            print(classification_report(y_true_test_mc, y_pred_test_mc, target_names=severity_names, labels=[0,1,2,3], zero_division=0))

            # Binary patient OSA results (derived from multi-class)
            y_pred_test_binary = (y_pred_test_mc > 0).astype(int) # Normal (0) vs OSA (1,2,3 -> 1)
            print(f"\n--- Patient-Level Binary OSA Results (Test Set, Derived from Multi-Class, Method: {prediction_method_info}) ---")
            print(f"  Accuracy: {accuracy_score(y_true_test_binary, y_pred_test_binary):.4f}")
            cm_pat_test_binary = confusion_matrix(y_true_test_binary, y_pred_test_binary, labels=[0,1])
            binary_names = ['No OSA','OSA']
            save_confusion_matrix_plot(cm_pat_test_binary, binary_names, f'Patient Binary OSA CM (Test Set - {prediction_method_info})', os.path.join(FINAL_MODEL_RESULTS_DIR, "cm_patient_binary_osa_test.png"))
            print("\nPatient-Level Binary OSA Classification Report (Test Set):")
            print(classification_report(y_true_test_binary, y_pred_test_binary, target_names=binary_names, labels=[0,1], zero_division=0))

            # Save detailed predictions
            test_patient_data_final['pred_severity_multiclass'] = y_pred_test_mc
            test_patient_data_final['pred_osa_binary'] = y_pred_test_binary
            test_patient_data_final.to_csv(os.path.join(FINAL_MODEL_RESULTS_DIR, "detailed_patient_predictions_test.csv"), index=False)
            print(f"Saved detailed patient predictions to: {os.path.join(FINAL_MODEL_RESULTS_DIR, 'detailed_patient_predictions_test.csv')}")
        else:
            print("  Warning: Test patient data empty or missing target. Skipping patient-level evaluation.")
    else:
        print("\nSkipping Test Set Evaluation (no model or test_df empty).")

    combined_history = {}
    if hist_p1 and hist_p1.history:
        for key, val in hist_p1.history.items():
            combined_history[key] = list(val)

    if hist_p2 and hist_p2.history:
        for key, val in hist_p2.history.items():
            if key in combined_history:
                combined_history[key].extend(list(val))
            else:
                combined_history[key] = list(val)

    if combined_history.get('loss'):
        plt.figure(figsize=(12, 10))

        plt.subplot(2, 1, 1)
        plt.plot(combined_history['loss'], label='Train Loss')
        if 'val_loss' in combined_history:
            plt.plot(combined_history['val_loss'], label='Validation Loss')
        plt.title('Model Loss During Training')
        plt.xlabel('Epoch'); plt.ylabel('Loss')
        initial_epochs_count = len(hist_p1.history['loss']) if hist_p1 and hist_p1.history and 'loss' in hist_p1.history else 0
        if hist_p2 and hist_p2.history and initial_epochs_count > 0 and initial_epochs_count < len(combined_history['loss']):
             plt.axvline(x=initial_epochs_count -1 , color='gray', linestyle='--', label='Fine-tuning Start')
        plt.legend(); plt.grid(True)

        plt.subplot(2, 1, 2)
        if 'accuracy' in combined_history: plt.plot(combined_history['accuracy'], label='Train Accuracy')
        if 'val_accuracy' in combined_history: plt.plot(combined_history['val_accuracy'], label='Validation Accuracy')
        if 'auc' in combined_history: plt.plot(combined_history['auc'], label='Train AUC')
        if 'val_auc' in combined_history: plt.plot(combined_history['val_auc'], label='Validation AUC')
        plt.title('Model Performance Metrics During Training')
        plt.xlabel('Epoch'); plt.ylabel('Metric Value')
        if hist_p2 and hist_p2.history and initial_epochs_count > 0 and initial_epochs_count < len(combined_history['loss']):
             plt.axvline(x=initial_epochs_count -1, color='gray', linestyle='--', label='Fine-tuning Start')
        plt.legend(); plt.grid(True)

        plt.tight_layout()
        history_plot_path = os.path.join(FINAL_MODEL_RESULTS_DIR, "training_history_combined.png")
        plt.savefig(history_plot_path); plt.close()
        print(f"Saved combined training history plot to: {history_plot_path}")
    else:
        print("No training history found to plot.")

    print(f"\nAttempting to clean up local data directory: {config['training_data_dir']}")
    if os.path.exists(config['training_data_dir']):
        try:
            shutil.rmtree(config['training_data_dir'])
            print(f"Successfully removed local data directory: {config['training_data_dir']}")
        except Exception as e:
            print(f"Error removing local data directory {config['training_data_dir']}: {e}")
            print("You may need to remove it manually.")

    print(f"\n--- Sleep Apnea Analysis Script Finished ---")
    print(f"Total execution time: {time.time() - main_start_time:.2f} seconds.")
    print(f"All final model evaluation results and plots are saved in: {FINAL_MODEL_RESULTS_DIR}")

Ensuring persistent base output directory exists: /content/drive/MyDrive/AI_Sleep_Apnea/output/sleep_apnea_analysis_simplified_208_v2_secondary_model
Preparing local base directory for processing/training: /content/processed_data_local_simplified_v2
Output directories prepared.
--- STAGE 1: Data Preprocessing and Windowing ---
Created patient multi-class target 'AHI_Severity_MultiClass' using AHI thresholds: <5 (0), [5-15) (1), [15-30) (2), >=30 (3).
Created patient binary target 'OSA_Severity_Binary' using AHI threshold: >=5 (1).
Found 208 records in '/content/temp_data/selected_records208' that have corresponding entries in labels/demographics files.
Attempting to load existing metadata from: /content/drive/MyDrive/AI_Sleep_Apnea/output/sleep_apnea_analysis_simplified_208_v2_secondary_model/all_processed_windows_metadata.csv
Processing 208 records into windows (NPY files)...
Using 4 processes for window generation.


Stage 1: Processing records (MP): 100%|██████████| 208/208 [07:19<00:00,  2.11s/it]



Window processing (Stage 1) took: 440.22 seconds.
Total windows generated: 386036
Saved all processed windows metadata to: /content/drive/MyDrive/AI_Sleep_Apnea/output/sleep_apnea_analysis_simplified_208_v2_secondary_model/all_processed_windows_metadata.csv
--- Stage 1 Finished ---

--- FINAL MODEL TRAINING AND EVALUATION STAGE ---
Final Data Split (by Record IDs): Train: 134, Validation: 32, Test: 42
Window counts: Train: 249563, Validation: 58568, Test: 77905
Using tabular features for primary model: ['Age', 'Sex_encoded']
Imputation values for tabular features: {'Age': 54.0, 'Sex_encoded': 1.0}
StandardScaler fitted on training tabular data.
Window-level class weights: {0: np.float64(0.5440634660411334), 1: np.float64(6.173634474569563)}


  base_model = keras.applications.MobileNetV2(



--- Phase 1 Training (15 epochs) ---
Monitoring 'val_auc' for ModelCheckpoint and 'val_loss' for EarlyStopping/ReduceLROnPlateau.
Epoch 1/15
[1m7799/7799[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 153ms/step - accuracy: 0.5929 - auc: 0.7501 - loss: 0.3449 - precision: 0.4129 - recall: 0.8554
Epoch 1: val_auc improved from -inf to 0.84160, saving model to /content/drive/MyDrive/AI_Sleep_Apnea/output/sleep_apnea_analysis_simplified_208_v2_secondary_model/final_model_evaluation_results/best_final_model.keras
[1m7799/7799[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1632s[0m 186ms/step - accuracy: 0.5929 - auc: 0.7501 - loss: 0.3449 - precision: 0.4129 - recall: 0.8554 - val_accuracy: 0.4428 - val_auc: 0.8416 - val_loss: 0.1558 - val_precision: 0.1154 - val_recall: 0.9523 - learning_rate: 1.0000e-04
Epoch 2/15
[1m7799/7799[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 146ms/step - accuracy: 0.5880 - auc: 0.8363 - loss: 0.1287 - precision: 0.4176 - recall: 0.9406
E