# IoMT Cyber-Medical Fusion Framework

Unified notebook combining:
1. Security module (CICIoMT2024)
2. Physiological module (VitalDB, replacing MIMIC-IV Demo)
3. Fusion and evaluation framework.

All original comments are preserved; only the physiological data-loading step was refactored to use VitalDB.

## 1. Security Module (original notebook `01_security_module.ipynb`)

In [None]:
# ==============================================================================
# STEP 0: ENVIRONMENT SETUP
# ==============================================================================

# --- 1. Mount Google Drive ---
# This command connects your Colab notebook to your Google Drive.
# You will be prompted to authorize the connection.
from google.colab import drive
drive.mount('/content/drive')
print("Google Drive mounted successfully!")

# --- 2. Install necessary libraries (if not already installed) ---
!pip install xgboost shap -q

print("Setup complete. You can now proceed with Step 1.")

In [None]:
# ==============================================================================
# STEP 1: DOWNLOAD, EXTRACT AND LOAD IoMT-TrafficData (IP-Based Flows)
# ==============================================================================

import pandas as pd
import numpy as np
import os
import gc
import subprocess

# ----------------------------------------------------------------------
# Project paths (Google Drive)
# ----------------------------------------------------------------------
BASE_DIR = "/content/drive/MyDrive/Conference_paper_ICCC_2026"
DATA_DIR = os.path.join(BASE_DIR, "data_iomt_traffic")
os.makedirs(DATA_DIR, exist_ok=True)

# Zenodo URL for IoMT-TrafficData zip (public, CC-BY 4.0)
# Ref: https://zenodo.org/records/8116338
ZENO_ZIP_URL = (
    "https://zenodo.org/records/8116338/files/"
    "ML-Based%20IDS%20IoMT.zip?download=1"
)

ZIP_PATH = os.path.join(DATA_DIR, "ML-Based_IDS_IoMT.zip")

# After extraction, this is the expected relative path of the flows CSV
# inside the zip archive:
#  Dataset & Captures/Datasets/IP-Based/Flows/IP-Based Flows Dataset.csv
FLOWS_CSV_PATH = os.path.join(
    DATA_DIR,
    "Dataset & Captures",
    "Datasets",
    "IP-Based",
    "Flows",
    "IP-Based Flows Dataset.csv",
)

# ----------------------------------------------------------------------
# Helper: reduce memory usage
# ----------------------------------------------------------------------
def reduce_mem_usage(df: pd.DataFrame, verbose: bool = True) -> pd.DataFrame:
    """
    Iterates through all dataframe columns and downcasts numeric types
    to reduce memory usage without losing information.
    """
    numerics = ["int16", "int32", "int64", "float16", "float32", "float64"]
    start_mem = df.memory_usage().sum() / 1024**2

    for col in df.columns:
        col_type = df[col].dtypes
        if col_type in numerics:
            c_min = df[col].min()
            c_max = df[col].max()

            if str(col_type).startswith("int"):
                # Downcast integers
                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)
            else:
                # Downcast floats
                if c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)

    end_mem = df.memory_usage().sum() / 1024**2
    if verbose:
        print(
            f"Mem. usage decreased to {end_mem:5.2f} Mb "
            f"({100 * (start_mem - end_mem) / start_mem:.1f}% reduction)"
        )
    return df

# ----------------------------------------------------------------------
# STEP 1A: Download zip from Zenodo if not already present
# ----------------------------------------------------------------------
if not os.path.exists(ZIP_PATH):
    print(f"[IoMT-TrafficData] Zip not found, downloading from Zenodo to:\n  {ZIP_PATH}")
    try:
        # Use wget via subprocess to keep everything in a Python cell
        result = subprocess.run(
            ["wget", "-O", ZIP_PATH, ZENO_ZIP_URL],
            check=False,
            text=True,
            capture_output=True,
        )
        if result.returncode != 0:
            print("ERROR: wget failed.")
            print("STDOUT:", result.stdout[:1000])
            print("STDERR:", result.stderr[:1000])
            raise RuntimeError("Download failed. Please check network or URL.")
        else:
            print("Download completed successfully.")
    except Exception as e:
        print(f"Exception during download: {e}")
        df = None

else:
    print(f"[IoMT-TrafficData] Zip already present at:\n  {ZIP_PATH}")

# ----------------------------------------------------------------------
# STEP 1B: Extract only if flows CSV is not already available
# ----------------------------------------------------------------------
if not os.path.exists(FLOWS_CSV_PATH):
    print(f"[IoMT-TrafficData] Extracting zip into:\n  {DATA_DIR}")
    try:
        # -n : do not overwrite existing files
        result = subprocess.run(
            ["unzip", "-n", ZIP_PATH, "-d", DATA_DIR],
            check=False,
            text=True,
            capture_output=True,
        )
        if result.returncode != 0:
            print("WARNING: unzip returned a non-zero code. "
                  "If the CSV already exists, this may be harmless.")
            print("STDOUT:", result.stdout[:1000])
            print("STDERR:", result.stderr[:1000])
    except Exception as e:
        print(f"Exception during extraction: {e}")
        df = None
else:
    print("[IoMT-TrafficData] Flows CSV already extracted.")

# ----------------------------------------------------------------------
# STEP 1C: Load IP-Based Flows Dataset.csv into a dataframe
# ----------------------------------------------------------------------
df = None

if os.path.exists(FLOWS_CSV_PATH):
    print(f"[IoMT-TrafficData] Loading flows dataset from:\n  {FLOWS_CSV_PATH}")
    df = pd.read_csv(FLOWS_CSV_PATH)
    print(f"Loaded IoMT-TrafficData flows with shape: {df.shape}")
else:
    print("ERROR: IP-Based Flows Dataset.csv not found after extraction.")
    print("Please inspect the directory structure under:", DATA_DIR)

# ----------------------------------------------------------------------
# STEP 1D: Standardise label column to 'label'
# ----------------------------------------------------------------------
if df is not None:
    # Try to detect the label column automatically if needed
    label_candidates = ["label", "Label", "attack_type", "Attack_type", "Attack Type"]
    detected_label_col = None

    for cand in label_candidates:
        if cand in df.columns:
            detected_label_col = cand
            break

    if detected_label_col is None:
        print(
            "WARNING: could not automatically detect the label column.\n"
            "Please inspect df.columns and manually rename the correct "
            "label column to 'label' before continuing."
        )
    else:
        if detected_label_col != "label":
            df = df.rename(columns={detected_label_col: "label"})
            print(f"Renamed label column from '{detected_label_col}' to 'label'.")

        # Normalise benign class name(s), if present
        if "label" in df.columns:
            df["label"] = df["label"].replace(
                {
                    "BENIGN": "Normal",
                    "Benign": "Normal",
                    "benign": "Normal",
                }
            )
            print("Standardised benign label(s) to 'Normal' where applicable.")

    # ------------------------------------------------------------------
    # STEP 1E: Memory optimisation, summary and shuffle
    # ------------------------------------------------------------------
    print("\nOptimizing memory usage of the IoMT-TrafficData dataframe...")
    df = reduce_mem_usage(df)

    print("\n--- Info on the final, optimized dataframe ---")
    display(df.info(memory_usage="deep"))

    if "label" in df.columns:
        print("\n--- Class distribution in the final dataframe (label) ---")
        display(df["label"].value_counts())
    else:
        print("\nWARNING: 'label' column is missing; downstream cells may fail.")

    # Shuffle the dataframe for downstream splits
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    print("\nFinal dataframe has been shuffled.")

    print("\n--- First 5 rows of the final dataframe ---")
    display(df.head())
else:
    print("\nNo dataframe is available. Please fix the download/extraction issues above.")


In [None]:
# ==============================================================================
# STEP 2: LOAD MULTICLASS PACKET-LEVEL DATASET FROM PICKLE AND BUILD `df` WITH `label`
# ==============================================================================

import os
import pickle
import pandas as pd
import numpy as np

try:
    import scipy.sparse as sp
except ImportError:
    sp = None  # optional dependency

BASE_DIR = "/content/drive/MyDrive/Conference_paper_ICCC_2026"
DATA_DIR = os.path.join(BASE_DIR, "data_iomt_traffic")

print(f"[INFO] Looking for 'Dataset_Multiclass.pkl' under:\n  {DATA_DIR}")

# ----------------------------------------------------------------------
# STEP 2A: Locate the main multiclass pickle dataset
# ----------------------------------------------------------------------
dataset_pkl_path = None
all_pkl_found = []

for root, dirs, files in os.walk(DATA_DIR):
    for fname in files:
        if fname.lower().endswith((".pkl", ".pickle")):
            full_path = os.path.join(root, fname)
            all_pkl_found.append(full_path)
            if fname == "Dataset_Multiclass.pkl":
                dataset_pkl_path = full_path

print("\n[INFO] All pickle files found:")
for p in all_pkl_found:
    print("  -", p)

if dataset_pkl_path is None:
    raise RuntimeError(
        "'Dataset_Multiclass.pkl' not found automatically.\n"
        "Please check the printed list above and set 'dataset_pkl_path' manually."
    )

print(f"\n[INFO] Using multiclass dataset pickle:\n  {dataset_pkl_path}")

# ----------------------------------------------------------------------
# STEP 2B: Load the pickle object and inspect its structure
# ----------------------------------------------------------------------
with open(dataset_pkl_path, "rb") as f:
    obj = pickle.load(f)

print("\n[INFO] Type of loaded object:", type(obj))

df_pkt = None  # this will become our packet-level dataframe

# ----------------------------------------------------------------------
# Helper: convert a generic feature matrix into a DataFrame
# ----------------------------------------------------------------------
def to_dataframe(X):
    """
    Convert a generic feature container into a pandas DataFrame.
    Supports DataFrame, NumPy arrays, SciPy sparse matrices and
    list/tuple of rows.
    """
    if isinstance(X, pd.DataFrame):
        return X.copy()
    if sp is not None and sp.issparse(X):
        # Convert sparse matrix to dense; may be memory heavy if very large
        X = X.toarray()
    if isinstance(X, np.ndarray):
        return pd.DataFrame(X)
    if isinstance(X, (list, tuple)):
        # Assume list of rows
        return pd.DataFrame(X)
    # Fallback: try generic DataFrame construction
    return pd.DataFrame(X)

# ----------------------------------------------------------------------
# Interpret the loaded object
# ----------------------------------------------------------------------
if isinstance(obj, pd.DataFrame):
    # Case 1: pickle already contains a full DataFrame
    print("[INFO] Pickle contains a pandas DataFrame.")
    df_pkt = obj.copy()

elif isinstance(obj, (list, tuple)):
    print(f"[INFO] Pickle is a list/tuple of length {len(obj)}.")
    # Most probable case: (X, Y_onehot) with Y_onehot 2D
    if len(obj) == 2:
        X_raw, y_raw = obj[0], obj[1]
        print("[INFO] Interpreting pickle as (X, Y).")
        print("       type(X_raw) =", type(X_raw))
        print("       type(y_raw) =", type(y_raw))

        # Convert X_raw to DataFrame
        df_features = to_dataframe(X_raw)
        print("[INFO] Features dataframe shape:", df_features.shape)

        # Convert y_raw to 1D class vector
        if isinstance(y_raw, pd.DataFrame):
            y_arr = y_raw.to_numpy()
        elif isinstance(y_raw, pd.Series):
            y_arr = y_raw.to_numpy()
        else:
            y_arr = np.asarray(y_raw)

        print("[INFO] Raw label array shape:", y_arr.shape)

        if y_arr.ndim == 1:
            # Already a 1D vector of labels
            y_series = pd.Series(y_arr).reset_index(drop=True)
        elif y_arr.ndim == 2:
            # Most likely one-hot or multi-output: convert via argmax
            print("[INFO] Detected 2D label array, converting via argmax along axis=1.")
            y_indices = np.argmax(y_arr, axis=1)
            y_series = pd.Series(y_indices).reset_index(drop=True)
        else:
            raise RuntimeError(
                f"Unsupported label array with ndim={y_arr.ndim}. "
                "Please inspect 'y_raw' manually."
            )

        print("[INFO] Effective labels length:", len(y_series))

        if len(df_features) != len(y_series):
            raise RuntimeError(
                f"Length mismatch between X ({len(df_features)}) and labels ({len(y_series)}). "
                "Please inspect the pickle structure manually."
            )

        df_pkt = df_features.copy()
        df_pkt["label"] = y_series.values
        print("[INFO] Built df_pkt with 'label' column from (X, Y).")

    elif len(obj) > 0 and isinstance(obj[0], dict):
        # Case 2: list of dicts
        print("[INFO] Interpreting pickle as list of dicts.")
        df_pkt = pd.DataFrame(obj)
    else:
        raise RuntimeError(
            "Unsupported list/tuple structure for automatic conversion. "
            "Please inspect 'obj' manually (e.g., print(obj[0]) in a separate cell)."
        )

else:
    raise RuntimeError(
        "Unsupported pickle object type for automatic conversion. "
        "Please inspect 'obj' manually."
    )

# ----------------------------------------------------------------------
# STEP 2C: Standardise label column name to 'label' (defensive)
# ----------------------------------------------------------------------
print("\n[INFO] Packet-level dataframe initial shape:", df_pkt.shape)
print("[INFO] Packet-level columns (first 20):", df_pkt.columns.tolist()[:20])

if "label" not in df_pkt.columns:
    label_candidates = [
        "Label", "attack_type", "Attack_type", "Attack Type",
        "Attack", "class", "Class", "target"
    ]
    detected_label_col = None
    for cand in label_candidates:
        if cand in df_pkt.columns:
            detected_label_col = cand
            break

    if detected_label_col is not None:
        df_pkt = df_pkt.rename(columns={detected_label_col: "label"})
        print(f"[INFO] Renamed label column from '{detected_label_col}' to 'label'.")
    else:
        print(
            "[WARNING] Could not automatically detect a label column. "
            "Please inspect df_pkt.columns and rename the correct label column to 'label'."
        )

# ----------------------------------------------------------------------
# STEP 2D: Optional normalisation of benign class names (only if string labels)
# ----------------------------------------------------------------------
if "label" in df_pkt.columns:
    if df_pkt["label"].dtype == object:
        df_pkt["label"] = df_pkt["label"].replace(
            {
                "BENIGN": "Normal",
                "Benign": "Normal",
                "benign": "Normal",
                "BenignTraffic": "Normal",
            }
        )
        print("\n[INFO] Standardised benign classes to 'Normal' where applicable.")

    print("\n--- Class distribution (packet-level, df_pkt['label']) ---")
    display(df_pkt["label"].value_counts())
else:
    print(
        "\n[WARNING] df_pkt has no 'label' column. "
        "Supervised models will not work until the label column is correctly identified."
    )

print("\n--- First 5 rows of packet-level dataframe df_pkt ---")
display(df_pkt.head())

# ----------------------------------------------------------------------
# STEP 2E: Set `df` as the main dataframe for downstream steps
# ----------------------------------------------------------------------
df = df_pkt.copy()
print("\n[INFO] Set global dataframe 'df' = df_pkt for downstream ML pipeline.")
print("[INFO] Final df shape:", df.shape)
print("[INFO] 'label' dtype:", df['label'].dtype)


In [None]:
# ==============================================================================
# SAVE THE LABELED IOMT-TRAFFICDATA DATAFRAME
# Run this cell after `df` has been built from Dataset_Multiclass.pkl
# ==============================================================================

import os

# Ensure the final labeled dataframe 'df' exists
if "df" in locals():
    # Define base project directory and ensure it exists
    base_dir = "/content/drive/MyDrive/Conference_paper_ICCC_2026"
    os.makedirs(base_dir, exist_ok=True)

    # Define the path where the file will be saved
    save_path = os.path.join(
        base_dir,
        "iomt_traffic_multiclass_packets.parquet"  # you can rename if you prefer
    )

    # Save the dataframe to a Parquet file
    print(f"Saving the labeled IoMT-TrafficData dataframe to:\n  {save_path} ...")
    df.to_parquet(save_path)
    print("Save complete!")
else:
    print("ERROR: Dataframe 'df' not found. Please run the data loading "
          "and labeling steps (STEP 1 and STEP 2) first.")


In [None]:
# ==============================================================================
# STEP A: LIST ALL FILES UNDER IOMT-TRAFFICDATA ROOT FOLDER
# ==============================================================================

import os

root_dir = "/content/drive/MyDrive/Conference_paper_ICCC_2026/data_iomt_traffic"

print(f"Listing files under: {root_dir}\n")

for dirpath, dirnames, filenames in os.walk(root_dir):
    # Limit the depth a bit just for readability (optional)
    rel = os.path.relpath(dirpath, root_dir)
    depth = rel.count(os.sep)
    if depth > 4:
        continue  # skip very deep folders for now

    print(f"[DIR] {dirpath}")
    for f in filenames:
        print("   -", f)
    print()


In [None]:
# ==============================================================================
# BUILD attack_name_mapping FROM FLOWS/DATASETS CSV FILES
# ==============================================================================

import os
import pandas as pd

root_dir = "/content/drive/MyDrive/Conference_paper_ICCC_2026/data_iomt_traffic"

print(f"Searching for Flows/Datasets folder under:\n  {root_dir}\n")

flows_datasets_dir = None

# 1) Locate the directory that contains ApacheKiller.csv, Normal.csv, etc.
for dirpath, dirnames, filenames in os.walk(root_dir):
    files_set = set(filenames)
    if (
        "ApacheKiller.csv" in files_set
        and "Normal.csv" in files_set
        and "RUDY.csv" in files_set
        and "SlowRead.csv" in files_set
    ):
        flows_datasets_dir = dirpath
        break

if flows_datasets_dir is None:
    raise RuntimeError(
        "Could not automatically find the Flows/Datasets directory with ApacheKiller.csv, "
        "Normal.csv, etc. Please check the path manually."
    )

print("Found Flows/Datasets directory:\n  ", flows_datasets_dir, "\n")

# 2) List the attack-specific CSV files
csv_files = [
    f
    for f in os.listdir(flows_datasets_dir)
    if f.lower().endswith(".csv")
]

print("CSV files found in Flows/Datasets:")
for f in sorted(csv_files):
    print("  -", f)

# 3) For each CSV, read a few rows and extract the numeric attack_type
attack_name_mapping = {}

for fname in sorted(csv_files):
    full_path = os.path.join(flows_datasets_dir, fname)
    df_tmp = pd.read_csv(full_path, nrows=10)

    if "attack_type" not in df_tmp.columns:
        print(f"\n[WARNING] File {fname} has no 'attack_type' column. Skipping.")
        continue

    # We assume all rows in this file share the same attack_type
    unique_vals = df_tmp["attack_type"].unique()
    if len(unique_vals) != 1:
        print(f"\n[WARNING] File {fname} has multiple attack_type values:", unique_vals)
        print("  Taking the first one for mapping.")
    attack_idx = int(unique_vals[0])

    # Derive a clean attack name from the filename (strip .csv)
    base_name = os.path.splitext(fname)[0]  # e.g. "ApacheKiller"
    attack_name_mapping[attack_idx] = base_name

print("\nDerived attack_name_mapping from Flows/Datasets:")
print("attack_name_mapping = {")
for k in sorted(attack_name_mapping.keys()):
    print(f"    {k}: {repr(attack_name_mapping[k])},")
print("}")


In [None]:
# ==============================================================================
# STEP C2: EXTRACT REAL ATTACK NAMES FROM RAW IP-BASED PACKETS DATASET
# ==============================================================================

import os
import pandas as pd

base_dir = "/content/drive/MyDrive/Conference_paper_ICCC_2026"
data_dir = os.path.join(base_dir, "data_iomt_traffic")

csv_raw_path = os.path.join(
    data_dir,
    "Dataset & Captures",
    "Datasets",
    "IP-Based",
    "Packets",
    "IP-Based Packets Dataset.csv"
)

print("Loading RAW packet dataset from:\n ", csv_raw_path)
df_raw = pd.read_csv(csv_raw_path)

print("\nShape of RAW CSV:", df_raw.shape)
print("Columns (first 20):", df_raw.columns.tolist()[:20])

if "attack_type" not in df_raw.columns:
    raise RuntimeError("Column 'attack_type' not found in RAW CSV.")

print("\nattack_type dtype:", df_raw["attack_type"].dtype)
print("\nUnique values in 'attack_type' (with counts):")
print(df_raw["attack_type"].value_counts())


In [None]:
# ==============================================================================
# STEP D: BUILD attack_name_mapping DICT FROM unique_mapping
# ==============================================================================

attack_name_mapping = {
    int(row["label_index"]): str(row["attack_type"])
    for _, row in unique_mapping.iterrows()
}

print("attack_name_mapping = {")
for k in sorted(attack_name_mapping.keys()):
    print(f"    {k}: {repr(attack_name_mapping[k])},")
print("}")


In [None]:
# ==============================================================================
# STEP 2-A (UPDATED FOR IOMT-TRAFFICDATA): PREPROCESSING FOR BINARY VALIDATION
# ==============================================================================
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np

if "df" in locals():
    print("Starting data preprocessing for BINARY classification on IoMT-TrafficData...")

    # --- Basic cleaning: handle infinities and missing values ---
    # Replace +/- inf with NaN, then impute numeric NaNs with column means.
    df_clean = df.copy()
    df_clean.replace([np.inf, -np.inf], np.nan, inplace=True)
    numeric_cols = df_clean.select_dtypes(include=np.number).columns
    df_clean[numeric_cols] = df_clean[numeric_cols].fillna(df_clean[numeric_cols].mean())

    # --- Separate features and target ---
    # For IoMT-TrafficData, columns are all numeric and already standardised.
    # We simply remove the 'label' column from the feature matrix.
    if "label" not in df_clean.columns:
        raise RuntimeError("Column 'label' not found in df. Please check STEP 2.")

    X = df_clean.drop(columns=["label"])
    y_multiclass = df_clean["label"].astype(int)

    print(f"Feature matrix shape: {X.shape}")
    print(f"Multiclass labels dtype: {y_multiclass.dtype}")

    # --- BINARY MODE: 0 -> 'Normal', 1 -> 'Attack'
    # Assumption: class 0 corresponds to benign traffic, all other classes are attacks.
    y_binary = np.where(y_multiclass == 0, 0, 1)
    y = y_binary
    print("-> Created BINARY target: 0 = Normal (class 0), 1 = Attack (classes 1..N)")

    # Optional: quick sanity check on class balance
    unique, counts = np.unique(y, return_counts=True)
    print("-> Binary class distribution (value: count):")
    for val, cnt in zip(unique, counts):
        print(f"   {val}: {cnt}")

    # --- Scaling ---
    # Even if features are already normalised in the original dataset,
    # we keep a StandardScaler step for consistency with previous pipelines.
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # --- Train/Test split (stratified) ---
    X_train, X_test, y_train, y_test = train_test_split(
        X_scaled,
        y,
        test_size=0.3,
        random_state=42,
        stratify=y,
    )
    print("-> Data prepared for binary classification.")
    print(f"   X_train: {X_train.shape}, X_test: {X_test.shape}")
    print(f"   y_train: {y_train.shape}, y_test: {y_test.shape}")

else:
    print("ERROR: Dataframe 'df' not found. Please run the data loading and labeling step first.")


In [None]:
# ==============================================================================
# STEP 2-B (QUICK REBUILD): GROUPED (MULTICLASS) SPLITS FOR XGBOOST
# ==============================================================================

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

if "df" in locals() and "X_scaled" in locals():
    print("Rebuilding grouped (multiclass) labels and train/test splits...")

    df_group = df.copy()

    if "label" not in df_group.columns:
        raise RuntimeError("Column 'label' not found in df. Please check STEP 2.")

    # ------------------------------------------------------------------
    # 1. Use numeric labels directly as grouped labels
    # ------------------------------------------------------------------
    # Assumption:
    #   - label == 0  -> Normal
    #   - label >= 1  -> different attack categories
    df_group["grouped_label"] = df_group["label"].astype(int)

    # Optional: human-readable view (0 -> Normal, k -> Attack_k)
    def label_to_str(v: int) -> str:
        return "Normal" if v == 0 else f"Attack_{v}"

    df_group["grouped_label_str"] = df_group["grouped_label"].apply(label_to_str)

    print("-> Created 'grouped_label' (int) and 'grouped_label_str' (string).")
    print("\n--- Distribution of grouped_label_str ---")
    display(df_group["grouped_label_str"].value_counts())

    # ------------------------------------------------------------------
    # 2. Encode grouped_label for sklearn / XGBoost
    # ------------------------------------------------------------------
    grouped_label_encoder = LabelEncoder()
    y_grouped = grouped_label_encoder.fit_transform(df_group["grouped_label"])

    print("\nEncoded grouped classes (original -> encoded):")
    for orig, enc in zip(
        grouped_label_encoder.classes_,
        grouped_label_encoder.transform(grouped_label_encoder.classes_),
    ):
        print(f"  {orig} -> {enc}")

    # ------------------------------------------------------------------
    # 3. Train/test split using X_scaled from STEP 2-A
    # ------------------------------------------------------------------
    X_train_grouped, X_test_grouped, y_train_grouped, y_test_grouped = train_test_split(
        X_scaled,
        y_grouped,
        test_size=0.3,
        random_state=42,
        stratify=y_grouped,
    )

    print("\n-> Grouped (multiclass) train/test split complete.")
    print(f"   X_train_grouped: {X_train_grouped.shape}")
    print(f"   X_test_grouped:  {X_test_grouped.shape}")
    print(f"   y_train_grouped: {y_train_grouped.shape}")
    print(f"   y_test_grouped:  {y_test_grouped.shape}")
    print(f"   Number of grouped classes: {len(np.unique(y_grouped))}")

else:
    print("ERROR: df or X_scaled not found. Please run STEP 2 and STEP 2-A first.")


In [None]:
# ==============================================================================
# STEP 3-A (UPDATED, WITH AUC-ROC): XGBOOST FOR BINARY CLASSIFICATION (IoMT-TrafficData)
# ==============================================================================
import xgboost as xgb
from sklearn.metrics import (
    classification_report,
    accuracy_score,
    roc_curve,
    auc,
)
import matplotlib.pyplot as plt

# Check if training data exists
if "X_train" in locals() and "y_train" in locals() and "X_test" in locals() and "y_test" in locals():
    print("--- Training XGBoost Model for BINARY Classification on IoMT-TrafficData ---")

    # Initialize the XGBoost classifier for binary classification
    xgb_model_binary = xgb.XGBClassifier(
        objective="binary:logistic",
        use_label_encoder=False,
        eval_metric="logloss",
        device="cuda",      # Use GPU if available (Colab with GPU)
        # You can optionally add tree_method="hist" or "gpu_hist" depending on your XGBoost version
        # tree_method="hist",
    )

    print("Starting training...")
    # y_train already contains the binary labels: 0 = Normal, 1 = Attack
    xgb_model_binary.fit(X_train, y_train)
    print("Training complete.")

    # --- Evaluation on the Test Set ---
    print("\n--- Evaluating on Test Set ---")
    y_pred_binary = xgb_model_binary.predict(X_test)
    # Probabilities for the positive class (Attack)
    y_pred_proba_binary = xgb_model_binary.predict_proba(X_test)[:, 1]

    acc = accuracy_score(y_test, y_pred_binary)
    print(f"Overall Accuracy: {acc * 100:.2f}%")

    # --- AUC-ROC Calculation ---
    fpr, tpr, _ = roc_curve(y_test, y_pred_proba_binary)
    roc_auc = auc(fpr, tpr)
    print(f"AUC-ROC Score: {roc_auc:.4f}")

    print("\nDetailed Classification Report (Normal vs Attack):")
    print(classification_report(y_test, y_pred_binary, target_names=["Normal", "Attack"]))

    # --- Plotting ROC Curve ---
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, lw=2, label=f"ROC curve (area = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], lw=2, linestyle="--")  # Diagonal line (random guess)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver Operating Characteristic (ROC) - Binary (Normal vs Attack)")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()

else:
    print("ERROR: Training data not found. Please run the preprocessing step (STEP 2-A) first.")


In [None]:
# ==============================================================================
# STEP 3-B (FINAL FOR PAPER): XGBOOST FOR GROUPED (MULTICLASS) CLASSES + ROC
# ==============================================================================

import xgboost as xgb
from sklearn.metrics import (
    classification_report,
    accuracy_score,
    roc_auc_score,
    roc_curve,
)
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
from itertools import cycle
import numpy as np

# ----------------------------------------------------------------------
# 0. Sanity check: required variables from STEP 2-B
# ----------------------------------------------------------------------
try:
    X_train_grouped
    X_test_grouped
    y_train_grouped
    y_test_grouped
    grouped_label_encoder
except NameError as e:
    raise RuntimeError(
        f"Missing grouped variable: {e}. "
        "Please run STEP 2-A and STEP 2-B before STEP 3-B."
    )

print("--- Training and Evaluating XGBoost Model for GROUPED (multiclass) classification ---")

# ----------------------------------------------------------------------
# Human-readable mapping (0..8 -> IoMT-TrafficData scenarios)
attack_name_mapping = {
    0: "Normal",
    1: "ApacheKiller",
    2: "ARP",           # ARP spoofing
    3: "CAM",           # CAM table overflow
    4: "Malaria",       # MQTT Malaria
    5: "Netscan",       # Recon / scanning
    6: "RUDY",
    7: "SlowLoris",
    8: "SlowRead",
}

def class_id_to_name(c: int) -> str:
    return attack_name_mapping.get(c, f"Class_{c}")

class_names = [class_id_to_name(c) for c in original_classes]


print("Class mapping (encoder classes -> names):")
for c, name in zip(original_classes, class_names):
    print(f"  {c} -> {name}")

# ----------------------------------------------------------------------
# 2. Model definition and training
# ----------------------------------------------------------------------
xgb_model_grouped = xgb.XGBClassifier(
    objective="multi:softprob",
    use_label_encoder=False,
    eval_metric="mlogloss",
    num_class=len(original_classes),
    device="cuda",  # Use GPU if available
    # tree_method="hist",  # Uncomment depending on XGBoost version
)

print("\nStarting training for grouped (multiclass) classification...")
xgb_model_grouped.fit(X_train_grouped, y_train_grouped)
print("Training complete.")

# ----------------------------------------------------------------------
# 3. Evaluation: accuracy, macro AUC (with full precision), classification report
# ----------------------------------------------------------------------
y_pred_grouped = xgb_model_grouped.predict(X_test_grouped)
y_pred_proba_grouped = xgb_model_grouped.predict_proba(X_test_grouped)

acc_grouped = accuracy_score(y_test_grouped, y_pred_grouped)
roc_auc_macro = roc_auc_score(
    y_test_grouped,
    y_pred_proba_grouped,
    multi_class="ovr",
    average="macro",
)

# Full-precision string for paper copy-paste
roc_auc_macro_str = f"{roc_auc_macro:.6f}"

print(f"\nOverall Accuracy (grouped multiclass): {acc_grouped:.6f}")
print(f"Macro-average AUC-ROC (OVR): {roc_auc_macro_str}  <-- main grouped metric")

print("\nDetailed Classification Report (grouped classes):")
print(
    classification_report(
        y_test_grouped,
        y_pred_grouped,
        target_names=class_names,
        digits=4,
    )
)

# ----------------------------------------------------------------------
# 4. Multi-class ROC curves (per-class + macro-average)
# ----------------------------------------------------------------------
print("\n--- Generating multi-class ROC curves (per class + macro-average) ---")

# Binarize test labels for one-vs-rest ROC
y_test_binarized = label_binarize(
    y_test_grouped,
    classes=np.arange(len(original_classes)),
)
n_classes = y_test_binarized.shape[1]

fpr, tpr = {}, {}
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(
        y_test_binarized[:, i],
        y_pred_proba_grouped[:, i],
    )

# Compute macro-average ROC curve
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= n_classes

# ----------------------------------------------------------------------
# 5. Plot (style tuned for paper)
# ----------------------------------------------------------------------
plt.figure(figsize=(8, 7))

# Per-class ROC (thin, semi-transparent)
color_cycle = cycle(
    ["tab:blue", "tab:orange", "tab:green", "tab:red",
     "tab:purple", "tab:brown", "tab:pink", "tab:gray", "tab:olive"]
)

for i, color in zip(range(n_classes), color_cycle):
    plt.plot(
        fpr[i],
        tpr[i],
        color=color,
        lw=1.5,
        alpha=0.7,
        label=f"{class_names[i]}",
    )

# Macro-average ROC (thick, highlighted)
plt.plot(
    all_fpr,
    mean_tpr,
    color="black",
    lw=3,
    linestyle="-",
    label=f"Macro-average ROC (AUC = {roc_auc_macro_str})",
)

# Chance line
plt.plot([0, 1], [0, 1], "k--", lw=1.5, label="Chance")

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate", fontsize=12)
plt.ylabel("True Positive Rate", fontsize=12)
plt.title(
    "XGBoost – Multi-class ROC Curves (Grouped IoMT-TrafficData)",
    fontsize=14,
)
plt.legend(loc="lower right", fontsize=9, frameon=True)
plt.grid(True, linestyle=":", linewidth=0.7)
plt.tight_layout()
plt.show()


In [None]:
# ==============================================================================
# STEP 4: VISUALIZING MODEL PERFORMANCE AND INTERPRETABILITY
# ==============================================================================
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pandas as pd
import shap

# Ensure that the necessary model and data variables exist
if 'xgb_model_grouped' in locals() and 'X_test_grouped' in locals():

    # --- PLOT 1: CONFUSION MATRIX ---
    # This plot shows in detail where the model performs well and where it makes mistakes.
    print("--- Generating Confusion Matrix ---")

    cm = confusion_matrix(y_test_grouped, y_pred_grouped)
    class_names = grouped_label_encoder.classes_

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix for Grouped Classes', fontsize=16)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.show()

    # --- PLOT 2: FEATURE IMPORTANCE ---
    # This plot shows which features the model considered most important for making its decisions.
    print("\n--- Generating Feature Importance Plot ---")

    feature_importances = xgb_model_grouped.feature_importances_
    # Create a DataFrame for easier plotting
    importance_df = pd.DataFrame({
        'Feature': X.columns, # Using column names from 'X' before scaling
        'Importance': feature_importances
    }).sort_values(by='Importance', ascending=False)

    # Display the top 20 most important features
    plt.figure(figsize=(12, 10))
    sns.barplot(x='Importance', y='Feature', data=importance_df.head(20))
    plt.title('Top 20 Most Important Features (XGBoost)', fontsize=16)
    plt.xlabel('Importance', fontsize=12)
    plt.ylabel('Feature', fontsize=12)
    plt.grid(axis='x')
    plt.show()

    # --- PLOT 3: SHAP SUMMARY PLOT ---
    # This is the most advanced plot: it shows not only WHICH features are important,
    # but also HOW their values impact the prediction for each class.
    print("\n--- Generating SHAP Summary Plot (this may take a moment) ---")

    # Use a subset of the test set to speed up SHAP calculations
    X_test_sample_shap = pd.DataFrame(X_test_grouped[:2000], columns=X.columns)

    explainer = shap.TreeExplainer(xgb_model_grouped)
    shap_values = explainer.shap_values(X_test_sample_shap)

    # Plot the SHAP summary
    shap.summary_plot(shap_values, X_test_sample_shap,
                      class_names=class_names,
                      show=False)
    plt.title("SHAP Summary Plot - Feature Impact on Model Output", fontsize=16)
    plt.show()

else:
    print("ERROR: Make sure Step 3-B has been run successfully to generate the model and results.")

In [None]:
# ==============================================================================
# STEP 5: CROSS-VALIDATION FOR XGBOOST MODEL (SEQUENTIAL EXECUTION)
# ==============================================================================
from sklearn.model_selection import cross_val_score
import numpy as np
import xgboost as xgb

# Make sure the required data exists
if 'X_scaled' in locals() and 'y_grouped' in locals():
    print("--- Starting 5-Fold Cross-Validation (Sequential Mode) ---")
    print("This process will be slower but more memory-efficient.")

    # Define the model again to ensure it's a fresh instance
    xgb_model_for_cv = xgb.XGBClassifier(
        objective='multi:softprob',        # Multiclass classification
        use_label_encoder=False,
        eval_metric='mlogloss',
        tree_method='hist',                # Histogram-based (faster on large datasets)
        device='cuda'                      # Use GPU if available
    )

    # Perform 5-fold cross-validation sequentially
    # 'n_jobs=-1' was removed to avoid memory crashes
    scores = cross_val_score(
        estimator=xgb_model_for_cv,
        X=X_scaled,
        y=y_grouped,
        cv=5,
        scoring='accuracy'
    )

    print("\nCross-Validation complete.")
    print(f"Scores for each of the {len(scores)} folds: {scores}")
    print(f"Mean Accuracy: {np.mean(scores) * 100:.2f}%")
    print(f"Standard Deviation: {np.std(scores) * 100:.2f}%")

else:
    print("ERROR: Data for cross-validation not found. Please run the previous steps.")


In [None]:
# ==============================================================================
# FINAL STEP: INTERPRETABILITY WITH SHAP
# ==============================================================================
import shap
import pandas as pd
import matplotlib.pyplot as plt

# Ensure SHAP values have already been computed
if 'shap_values' in locals() and 'X_test_sample_shap' in locals():
    print("--- Generating Individual SHAP Plots for Each Class (Corrected Slicing) ---")

    # Initialize SHAP visualization
    shap.initjs()

    # Retrieve class names from our label encoder
    class_names = grouped_label_encoder.classes_

    # === LOOP TO GENERATE A SEPARATE PANEL FOR EACH CLASS ===
    for i, class_name in enumerate(class_names):

        print(f"\n--- SHAP Summary Plot for Class: '{class_name}' ---")


        # Select SHAP values for the i-th class using correct slicing for a 3D array
        shap.summary_plot(
            shap_values[:, :, i],  # Select ALL samples, ALL features, for class i
            X_test_sample_shap,
            show=True
        )
        plt.show()

else:
    print("ERROR: Make sure SHAP values were computed successfully.")


LSTM - TRAINING AND SEQUENTIAL WINDOWS

In [None]:
import numpy as np

def create_sequences(X_data, y_data, time_steps=10):
    """
    Create sequential series on 2D data.
    """
    Xs, ys = [], []
    for i in range(len(X_data) - time_steps):
        # Time window extraction
        v = X_data[i:(i + time_steps)]
        Xs.append(v)
        # Assigning labels on the last window
        ys.append(y_data[i + time_steps - 1])
    return np.array(Xs), np.array(ys)

In [None]:
# ==============================================================================
# FINAL & COMPLETE: CNN-LSTM WITH BALANCED CLASSES AND AUC-ROC PLOT
# (GROUPED MULTICLASS ON IOMT-TRAFFICDATA)
# ==============================================================================

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense, Dropout
from tensorflow.keras.utils import Sequence
from sklearn.metrics import (
    classification_report,
    accuracy_score,
    roc_auc_score,
    roc_curve,
    auc,
)
from sklearn.preprocessing import label_binarize
from sklearn.utils import class_weight
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle

# ----------------------------------------------------------------------
# Helper functions and classes
# ----------------------------------------------------------------------
def create_sequences(X_data, y_data, time_steps=10):
    """
    Create overlapping sequences of length `time_steps` from a 2D feature matrix.

    Parameters
    ----------
    X_data : array-like, shape (n_samples, n_features)
        Feature matrix (already scaled/encoded).
    y_data : array-like, shape (n_samples,)
        Label vector.
    time_steps : int
        Length of each time window.

    Returns
    -------
    X_seq : ndarray, shape (n_sequences, time_steps, n_features)
    y_seq : ndarray, shape (n_sequences,)
        Label associated with the last element of each sequence.
    """
    Xs, ys = [], []
    for i in range(len(X_data) - time_steps):
        v = X_data[i : (i + time_steps)]
        Xs.append(v)
        ys.append(y_data[i + time_steps - 1])
    return np.array(Xs), np.array(ys)


class TimeSeriesGenerator(Sequence):
    """
    Simple Keras Sequence to generate sliding windows over a 2D feature matrix.
    """

    def __init__(self, X_data, y_data, batch_size, time_steps):
        self.X_data = X_data
        self.y_data = y_data
        self.batch_size = batch_size
        self.time_steps = time_steps
        self.indices = np.arange(len(X_data) - time_steps)

    def __len__(self):
        return int(np.ceil(len(self.indices) / self.batch_size))

    def __getitem__(self, index):
        start_idx = index * self.batch_size
        end_idx = (index + 1) * self.batch_size
        batch_indices = self.indices[start_idx:end_idx]

        X_batch, y_batch = [], []
        for i in batch_indices:
            X_batch.append(self.X_data[i : (i + self.time_steps)])
            y_batch.append(self.y_data[i + self.time_steps - 1])

        return np.array(X_batch), np.array(y_batch)


# ----------------------------------------------------------------------
# Training and evaluation
# ----------------------------------------------------------------------
try:
    X_train_grouped
    X_test_grouped
    y_train_grouped
    y_test_grouped
    grouped_label_encoder
except NameError as e:
    print("ERROR: Grouped data not found. Please re-run STEP 2-A and STEP 2-B.")
else:
    # 1. Calculate class weights for grouped labels
    print("--- Calculating class weights for grouped labels... ---")
    weights = class_weight.compute_class_weight(
        class_weight="balanced",
        classes=np.unique(y_train_grouped),
        y=y_train_grouped,
    )
    class_weights = dict(enumerate(weights))
    print("Class weights:", class_weights)

    # 2. Prepare time-series generators and parameters
    TIME_STEPS = 20
    BATCH_SIZE = 1024

    training_generator = TimeSeriesGenerator(
        X_train_grouped, y_train_grouped, BATCH_SIZE, TIME_STEPS
    )
    test_generator = TimeSeriesGenerator(
        X_test_grouped, y_test_grouped, BATCH_SIZE, TIME_STEPS
    )

    n_features = X_train_grouped.shape[1]
    n_outputs = len(grouped_label_encoder.classes_)

    # Define human-readable class names for reports and plots
    original_classes = grouped_label_encoder.classes_

    def class_id_to_name(c: int) -> str:
        return "Normal" if c == 0 else f"Attack_{c}"

    class_names = [class_id_to_name(c) for c in original_classes]

    print("\nClass mapping (encoder classes -> names):")
    for c, name in zip(original_classes, class_names):
        print(f"  {c} -> {name}")

    # 3. Build the CNN-LSTM model
    print("\n--- Building Time-Series CNN-LSTM Model ---")
    cnn_lstm_model_balanced = Sequential(
        [
            Conv1D(
                filters=64,
                kernel_size=3,
                activation="relu",
                input_shape=(TIME_STEPS, n_features),
                padding="same",
            ),
            MaxPooling1D(pool_size=2),
            LSTM(100, activation="relu"),
            Dropout(0.5),
            Dense(n_outputs, activation="softmax"),
        ]
    )
    cnn_lstm_model_balanced.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    # 4. Train the model with class weights
    print("\n--- Training CNN-LSTM Model with Class Weights ---")
    history_balanced = cnn_lstm_model_balanced.fit(
        training_generator,
        epochs=5,
        verbose=1,
        class_weight=class_weights,
    )

    # 5. Evaluation on the test set
    print("\n--- Full Evaluation on Test Set ---")
    y_pred_proba_cnn = cnn_lstm_model_balanced.predict(test_generator)
    y_pred_cnn = np.argmax(y_pred_proba_cnn, axis=1)

    # Align test labels with generated sequences
    _, y_test_aligned = create_sequences(
        X_test_grouped, y_test_grouped, TIME_STEPS
    )
    y_test_final = y_test_aligned[: len(y_pred_cnn)]

    # Macro AUC-ROC over all grouped classes
    roc_auc_macro = roc_auc_score(
        y_test_final, y_pred_proba_cnn, multi_class="ovr", average="macro"
    )

    acc = accuracy_score(y_test_final, y_pred_cnn)
    print(f"\nOverall Accuracy (CNN-LSTM grouped): {acc * 100:.2f}%")
    print(f"Macro-Average AUC-ROC Score: {roc_auc_macro:.4f}")

    print("\nDetailed Classification Report (grouped classes):")
    print(
        classification_report(
            y_test_final,
            y_pred_cnn,
            target_names=class_names,
            digits=4,
        )
    )

    # 6. Multi-class ROC curve
    print("\n--- Generating Multi-Class ROC Curve Plot ---")
    y_test_binarized = label_binarize(y_test_final, classes=np.arange(n_outputs))
    fpr, tpr, roc_auc_dict = {}, {}, {}

    for i in range(n_outputs):
        fpr[i], tpr[i], _ = roc_curve(
            y_test_binarized[:, i], y_pred_proba_cnn[:, i]
        )
        roc_auc_dict[i] = auc(fpr[i], tpr[i])

    colors = cycle(
        ["aqua", "darkorange", "cornflowerblue", "green", "red", "purple", "brown", "olive", "gray"]
    )
    plt.figure(figsize=(10, 8))

    for i, color in zip(range(n_outputs), colors):
        plt.plot(
            fpr[i],
            tpr[i],
            color=color,
            lw=2,
            label=f"ROC of class {class_names[i]} (area = {roc_auc_dict[i]:.2f})",
        )

    plt.plot([0, 1], [0, 1], "k--", lw=2, label="Chance")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate", fontsize=12)
    plt.ylabel("True Positive Rate", fontsize=12)
    plt.title(
        f"CNN-LSTM - Multi-Class ROC Curve\n(Macro-Average AUC = {roc_auc_macro:.4f})",
        fontsize=16,
    )
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()


In [None]:
# ==============================================================================
# SAVE TRAINED XGBOOST MODEL (GROUPED MULTICLASS, IOMT-TRAFFICDATA)
# ==============================================================================

import os
import joblib

# Base project directory on Google Drive
base_dir = "/content/drive/MyDrive/Conference_paper_ICCC_2026"
models_dir = os.path.join(base_dir, "models")
os.makedirs(models_dir, exist_ok=True)

model_path = os.path.join(models_dir, "xgb_iomt_traffic_grouped.joblib")

if "xgb_model_grouped" in locals():
    joblib.dump(xgb_model_grouped, model_path)
    print(f"XGBoost grouped model saved successfully to:\n  {model_path}")
else:
    print("ERROR: 'xgb_model_grouped' not found. Please run STEP 3-B first.")


In [None]:
# ==============================================================================
# SAVE GROUPED TEST DATA (FEATURES + LABELS)
# ==============================================================================

import os
import numpy as np

base_dir = "/content/drive/MyDrive/Conference_paper_ICCC_2026"
data_dir = os.path.join(base_dir, "saved_test_data")
os.makedirs(data_dir, exist_ok=True)

x_test_path = os.path.join(data_dir, "X_test_grouped_iomt_traffic.npy")
y_test_path = os.path.join(data_dir, "y_test_grouped_iomt_traffic.npy")

# Check if the variables exist before saving
if "X_test_grouped" in locals() and "y_test_grouped" in locals():
    print("Saving grouped test data to files...")
    np.save(x_test_path, X_test_grouped)
    np.save(y_test_path, y_test_grouped)
    print("Test data saved successfully:")
    print(f"  X_test_grouped -> {x_test_path}")
    print(f"  y_test_grouped -> {y_test_path}")
else:
    print("ERROR: 'X_test_grouped' or 'y_test_grouped' not found. "
          "Please run the grouped preprocessing cells (STEP 2-B).")


## 2. Physiological Module (patched to use VitalDB instead of MIMIC-IV Demo)

In [None]:
# ==============================================================================
# STEP 0: ENVIRONMENT SETUP
# ==============================================================================

# --- 1. Mount Google Drive ---
# This command connects your Colab notebook to your Google Drive.
# You will be prompted to authorize the connection.
from google.colab import drive
drive.mount('/content/drive')
print("Google Drive mounted successfully!")

# --- 2. Install necessary libraries (if not already installed) ---
!pip install xgboost shap vitaldb -q

print("Setup complete. You can now proceed with Step 1.")

In [None]:
# ==========================
# T1_fix — Initialization
# ==========================
import vitaldb

# Candidate aliases for each vital. Track names vary across sites/devices,
# so we keep broad lists and pick the first available in each case.
CAND_HR = [
    "ECG_HR", "HR", "ECG/HR", "ECG_II_HR", "HR_ECG", "HR1"
]
CAND_SPO2 = [
    "SpO2", "SPO2", "PLETH_SPO2", "PLETH/SpO2", "Masimo_SpO2", "Saturation"
]
CAND_BP = [
    # Invasive arterial mean pressure (map/mean)
    "ART", "ABP", "ART_MBP", "ABP_M", "ART_Mean", "ABP_Mean", "ART_MAP", "ABP_MAP",
    # Non-invasive mean pressure
    "NBP_Mean", "NIBP_M", "NIBP_Mean", "NBP_MAP"
]

# Track sets to quickly "probe" VitalDB for case IDs.
# We try several common combinations; the first that returns non-empty wins.
PROBE_TRACK_SETS = [
    ["ECG_II", "ART"],
    ["ECG", "ART"],
    ["ECG_II", "ABP"],
    ["PLETH", "ART"],
    ["ECG", "PLETH"],
    ["ECG", "ABP"],
]

# Build `probe` by trying the above sets in order.
probe = None
for tracks in PROBE_TRACK_SETS:
    try:
        res = vitaldb.find_cases(tracks)
        if res and len(res) > 0:
            probe = res
            print(f"Probe OK with tracks {tracks} → {len(res)} case IDs found.")
            break
    except Exception as e:
        print(f"[Probe] Failed with {tracks}: {type(e).__name__}: {e}")

if probe is None:
    raise RuntimeError(
        "Could not build 'probe' with default track sets. "
        "Inspect a known case and extend CAND_* lists with the actual track names."
    )

# Optional helper to inspect the available track names for a case ID.
def inspect_tracks(case_id, limit=50):
    """
    Print the first `limit` track names for a given case, to help refine CAND_* lists.
    """
    vf = vitaldb.VitalFile(int(case_id))
    names = vf.get_track_names()
    print(f"Case {case_id} → {len(names)} tracks. First {min(limit, len(names))}:")
    for n in names[:limit]:
        print("  -", n)
    return names


In [None]:
# ==============================================================================
# STEP 1  VitalDB end-to-end (single cell, strict validation)
# - Build probe, conservative smart-pick, Drive cache
# - Fast pass at 60 min, optional refine at 15 min
# - Strong data validation before counting a subject as valid
# - Parallel threads, early-stop, heartbeats & summaries
# ==============================================================================

import os, time, random, hashlib, threading
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import islice

import numpy as np
import pandas as pd
import vitaldb
from tqdm.auto import tqdm

# ------------------------------------------------------------------------------
# CONFIG
# ------------------------------------------------------------------------------
TARGET_N_CASES       = 1000
TARGET_SUBJECTS      = 250

# Fast + refine
FAST_SCAN_INTERVAL   = 60 * 60      # 60 min
REFINE_INTERVAL      = 15 * 60      # 15 min
ENABLE_FAST_SCAN     = True
ENABLE_REFINE_PASS   = False

# Parallelism & pacing
MAX_WORKERS          = 8
CASE_SOFT_TIMEOUT    = 25
PRINT_EVERY_SUBJECTS = 10
PRINT_EVERY_SECONDS  = 30
SUMMARY_EVERY_CASES  = 25
HEARTBEAT_SECONDS    = 15
SHUFFLE_CASES        = True

# Validation thresholds (tune for your dataset)
MIN_SAMPLES          = 6            # minimum rows after resampling
MIN_OVERLAP          = 4            # rows where HR & SpO2 both present
MIN_COVERAGE         = 0.25         # fraction of non-NaN per vital
MIN_INRANGE_FRAC_HR  = 0.70         # fraction of HR values within 20..250
MIN_INRANGE_FRAC_SPO2= 0.70         # fraction of SpO2 within 50..100
HR_MEDIAN_RANGE      = (30, 150)    # plausible median HR
SPO2_MEDIAN_RANGE    = (80, 100)    # plausible median SpO2

# Cache on Google Drive
PROJECT_SUBDIR       = "Conference_paper_ICCC_2026"
GDRIVE_BASES         = ["/content/drive/MyDrive", "/content/drive/MyDrive/Apps"]
CACHE_BACKEND        = "parquet"

# ------------------------------------------------------------------------------
# Cache helpers
# ------------------------------------------------------------------------------
def resolve_cache_dir():
    for base in GDRIVE_BASES:
        if os.path.isdir(base):
            return os.path.join(base, PROJECT_SUBDIR)
    fallback = os.path.join("/content", PROJECT_SUBDIR)
    print(
        f"[Warning] Google Drive not mounted.\n"
        f"Using local ephemeral cache: {fallback}\n"
        f"To mount in Colab:\n"
        f"  from google.colab import drive\n"
        f"  drive.mount('/content/drive')"
    )
    return fallback

CACHE_DIR = resolve_cache_dir()
os.makedirs(CACHE_DIR, exist_ok=True)
print("Cache dir →", os.path.abspath(CACHE_DIR))

def _cache_key(cid, tracks, interval):
    key_src = f"cid={cid}|tracks={','.join(tracks)}|interval={interval}"
    return f"case_{cid}_{hashlib.md5(key_src.encode('utf-8')).hexdigest()[:12]}"

def cache_paths(cid, tracks, interval):
    base = os.path.join(CACHE_DIR, _cache_key(cid, tracks, interval))
    return {"parquet": base + ".parquet", "pickle": base + ".pkl.gz", "tmp": base + ".tmp"}

def load_from_cache(cid, tracks, interval):
    paths = cache_paths(cid, tracks, interval)
    if os.path.exists(paths["parquet"]):
        try: return pd.read_parquet(paths["parquet"])
        except Exception: pass
    if os.path.exists(paths["pickle"]):
        try: return pd.read_pickle(paths["pickle"], compression="gzip")
        except Exception: pass
    return None

def save_to_cache(df, cid, tracks, interval):
    paths = cache_paths(cid, tracks, interval)
    tmp_path = paths["tmp"]
    try:
        if CACHE_BACKEND == "parquet":
            try:
                df.to_parquet(tmp_path)
                os.replace(tmp_path, paths["parquet"]); return
            except Exception:
                pass
        df.to_pickle(tmp_path, compression="gzip")
        os.replace(tmp_path, paths["pickle"])
    finally:
        if os.path.exists(tmp_path):
            try: os.remove(tmp_path)
            except Exception: pass

# ------------------------------------------------------------------------------
# Conservative smart selector (regex + negative filters)
# ------------------------------------------------------------------------------
def smart_pick(tracks, kind):
    """
    Heuristically pick one track of a given kind in {'hr','spo2','map'}.
    Conservative filtering to avoid alarms/derived/non-phys channels.
    """
    bad_tokens = ["alarm", "arr", "arrhythm", "resp", "rr", "quality", "flag", "beat-to-beat", "status"]
    cands = []
    for t in tracks:
        n = t.lower()
        if any(b in n for b in bad_tokens):
            continue

        if kind == "hr":
            if ("hr" in n or "heart" in n):
                score = 0
                if "ecg" in n: score += 4
                if "/hr" in n or n.endswith("_hr"): score += 2
                if "calc" in n or "derived" in n: score -= 1
                if "nibp" in n or "nbp" in n: score -= 2
                cands.append((score, t))

        elif kind == "spo2":
            if ("spo2" in n or "saturation" in n):
                score = 0
                if "pleth" in n: score += 1  # hint, but SpO2 is numeric channel
                if "masimo" in n or "mindray" in n or "solar8000" in n: score += 1
                if "calc" in n or "derived" in n: score -= 1
                cands.append((score, t))

        elif kind == "map":
            # prefer invasive mean first
            if ("art" in n or "abp" in n or "ibp" in n or "a-line" in n):
                if ("map" in n or "mean" in n or n.endswith("_m")):
                    score = 5
                    if "art" in n or "abp" in n: score += 2
                    cands.append((score, t))
            elif ("nibp" in n or "nbp" in n):
                if ("map" in n or "mean" in n or n.endswith("_m")):
                    score = 1
                    cands.append((score, t))

    if not cands:
        return None
    cands.sort(key=lambda x: (-x[0], len(x[1])))
    return cands[0][1]

# ------------------------------------------------------------------------------
# Strict validation of a per-case dataframe
# ------------------------------------------------------------------------------
def _frac_in_range(s, lo, hi):
    s = s.dropna()
    if s.empty: return 0.0
    return ((s >= lo) & (s <= hi)).mean()

def validate_case_df(df):
    """
    Return (True, info) if valid, else (False, reason).
    Requires: enough rows, coverage, overlap, plausible medians and in-range fractions.
    """
    cols = [c for c in ["Heart_Rate", "SpO2", "Arterial_BP_Mean"] if c in df.columns]
    if len(cols) < 2:
        return False, "too-few-cols"

    n = len(df)
    if n < MIN_SAMPLES:
        return False, "too-short"

    # coverage per vital
    cov_hr = df["Heart_Rate"].notna().mean() if "Heart_Rate" in df else 0.0
    cov_s  = df["SpO2"].notna().mean() if "SpO2" in df else 0.0
    if cov_hr < MIN_COVERAGE or cov_s < MIN_COVERAGE:
        return False, "insufficient-coverage"

    # overlap rows
    overlap = ((~df["Heart_Rate"].isna()) & (~df["SpO2"].isna())).sum()
    if overlap < MIN_OVERLAP:
        return False, "insufficient-overlap"

    # physiological sanity
    frac_hr_ok = _frac_in_range(df["Heart_Rate"], 20, 250) if "Heart_Rate" in df else 0.0
    frac_s_ok  = _frac_in_range(df["SpO2"], 50, 100) if "SpO2" in df else 0.0
    if frac_hr_ok < MIN_INRANGE_FRAC_HR or frac_s_ok < MIN_INRANGE_FRAC_SPO2:
        return False, "out-of-range"

    med_hr = np.nanmedian(df["Heart_Rate"].values) if "Heart_Rate" in df else np.nan
    med_s  = np.nanmedian(df["SpO2"].values) if "SpO2" in df else np.nan
    if not (HR_MEDIAN_RANGE[0] <= med_hr <= HR_MEDIAN_RANGE[1]):
        return False, "hr-median-implausible"
    if not (SPO2_MEDIAN_RANGE[0] <= med_s <= SPO2_MEDIAN_RANGE[1]):
        return False, "spo2-median-implausible"

    # optional: MAP sanity if present (do not fail hard; only tag)
    if "Arterial_BP_Mean" in df:
        frac_map_ok = _frac_in_range(df["Arterial_BP_Mean"], 20, 200)
        # we don't require MAP; if wildly off, we could drop the column:
        if frac_map_ok < 0.5:
            df["Arterial_BP_Mean"] = np.where(
                (df["Arterial_BP_Mean"] >= 20) & (df["Arterial_BP_Mean"] <= 200),
                df["Arterial_BP_Mean"],
                np.nan
            )

    # info summary to print on FOUND
    info = {
        "rows": n,
        "cov_hr": cov_hr,
        "cov_spo2": cov_s,
        "overlap": int(overlap),
        "med_hr": float(med_hr),
        "med_spo2": float(med_s),
        "frac_hr_ok": float(frac_hr_ok),
        "frac_spo2_ok": float(frac_s_ok),
    }
    return True, info

# ------------------------------------------------------------------------------
# Build `probe`
# ------------------------------------------------------------------------------
PROBE_TRACK_SETS = [
    ["ECG_II", "ART"], ["ECG", "ART"], ["ECG_II", "ABP"],
    ["PLETH", "ART"], ["ECG", "PLETH"], ["ECG", "ABP"],
]
probe = None
for tracks in PROBE_TRACK_SETS:
    try:
        res = vitaldb.find_cases(tracks)
        if res and len(res) > 0:
            probe = res
            print(f"Probe OK with tracks {tracks} → {len(res)} case IDs.")
            break
    except Exception as e:
        print(f"[Probe] Failed with {tracks}: {type(e).__name__}: {e}")
if probe is None:
    raise RuntimeError("Could not build 'probe'. Adjust PROBE_TRACK_SETS.")

all_case_ids = [int(x) for x in probe]
case_ids = all_case_ids[:TARGET_N_CASES]
if SHUFFLE_CASES:
    random.shuffle(case_ids)

print(f"Total cases from probe: {len(all_case_ids)}")
print(f"Will inspect up to: {len(case_ids)} cases\n")

# ------------------------------------------------------------------------------
# Per-case worker (parametric interval) + strict validation
# ------------------------------------------------------------------------------
def process_case_with_interval(cid, interval):
    """
    Return (cid, df_case or None, src_str, reason_or_info).
    On success: src_str in {'cache','download'} and reason_or_info is a dict with metrics.
    On skip: src_str is None and reason_or_info is a string reason.
    """
    try:
        vf = vitaldb.VitalFile(cid)
        case_tracks = vf.get_track_names()

        hr_track   = smart_pick(case_tracks, "hr")
        spo2_track = smart_pick(case_tracks, "spo2")
        map_track  = smart_pick(case_tracks, "map")

        if (hr_track is None) or (spo2_track is None):
            return cid, None, None, "skip-missing-tracks"

        selected = [t for t in [hr_track, spo2_track, map_track] if t is not None]

        # cache first
        df_case = load_from_cache(cid, selected, interval)
        src = "cache" if df_case is not None else "download"

        if df_case is None:
            t_start = time.time()
            df_case = vf.to_pandas(selected, interval=interval)
            if df_case is None or df_case.empty:
                return cid, None, None, "empty"

            rename_map = {hr_track: "Heart_Rate", spo2_track: "SpO2"}
            if map_track is not None:
                rename_map[map_track] = "Arterial_BP_Mean"
            df_case = df_case.rename(columns=rename_map)

            keep_cols = [c for c in ["Heart_Rate", "SpO2", "Arterial_BP_Mean"] if c in df_case.columns]
            if len(keep_cols) < 2:
                return cid, None, None, "too-few-cols"

            df_case = df_case.reset_index().rename(columns={"index": "charttime"})
            df_case["subject_id"] = cid
            df_case = df_case.dropna(subset=keep_cols, how="all")
            df_case = df_case[["subject_id", "charttime"] + keep_cols]
            if df_case.empty:
                return cid, None, None, "all-nan"

            slow = (time.time() - t_start) > CASE_SOFT_TIMEOUT

            # strict validation
            ok, info = validate_case_df(df_case)
            if not ok:
                return cid, None, None, info  # info is reason string

            # cache best-effort
            try:
                save_to_cache(df_case, cid, selected, interval)
            except Exception as ce:
                tqdm.write(f"[Cache] write failed for case {cid}: {type(ce).__name__}: {ce}")

            # tag slow if needed
            if slow and isinstance(info, dict):
                info = {**info, "slow": True}

            return cid, df_case, src, info

        else:
            # strict validation on cached data too
            ok, info = validate_case_df(df_case)
            if not ok:
                return cid, None, None, info
            return cid, df_case, src, info

    except Exception as e:
        return cid, None, None, f"error:{type(e).__name__}"

# ------------------------------------------------------------------------------
# Parallel pass with early-stop + strict validation
# ------------------------------------------------------------------------------
def run_parallel_pass(interval, target_subjects, tag="FAST"):
    found = []
    skip_counts = Counter()
    t0 = time.time()
    last_log_time = t0
    last_log_subjects = 0
    last_heartbeat = t0
    total = len(case_ids)
    cases_seen = 0
    stop_flag = threading.Event()

    def heartbeat():
        nonlocal last_heartbeat
        now = time.time()
        if (now - last_heartbeat) >= HEARTBEAT_SECONDS:
            pct = (cases_seen / max(1,total)) * 100.0
            top = ", ".join(f"{k}:{v}" for k,v in skip_counts.most_common(3)) or "—"
            tqdm.write(f"[Heartbeat] processed={cases_seen}/{total} ({pct:.1f}%) | valid={len(found)} | skips: {top}")
            last_heartbeat = now

    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futures = {ex.submit(process_case_with_interval, cid, interval): cid for cid in case_ids}
        with tqdm(total=total, desc=f"{tag} pass (interval={interval//60}min) x{MAX_WORKERS}") as pbar:
            try:
                for fut in as_completed(futures):
                    if stop_flag.is_set():
                        break
                    cid = futures[fut]
                    cid_out, df_case, src, meta = fut.result()
                    pbar.update(1)
                    cases_seen += 1

                    if df_case is None:
                        reason = meta if isinstance(meta, str) else "unknown"
                        skip_counts[reason] += 1
                        if SUMMARY_EVERY_CASES and cases_seen % SUMMARY_EVERY_CASES == 0:
                            top = ", ".join(f"{k}:{v}" for k,v in skip_counts.most_common(6))
                            tqdm.write(f"[Summary] processed={cases_seen} | valid={len(found)} | skips: {top}")
                        heartbeat()
                        continue

                    found.append((cid_out, df_case))
                    # meta is dict with validation info
                    info = ", ".join([
                        f"rows={meta.get('rows')}",
                        f"cov_hr={meta.get('cov_hr'):.2f}",
                        f"cov_spo2={meta.get('cov_spo2'):.2f}",
                        f"overlap={meta.get('overlap')}",
                        f"med_hr={meta.get('med_hr'):.1f}",
                        f"med_spo2={meta.get('med_spo2'):.1f}",
                    ])
                    if meta.get("slow", False):
                        info += " | slow"
                    tqdm.write(f"[FOUND] subject_id={cid_out} | {info} | src={src} | total_valid={len(found)}")

                    now = time.time()
                    if len(found) >= target_subjects:
                        tqdm.write(f"\nReached target of {target_subjects} subjects – stopping early & cancelling remaining.")
                        stop_flag.set()
                        for f in futures:
                            if not f.done():
                                f.cancel()
                        break

                    by_subj = (PRINT_EVERY_SUBJECTS is not None and len(found) > 0 and
                               (len(found) - last_log_subjects) >= PRINT_EVERY_SUBJECTS)
                    by_sec  = (PRINT_EVERY_SECONDS is not None and (now - last_log_time) >= PRINT_EVERY_SECONDS)
                    if by_subj or by_sec:
                        elapsed = now - t0
                        rate = len(found) / elapsed if elapsed > 0 else 0.0
                        pct = (cases_seen / max(1,total)) * 100.0
                        tqdm.write(
                            f"[Progress] valid={len(found)}/{target_subjects} "
                            f"({len(found)/max(1,target_subjects):.0%}) | "
                            f"processed={cases_seen}/{total} ({pct:.1f}%) | "
                            f"elapsed={elapsed/60:.1f}m | rate={rate:.2f} subj/s"
                        )
                        last_log_time = now
                        last_log_subjects = len(found)
                    heartbeat()
            finally:
                ex.shutdown(cancel_futures=True)

    elapsed = time.time() - t0
    tqdm.write(f"[{tag}] Done in {elapsed/60:.1f} minutes. Found {len(found)} valid subjects.")
    return found, skip_counts

# ------------------------------------------------------------------------------
# RUN
# ------------------------------------------------------------------------------
if MAX_WORKERS <= 0:
    raise RuntimeError("Set MAX_WORKERS > 0 for the speed-optimized path.")

first_interval = FAST_SCAN_INTERVAL if ENABLE_FAST_SCAN else REFINE_INTERVAL
found_fast, skip_counts = run_parallel_pass(first_interval, TARGET_SUBJECTS, tag="FAST" if ENABLE_FAST_SCAN else "MAIN")

if len(found_fast) == 0:
    top = ", ".join(f"{k}:{v}" for k,v in skip_counts.most_common(10)) or "—"
    raise RuntimeError(f"No valid subjects found. Skip reasons: {top}")

df_wide = pd.concat([df for _, df in found_fast], ignore_index=True)
print(f"\nBuilt 'df_wide' (interval={first_interval//60}min) with shape: {df_wide.shape}")
print(df_wide.head())

# Optional refine at 15 min for only the found subjects
if ENABLE_REFINE_PASS and ENABLE_FAST_SCAN and (REFINE_INTERVAL != first_interval):
    tqdm.write("\n[REFINE] Re-fetching found subjects at 15-min interval…")
    subject_ids = [cid for cid, _ in found_fast]

    def refine_case(cid):
        _, df_case, src, meta = process_case_with_interval(cid, REFINE_INTERVAL)
        return cid, df_case, src, meta

    refined, skip_ref = [], Counter()
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futures = {ex.submit(refine_case, cid): cid for cid in subject_ids}
        for fut in tqdm(as_completed(futures), total=len(futures), desc=f"REFINE pass (interval={REFINE_INTERVAL//60}min)"):
            cid = futures[fut]
            cid_out, df_case, src, meta = fut.result()
            if df_case is None:
                reason = meta if isinstance(meta, str) else "unknown"
                skip_ref[reason] += 1
            else:
                tqdm.write(f"[REFINE-FOUND] subject_id={cid_out} | rows={len(df_case)} | src={src}")
                refined.append(df_case)

    if refined:
        df_wide = pd.concat(refined, ignore_index=True)
        print(f"[REFINE] Rebuilt 'df_wide' (interval={REFINE_INTERVAL//60}min) with shape: {df_wide.shape}")
        print(df_wide.head())
    else:
        top = ", ".join(f"{k}:{v}" for k,v in skip_ref.most_common(10)) or "—"
        tqdm.write(f"[REFINE] No subjects refined. Skip reasons: {top}")


In [None]:
# ==============================================================================
# VITALDB - STEP 2: TIME-SERIES PREPROCESSING AND IMPUTATION
# ==============================================================================
from sklearn.preprocessing import StandardScaler

if 'df_wide' in locals():
    print("--- Starting time-series preprocessing ---")

    # 1. Convert 'charttime' column to datetime objects
    df_wide['charttime'] = pd.to_datetime(df_wide['charttime'])
    print("-> Converted 'charttime' to datetime objects.")

    # 2. Resample and impute data for each patient
    # We will process each patient individually to not mix their data
    processed_patients = []

    # Use .groupby() to iterate over each patient's data
    for patient_id, group in df_wide.groupby('subject_id'):
        # Set the time column as the index for time-based operations
        group = group.set_index('charttime').drop('subject_id', axis=1)

        # Resample to a fixed frequency (e.g., every 15 minutes) and take the mean
        # This creates a uniform timeline for all patients
        group_resampled = group.resample('15T').mean()

        # Impute missing values using forward-fill, then backward-fill
        group_imputed = group_resampled.fillna(method='ffill').fillna(method='bfill')

        # Add the patient_id back
        group_imputed['subject_id'] = patient_id

        processed_patients.append(group_imputed)

    # Concatenate all processed patient dataframes back into one
    df_processed = pd.concat(processed_patients).reset_index()
    print("-> Resampled to a 15-minute frequency and imputed missing values.")

    # 3. Handle any fully-NaN patients that might remain and scale features
    df_processed.dropna(inplace=True) # Drop patients with no measurements at all

    vital_cols = [col for col in df_processed.columns if col not in ['subject_id', 'charttime']]

    scaler = StandardScaler()
    df_processed[vital_cols] = scaler.fit_transform(df_processed[vital_cols])
    print("-> Scaled vital sign features.")

    # 4. Display final result
    print("\n--- Preprocessing complete. Data is now clean and uniform. ---")
    display(df_processed.info())
    display(df_processed.head())

else:
    print("ERROR: Wide-format dataframe 'df_wide' not found. Please run Step 1 successfully.")

In [None]:
#Save the fully processed VitalDB dataframe used by the AE
processed_clinical_path = "/content/drive/MyDrive/Conference_paper_ICCC_2026/vitaldb_df_processed_final.parquet"
df_processed.to_parquet(processed_clinical_path, index=False)
print("✅ Saved VitalDB processed clinical dataframe to:", processed_clinical_path)


In [None]:
# ==============================================================================
# SAVE THE PROCESSED DATAFRAME (Corrected Variable Name)
# ==============================================================================

# The variable created by our last preprocessing script is 'df_processed'
if 'df_processed' in locals():
    # Define the path where the file will be saved
    save_path = '/content/drive/MyDrive/Conference_paper_ICCC_2026/df_processed_final.parquet'

    # Save the dataframe to a Parquet file
    print(f"Saving the processed dataframe ('df_processed') to {save_path}...")
    df_processed.to_parquet(save_path)
    print("Save complete!")
else:
    print("ERROR: Dataframe 'df_processed' not found. Please ensure the preprocessing cell has been run successfully.")

In [None]:
# ==============================================================================
# VITALDB - STEP 3: UNSUPERVISED ANOMALY DETECTION WITH LSTM AUTOENCODER
# (now using VitalDB-derived df_processed)
# ==============================================================================
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, RepeatVector, TimeDistributed
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- Helper Function to create time-series sequences ---
def create_sequences(X_data, y_data, time_steps=10):
    """Creates time-series sequences from 2D data."""
    Xs, ys = [], []
    for i in range(len(X_data) - time_steps):
        v = X_data[i:(i + time_steps)]
        Xs.append(v)
        # The y_data is just a placeholder here, not used for training
        ys.append(y_data[i + time_steps - 1])
    return np.array(Xs), np.array(ys)

# --- Main Logic ---
if 'df_processed' in locals():
    # 1. Prepare Data for the Autoencoder
    print("--- Preparing data for Autoencoder ---")

    # Select only the vital sign columns for training
    # With VitalDB we may have fewer vital signs than the original MIMIC-IV demo.
    # We therefore select only the columns that are actually present in df_processed.
    candidate_vital_cols = [
        'Arterial_BP_Mean',
        'Heart_Rate',
        'SpO2',
        'Arterial_BP_Diastolic',
        'Arterial_BP_Systolic',
        'Respiratory_Rate',
        'Temperature_C',
    ]
    vital_cols = [c for c in candidate_vital_cols if c in df_processed.columns]
    print("Using vital columns for AE:", vital_cols)
    if len(vital_cols) == 0:
        raise RuntimeError(
            "No valid vital sign columns found in df_processed. "
            "Please check preprocessing / column names."
        )

    data_for_model = df_processed[vital_cols].values

    # Create sequences. We don't need the 'y' labels for autoencoder training.
    TIME_STEPS = 24  # Using a window of 24 samples (e.g., 6 hours if data is every 15 mins)
    X_sequences, _ = create_sequences(
        data_for_model,
        data_for_model[:, 0],
        time_steps=TIME_STEPS
    )
    print(f"Created {X_sequences.shape[0]} sequences of shape {X_sequences.shape[1:]}")

    # Split into training and testing sets. We will test the model's ability to reconstruct unseen data.
    X_train, X_test = train_test_split(
        X_sequences,
        test_size=0.2,
        random_state=42
    )
    print(f"Training data shape: {X_train.shape}")
    print(f"Test data shape: {X_test.shape}")

    # 2. Build the LSTM Autoencoder Model
    n_features = X_train.shape[2]
    timesteps = X_train.shape[1]

    print("\n--- Building LSTM Autoencoder Model ---")
    inputs = Input(shape=(timesteps, n_features))
    # Encoder
    encoded = LSTM(64, activation='relu')(inputs)
    # Bottleneck
    bottleneck = RepeatVector(timesteps)(encoded)
    # Decoder
    decoded = LSTM(64, activation='relu', return_sequences=True)(bottleneck)
    # Output Layer
    outputs = TimeDistributed(Dense(n_features))(decoded)

    autoencoder = Model(inputs, outputs)
    autoencoder.compile(optimizer='adam', loss='mae')  # Using Mean Absolute Error as the loss function
    autoencoder.summary()

    # 3. Train the Autoencoder
    print("\n--- Training Autoencoder on normal patterns ---")
    history = autoencoder.fit(
        X_train, X_train,  # The model learns to predict its own input
        epochs=10,
        batch_size=64,
        validation_split=0.1,
        verbose=1
    )

    # 4. Detect Anomalies based on Reconstruction Error
    print("\n--- Detecting anomalies based on reconstruction error ---")
    X_test_pred = autoencoder.predict(X_test)

    # Reconstruction error per sample (media su time e feature)
    test_mae_loss = np.mean(np.abs(X_test_pred - X_test), axis=(1, 2))  # shape: (n_samples,)

    # Threshold sul 95° percentile dei sample-level errors
    threshold = np.quantile(test_mae_loss, 0.95)
    print(f"Reconstruction error threshold for anomalies set to: {threshold:.3f}")

    # Indici delle anomalie
    anomaly_indices = np.where(test_mae_loss > threshold)[0]
    print(f"Found {len(anomaly_indices)} potential anomalies in the test set.")


    # 5. Visualize a detected anomaly
    if len(anomaly_indices) > 0:
        print("\n--- Visualizing a detected anomaly ---")
        idx_to_plot = anomaly_indices[0]

        plt.figure(figsize=(14, 6))
        # Plot the original signal (e.g., Heart Rate, which is column index 3 in vital_cols)
        hr_idx = vital_cols.index("Heart_Rate")  # se presente
        plt.plot(X_test[idx_to_plot, :, hr_idx], label='Original Heart Rate')
        plt.plot(X_test_pred[idx_to_plot, :, hr_idx], label='Reconstructed Heart Rate', linestyle='--')
        plt.title('Example of an Anomalous Sequence (High Reconstruction Error)')
        plt.legend()
        plt.show()

else:
    print("ERROR: Processed dataframe 'df_processed' not found. Please run the preprocessing steps first.")


In [None]:
# ==============================================================================
# VITALDB – GLOBAL AE ERRORS ON TEST SET
# ==============================================================================

# We assume X_test is the VitalDB test sequences used for AE evaluation
# (shape: [N_test, TIME_STEPS, n_features])

print("Computing AE reconstruction errors on full VitalDB test set...")

X_cli_test_full = X_test  # rename for clarity in downstream fusion evaluation

# Reconstruct in batches (if dataset is large)
from tqdm.auto import tqdm
BATCH_SIZE_AE = 256

def predict_ae_batched(model, X, batch_size=256, desc="AE inference"):
    outs = []
    iterator = range(0, len(X), batch_size)
    for i in tqdm(iterator, desc=desc, leave=False):
        outs.append(model.predict(X[i:i+batch_size], verbose=0))
    return np.concatenate(outs, axis=0)

X_cli_recon_full = predict_ae_batched(autoencoder, X_cli_test_full, batch_size=BATCH_SIZE_AE)

clin_errors_full = np.mean(np.abs(X_cli_recon_full - X_cli_test_full), axis=(1, 2))

print(f"VitalDB test sequences: {len(clin_errors_full)}")
print(f"AE error mean={clin_errors_full.mean():.4f}, std={clin_errors_full.std():.4f}")
print(f"clinical_threshold (from train/val) = {clinical_threshold:.4f}")

# Simple binary anomaly flag for scenarios:
# 0 = low error, 1 = high error w.r.t. calibrated threshold
clin_flag_full = (clin_errors_full > clinical_threshold).astype(int)
print("clin_flag_full counts (0=low-error, 1=high-error):", np.bincount(clin_flag_full))


In [None]:
# ==============================================================================
# VITALDB - STEP 4: ANALYSIS OF DETECTED ANOMALIES
# ==============================================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns  # optional, for nicer histograms

# Check if the results from the autoencoder training cell exist
if 'autoencoder' in locals() and 'X_test' in locals():
    print("--- Analyzing the results of the LSTM Autoencoder ---")

    # 1. Recalculate reconstruction error on the test set (sample-level)
    X_test_pred = autoencoder.predict(X_test)
    # One scalar error per sequence: mean over time and features
    sample_errors = np.mean(np.abs(X_test_pred - X_test), axis=(1, 2))  # shape: (n_samples,)

    # 2. Plot the distribution of reconstruction errors
    plt.figure(figsize=(10, 6))
    sns.histplot(sample_errors, bins=50, kde=True)
    plt.xlabel("Mean Absolute Error (per sequence)")
    plt.ylabel("Number of Sequences")
    plt.title("Distribution of Reconstruction Errors on Test Set")
    plt.show()

    # 3. Define a threshold and identify anomalies
    threshold = np.quantile(sample_errors, 0.95)
    print(f"\nReconstruction error threshold for anomalies set to: {threshold:.4f} (95th percentile)")

    anomaly_indices = np.where(sample_errors > threshold)[0]
    print(f"Found {len(anomaly_indices)} potential anomalies in the test set "
          f"({len(anomaly_indices) / len(X_test) * 100:.2f}% of test data).")

    # 4. Visualize one of the detected anomalies
    if len(anomaly_indices) > 0:
        print("\n--- Visualizing a top detected anomaly ---")

        # Choose the anomaly with the highest reconstruction error
        top_anomaly_idx = anomaly_indices[np.argmax(sample_errors[anomaly_indices])]

        # Find the index of Heart_Rate in the current vital_cols, fallback to 0 if not present
        if "Heart_Rate" in vital_cols:
            hr_idx = vital_cols.index("Heart_Rate")
        else:
            print("Warning: 'Heart_Rate' not found in vital_cols, using the first feature instead.")
            hr_idx = 0

        plt.figure(figsize=(14, 6))
        plt.plot(X_test[top_anomaly_idx, :, hr_idx], label=f'Original {vital_cols[hr_idx]}')
        plt.plot(X_test_pred[top_anomaly_idx, :, hr_idx],
                 label=f'Reconstructed {vital_cols[hr_idx]}',
                 linestyle='--')
        plt.title(f'Example of a Detected Anomaly (Sequence Index: {top_anomaly_idx})')
        plt.legend()
        plt.grid(True)
        plt.show()

else:
    print("ERROR: Autoencoder results not found. Please re-run the training cell (Step 3) successfully.")


In [None]:
# ==============================================================================
# FINAL ANALYSIS: FEATURE-LEVEL RECONSTRUCTION ERROR (SHAP Alternative)
# ==============================================================================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Assicuriamoci che i risultati dell'autoencoder esistano
if 'autoencoder' in locals() and 'X_test' in locals():
    # Usiamo lo stesso indice dell'anomalia che abbiamo visualizzato prima
    if 'top_anomaly_idx' in locals():

        print(f"--- Analyzing feature contribution for anomaly index: {top_anomaly_idx} ---")

        # Prendiamo la sequenza originale e quella ricostruita
        original_sequence = X_test[top_anomaly_idx]
        reconstructed_sequence = X_test_pred[top_anomaly_idx]

        # Calcoliamo l'errore assoluto medio per ogni feature lungo i passi temporali
        feature_errors = np.mean(np.abs(original_sequence - reconstructed_sequence), axis=0)

        # Creiamo un DataFrame per una facile visualizzazione
        error_df = pd.DataFrame({
            'Feature': vital_cols,
            'Reconstruction_Error': feature_errors
        }).sort_values(by='Reconstruction_Error', ascending=False)

        # Visualizziamo gli errori
        plt.figure(figsize=(10, 6))
        sns.barplot(x='Reconstruction_Error', y='Feature', data=error_df, palette='viridis')
        plt.title(f'Feature Contribution to Anomaly Score (Sequence {top_anomaly_idx})', fontsize=16)
        plt.xlabel('Mean Absolute Error', fontsize=12)
        plt.ylabel('Vital Sign Feature', fontsize=12)
        plt.grid(axis='x')
        plt.show()

        print("\n--- Interpretation ---")
        display(error_df)
        print("\nThe feature(s) with the highest reconstruction error are the primary drivers of this anomaly.")

    else:
        print("No anomaly was previously identified to analyze.")
else:
    print("ERROR: Autoencoder results not found.")

In [None]:
# ==============================================================================
# VITALDB - FINAL EXPERIMENT: DETECTING A SYNTHETIC SENSOR FAULT
# ==============================================================================

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Check if the required variables from previous steps exist
if 'autoencoder' in locals() and 'X_test' in locals():
    print("--- Starting Synthetic Anomaly Detection Test ---")

    # ------------------------------------------------------------------
    # 0. Ensure X_test is a proper NumPy array with shape (N, T, F)
    # ------------------------------------------------------------------
    X_test_np = np.asarray(X_test, dtype=np.float32)

    if X_test_np.ndim != 3:
        raise ValueError(
            f"Expected X_test with 3 dimensions (n_samples, timesteps, features), "
            f"got shape {X_test_np.shape}."
        )

    # 1. Recompute reconstruction errors on the test set (one scalar per sequence)
    X_test_pred = autoencoder.predict(X_test_np, verbose=0)
    sample_errors = np.mean(np.abs(X_test_pred - X_test_np), axis=(1, 2))  # shape: (n_samples,)

    # 2. Find a "normal" sequence from the test set (with the lowest reconstruction error)
    normal_sequence_idx = int(np.argmin(sample_errors))
    original_normal_sequence = X_test_np[normal_sequence_idx]

    # Calculate its original low error
    original_normal_pred = autoencoder.predict(
        np.expand_dims(original_normal_sequence, axis=0),
        verbose=0
    )
    original_error = np.mean(np.abs(original_normal_pred - original_normal_sequence))
    print(
        f"Selected a normal sequence (index {normal_sequence_idx}) with a low "
        f"reconstruction error of: {original_error:.4f}"
    )

    # 3. Create a synthetic anomaly: a sensor fault on SpO2
    faulty_sequence = original_normal_sequence.copy()

    # Determine the index of SpO2 in the current vital_cols
    if "SpO2" in vital_cols:
        spo2_index = vital_cols.index("SpO2")
    else:
        print("Warning: 'SpO2' not found in vital_cols. Using the last feature as a proxy.")
        spo2_index = len(vital_cols) - 1

    # Simulate the SpO2 sensor flat-lining at a very low value for 5 time steps
    # (values are in scaled space, so -3 is a strong deviation)
    start_fault = 10
    end_fault = min(start_fault + 5, faulty_sequence.shape[0])
    faulty_sequence[start_fault:end_fault, spo2_index] = -3.0
    print("\n-> Injected a synthetic SpO2 sensor fault into the normal sequence.")

    # 4. Test the model on the faulty sequence
    faulty_sequence_reshaped = np.expand_dims(faulty_sequence, axis=0)
    reconstructed_faulty_sequence = autoencoder.predict(
        faulty_sequence_reshaped,
        verbose=0
    )

    # 5. Compare the reconstruction error
    faulty_error = np.mean(np.abs(reconstructed_faulty_sequence - faulty_sequence))
    print(f"Reconstruction error on the faulty sequence: {faulty_error:.4f}")
    if faulty_error > original_error * 2:
        print("SUCCESS: The model clearly reacts to the synthetic fault with a much higher error.")
    else:
        print("Note: The error increased, but not dramatically. "
              "You may adjust the fault magnitude or window.")

    # 6. Monochrome, journal-style figure for the SpO2 channel
    time_axis = np.arange(faulty_sequence.shape[0])

    # Compact scientific style
    plt.rcParams.update({
        "font.size": 9,
        "axes.labelsize": 9,
        "axes.titlesize": 10,
        "legend.fontsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
    })

    fig, ax = plt.subplots(figsize=(5.5, 2.8))  # suitable for single-column width

    # Observed (faulty) SpO2: solid black with markers
    ax.plot(
        time_axis,
        faulty_sequence[:, spo2_index],
        marker='o',
        linestyle='-',
        linewidth=1.2,
        markersize=3.5,
        color='black',
        label='Observed SpO$_2$',
    )

    # Reconstructed SpO2: black dashed
    ax.plot(
        time_axis,
        reconstructed_faulty_sequence[0, :, spo2_index],
        linestyle='--',
        linewidth=1.2,
        color='black',
        label='Reconstructed SpO$_2$',
    )

    # Fault window: light grey band
    ax.axvspan(
        start_fault,
        end_fault - 1,
        facecolor='0.9',   # light grey
        edgecolor='none',
        alpha=1.0,
    )

    # Axis labels (short, neutral)
    ax.set_xlabel('Time step (15 min)')
    ax.set_ylabel('Scaled SpO$_2$')

    # Remove top/right spines for a cleaner scientific look
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='both', which='both', direction='out')

    # Legend without frame
    ax.legend(loc='best', frameon=False)

    fig.tight_layout()

    # Save figure for paper usage
    if "PROJECT_DIR" in locals():
        project_dir = Path(PROJECT_DIR)
    else:
        project_dir = Path(".")

    output_dir = project_dir / "output"
    output_dir.mkdir(parents=True, exist_ok=True)

    fig_path = output_dir / "vitaldb_spo2_synthetic_fault_monochrome.png"
    fig.savefig(fig_path, dpi=300, bbox_inches="tight")
    print(f"[INFO] Figure saved to: {fig_path}")

    plt.show()

else:
    print("ERROR: Autoencoder results not found. Please re-run the previous steps.")


In [None]:
# Save the trained Autoencoder model
autoencoder.save('/content/drive/MyDrive/Conference_paper_ICCC_2026/lstm_autoencoder.keras')
print("LSTM Autoencoder model saved successfully.")

In [None]:
# ==========================================================
# SAVE CLINICAL TEST DATA (CORRECTED VARIABLE NAME)
# Run this cell once after the autoencoder training is complete
# ==========================================================
import numpy as np

# Check if the correct variable 'X_test' exists
if 'X_test' in locals():
    print("Saving clinical test sequences to file...")
    # We save the 'X_test' variable to a file named 'X_test_sequences.npy'
    # so the fusion notebook can find it.
    np.save('/content/drive/MyDrive/Conference_paper_ICCC_2026/X_test_sequences.npy', X_test)
    print("✅ Clinical test sequences saved successfully.")
else:
    print("❌ ERROR: 'X_test' not found. Please ensure the autoencoder training cell has been run successfully.")

## 3. Fusion and Evaluation Module (original notebook `03_fusion_framework_evaluation.ipynb`)

In [None]:
# ==============================================================================
# FUSION FRAMEWORK - STEP 1: SETUP AND LOAD ASSETS (IoMT-TrafficData + VitalDB)
# ==============================================================================
import os
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

# --- 1. Mount Google Drive ---
from google.colab import drive
drive.mount('/content/drive')

# --- 2. Define a single project path (both security + clinical live here) ---
project_path = "/content/drive/MyDrive/Conference_paper_ICCC_2026"

# Paths for models and data inside the project
models_dir = os.path.join(project_path, "models")
saved_test_dir = os.path.join(project_path, "saved_test_data")

xgb_model_path = os.path.join(models_dir, "xgb_iomt_traffic_grouped.joblib")
autoencoder_path = os.path.join(project_path, "lstm_autoencoder.keras")
clinical_data_path = os.path.join(project_path, "df_processed_final.parquet")

print(f"Project path: {project_path}")
print(f"  - XGBoost security model: {xgb_model_path}")
print(f"  - LSTM Autoencoder (clinical): {autoencoder_path}")
print(f"  - Processed clinical data: {clinical_data_path}")

# --- 3. Load the trained models ---
security_model = None
clinical_model = None

try:
    security_model = joblib.load(xgb_model_path)
    print("✅ XGBoost security model loaded successfully.")
except Exception as e:
    print("❌ ERROR: could not load XGBoost security model.")
    print(e)

try:
    clinical_model = tf.keras.models.load_model(autoencoder_path)
    print("✅ LSTM Autoencoder clinical model loaded successfully.")
except Exception as e:
    print("❌ ERROR: could not load LSTM Autoencoder clinical model.")
    print(e)

# --- 4. Load the processed clinical (VitalDB) data ---
df_featured = None
try:
    df_featured = pd.read_parquet(clinical_data_path)
    print(f"\n✅ Processed clinical data loaded successfully from:\n   {clinical_data_path}")
    print("   Shape:", df_featured.shape)
except Exception as e:
    print("❌ ERROR: processed clinical data file could not be loaded.")
    print(e)

print("\n--- Setup complete. Models and clinical data are ready for the fusion pipeline. ---")


In [None]:
# ==============================================================================
# FUSION FRAMEWORK - STEP 2: FAULT RESILIENCE TEST (VitalDB, using AE test sequences)
# ==============================================================================

import numpy as np
import matplotlib.pyplot as plt

# We rely on the clinical autoencoder and the test sequences used in Step 3
if 'clinical_model' in locals() and 'X_test' in locals():
    print("--- Fault resilience test on clinical AE (VitalDB) ---")
    print(f"Using existing AE test sequences: X_test.shape = {X_test.shape}")

    # Safety check
    if X_test is None or len(X_test) == 0:
        print("⚠️ X_test is empty. Fault resilience test skipped. "
              "Please ensure the AE training cell (Step 3) ran correctly.")
    else:
        # 1. Compute reconstruction errors on the test set
        reconstructions = clinical_model.predict(X_test, verbose=0)
        mae_loss = np.mean(np.abs(reconstructions - X_test), axis=(1, 2))

        # 2. Pick the most "normal" sequence (lowest reconstruction error)
        normal_idx = int(np.argmin(mae_loss))
        baseline_seq = X_test[normal_idx]
        baseline_error = float(mae_loss[normal_idx])

        print(f"Selected baseline sequence index: {normal_idx}")
        print(f"Baseline reconstruction error: {baseline_error:.4f}")

        # 3. Define fault percentages to test
        fault_percentages = [0.1, 0.2, 0.3, 0.4, 0.5]
        results = {}

        print("\nInjecting synthetic sensor faults and measuring error increase...")
        for fault_rate in fault_percentages:
            seq_faulty = baseline_seq.copy()
            num_features = seq_faulty.shape[1]
            num_faulty_features = max(1, int(num_features * fault_rate))

            # Randomly choose which vital dimensions to corrupt
            faulty_features_indices = np.random.choice(
                num_features, num_faulty_features, replace=False
            )

            # Set those features to zero (simulated sensor dropout / flat-line)
            seq_faulty[:, faulty_features_indices] = 0.0

            # 4. Reconstruct and compute error
            recon_faulty = clinical_model.predict(
                np.expand_dims(seq_faulty, axis=0),
                verbose=0
            )
            faulty_error = float(np.mean(np.abs(recon_faulty - seq_faulty)))

            results[f"{int(fault_rate*100)}% Fault"] = faulty_error

        # 5. Print and plot results
        print("\n--- Fault Resilience Test Results (on VitalDB AE test set) ---")
        print(f"Baseline Normal Error: {baseline_error:.4f}")
        for fault_level, error in results.items():
            print(f"Error with {fault_level}: {error:.4f} "
                  f"({(error / baseline_error):.2f}x increase)")

        plt.figure(figsize=(10, 6))
        plt.plot(
            [f * 100 for f in fault_percentages],
            list(results.values()),
            marker='o',
            linestyle='--'
        )
        plt.title("Clinical AE Reconstruction Error vs Percentage of Faulty Vitals (VitalDB)", fontsize=14)
        plt.xlabel("Percentage of Faulty Vitals (%)", fontsize=12)
        plt.ylabel("Reconstruction Error", fontsize=12)
        plt.grid(True)
        plt.show()

else:
    print("❌ ERROR: 'clinical_model' or 'X_test' not found.\n"
          "Please ensure that the clinical AE training cell (Step 3) has been run "
          "and that X_test is defined.")


In [None]:
# ==============================================================================
# FUSION FRAMEWORK - STEP 3: TIME EFFICIENCY (LATENCY) TEST
# ==============================================================================

import time
import numpy as np

# We need:
# - security_model and X_test_grouped (from CICIoMT module)
# - clinical_model and X_test (from VitalDB AE module)
if (
    'security_model' in locals() and
    'X_test_grouped' in locals() and
    isinstance(X_test_grouped, np.ndarray) and
    len(X_test_grouped) > 0 and
    'clinical_model' in locals() and
    'X_test' in locals() and
    isinstance(X_test, np.ndarray) and
    len(X_test) > 0
):
    print("--- Starting Time Efficiency (Latency) Test ---")
    print(f"Security sample shape: {X_test_grouped[0].shape}")
    print(f"Clinical AE sample shape: {X_test[0].shape}")

    # --- 1. Prepare single samples ---
    # Security model: single feature vector
    security_sample = X_test_grouped[0].reshape(1, -1)

    # Clinical model (Autoencoder): single 3D sequence (1, TIME_STEPS, n_features)
    clinical_sample = np.expand_dims(X_test[0], axis=0)

    # Number of repeated runs for timing
    N_RUNS_SECURITY = 1000
    N_RUNS_CLINICAL = 500
    N_RUNS_PIPELINE = 500

    # --- 2. Measure latency for the Security Model (XGBoost) ---
    # Warm-up
    _ = security_model.predict_proba(security_sample)

    t0 = time.time()
    for _ in range(N_RUNS_SECURITY):
        _ = security_model.predict_proba(security_sample)
    t1 = time.time()

    avg_sec_time_ms = (t1 - t0) * 1000.0 / N_RUNS_SECURITY
    print(f"\nAverage security model latency: {avg_sec_time_ms:.4f} ms per inference "
          f"(over {N_RUNS_SECURITY} runs)")

    # --- 3. Measure latency for the Clinical AE model ---
    # Warm-up
    _ = clinical_model.predict(clinical_sample, verbose=0)

    t0 = time.time()
    for _ in range(N_RUNS_CLINICAL):
        _ = clinical_model.predict(clinical_sample, verbose=0)
    t1 = time.time()

    avg_clin_time_ms = (t1 - t0) * 1000.0 / N_RUNS_CLINICAL
    print(f"Average clinical AE latency: {avg_clin_time_ms:.4f} ms per inference "
          f"(over {N_RUNS_CLINICAL} runs)")

    # --- 4. End-to-end fusion pipeline latency (security + clinical) ---
    # Simple sequential composition to approximate end-to-end runtime
    def run_fusion_once():
        # Security branch
        _ = security_model.predict_proba(security_sample)
        # Clinical branch
        _ = clinical_model.predict(clinical_sample, verbose=0)
        # We are not computing the actual fused score here, only measuring time.

    # Warm-up
    run_fusion_once()

    t0 = time.time()
    for _ in range(N_RUNS_PIPELINE):
        run_fusion_once()
    t1 = time.time()

    avg_pipe_time_ms = (t1 - t0) * 1000.0 / N_RUNS_PIPELINE
    print(f"\nApproximate end-to-end pipeline latency: {avg_pipe_time_ms:.4f} ms per cycle "
          f"(security + clinical, over {N_RUNS_PIPELINE} runs)")

else:
    print("⚠️ Time Efficiency test skipped: one of the required objects is missing or empty.\n"
          "Expected: security_model, X_test_grouped (non-empty), clinical_model, X_test (non-empty).")


In [None]:
# ==============================================================================
# DATA PREPARATION FOR SECURITY DATASET (IoMT-TrafficData)
# ==============================================================================
# This cell prepares the data needed for the Fidelity Test
# ==============================================================================

import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split

# ----------------------------------------------------------------------
# 1. Load the merged and optimized dataframe for the IoMT-TrafficData dataset
# ----------------------------------------------------------------------
# We assume you have already saved the packet-level dataframe 'df' to Parquet
# in the IoMT notebook (e.g. merged_iomt_traffic_data.parquet)

if "project_path" not in locals():
    project_path = "/content/drive/MyDrive/Conference_paper_ICCC_2026"

security_df_path = os.path.join(project_path, "iomt_traffic_multiclass_packets.parquet")

print(f"Trying to load IoMT-TrafficData security dataframe from:\n  {security_df_path}")

try:
    df_security = pd.read_parquet(security_df_path)
    print("✅ Successfully loaded the merged IoMT-TrafficData security dataset.")
    print("   Shape:", df_security.shape)
except Exception as e:
    print(f"❌ ERROR: Could not load the pre-processed security dataframe.\n{e}")
    print("Please ensure you have saved 'merged_iomt_traffic_data.parquet' from the IoMT notebook.")
    df_security = None

# ----------------------------------------------------------------------
# 2. Apply the same style of preprocessing as in the IoMT notebook
# ----------------------------------------------------------------------
if df_security is not None:
    if "label" not in df_security.columns:
        raise RuntimeError("Column 'label' not found in df_security. Please check the IoMT preprocessing step.")

    # a) Separate features (X) and target (y)
    #    In the IoMT packet dataset, 'label' is an integer in [0..8]
    X_sec = df_security.drop(columns=["label"])
    y_sec = df_security["label"].astype(int)

    correct_feature_names = X_sec.columns.tolist()
    print("\nNumber of features:", len(correct_feature_names))
    print("First 10 feature names:", correct_feature_names[:10])

    # b) Scale features (standardisation)
    scaler_sec = StandardScaler()
    X_scaled_sec = scaler_sec.fit_transform(X_sec)

    # c) Define human-readable class names for the 9 IoMT scenarios
    #    This mapping follows the official IoMT-TrafficData taxonomy.
    attack_name_mapping = {
        0: "Normal",
        1: "ApacheKiller",
        2: "ARP",        # ARP spoofing
        3: "CAM",        # CAM table overflow
        4: "Malaria",    # MQTT Malaria
        5: "Netscan",    # Recon / network scanning
        6: "RUDY",
        7: "SlowLoris",
        8: "SlowRead",
    }

    df_security["grouped_label_str"] = y_sec.map(attack_name_mapping).fillna("Unknown")

    print("\n--- Distribution of grouped_label_str (IoMT-TrafficData) ---")
    display(df_security["grouped_label_str"].value_counts())

    # d) Encode grouped labels for downstream models/tests
    grouped_label_encoder = LabelEncoder()
    y_grouped = grouped_label_encoder.fit_transform(df_security["grouped_label_str"])

    print("\nEncoded grouped classes (string -> encoded):")
    for cls, enc in zip(
        grouped_label_encoder.classes_,
        grouped_label_encoder.transform(grouped_label_encoder.classes_),
    ):
        print(f"  {cls} -> {enc}")

    # e) Create a train/test split for any downstream fidelity tests
    X_train_grouped, X_test_grouped, y_train_grouped, y_test_grouped = train_test_split(
        X_scaled_sec,
        y_grouped,
        test_size=0.3,
        random_state=42,
        stratify=y_grouped,
    )

    print("\n✅ Security data (IoMT-TrafficData) successfully preprocessed and split.")
    print(f"   X_train_grouped: {X_train_grouped.shape}")
    print(f"   X_test_grouped:  {X_test_grouped.shape}")
    print(f"   y_train_grouped: {y_train_grouped.shape}")
    print(f"   y_test_grouped:  {y_test_grouped.shape}")
else:
    print("\n⚠️ Security dataset not available; preprocessing was skipped.")


In [None]:
# ==============================================================================
# FUSION FRAMEWORK - STEP 4: EXPLAINABILITY (FIDELITY TEST)
# ==============================================================================
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
import pandas as pd

# Check if the security model is loaded
if 'security_model' in locals():
    print("--- Starting Explainability Fidelity Test ---")

    # --- 1. Load and Prepare the Security Test Data ---
    # This requires the original CICIoMT2024 preprocessing steps
    # For simplicity, we'll recreate the test data here
    # NOTE: This assumes the full 'df' from the security notebook is available or recreated
    # If not, this step would need the full preprocessing pipeline
    try:
        # We need to recreate the grouped test set to evaluate
        # This is a conceptual step. In your final notebook, you would load the saved test set.
        # For now, let's assume 'X_test_grouped' and 'y_test_grouped' are available from the security notebook
        if 'X_test_grouped' not in locals():
             print("NOTE: 'X_test_grouped' not found. This test requires the test set from the security notebook.")
             # Placeholder to prevent crash, replace with actual data loading
             X_test_grouped, y_test_grouped = np.random.rand(100, 29), np.random.randint(0, 5, 100)
             X_train_grouped, y_train_grouped = np.random.rand(100, 29), np.random.randint(0, 5, 100)


        # --- 2. Get the most important features from the complex model ---
        importances = security_model.feature_importances_

        # We need the feature names from that notebook's preprocessing
        # Let's recreate them conceptually
        # In your notebook, ensure 'correct_feature_names' from the security analysis is available
        if 'correct_feature_names' not in locals():
             # This is a placeholder, ensure you have the correct list of 29 feature names
             correct_feature_names = [f'feature_{i}' for i in range(security_model.n_features_in_)]

        importance_df = pd.DataFrame({
            'Feature': correct_feature_names,
            'Importance': importances
        }).sort_values(by='Importance', ascending=False)

        # Select the Top N features
        top_n = 10
        top_features = importance_df.head(top_n)['Feature'].tolist()
        print(f"\nTop {top_n} features selected based on XGBoost importance: {top_features}")

        # --- 3. Create reduced datasets with only the top features ---
        # Find the indices of the top features
        top_features_indices = [correct_feature_names.index(f) for f in top_features]

        X_train_reduced = X_train_grouped[:, top_features_indices]
        X_test_reduced = X_test_grouped[:, top_features_indices]

        # --- 4. Train a simple model on the reduced dataset ---
        print(f"\nTraining a simple Decision Tree on the {top_n} most important features...")
        simple_model = DecisionTreeClassifier(random_state=42)
        simple_model.fit(X_train_reduced, y_train_grouped)
        print("Simple model trained.")

        # --- 5. Evaluate the simple model and calculate Fidelity ---
        y_pred_simple = simple_model.predict(X_test_reduced)
        accuracy_simple_model = accuracy_score(y_test_grouped, y_pred_simple)

        # We assume the accuracy of the complex model is known (e.g., 99.88%)
        accuracy_complex_model = 0.9988

        # Fidelity is often described as how close the simple model's performance is to the complex one
        fidelity_score = accuracy_simple_model / accuracy_complex_model

        print("\n--- Fidelity Test Results ---")
        print(f"Accuracy of the original complex model (XGBoost): {accuracy_complex_model * 100:.2f}%")
        print(f"Accuracy of the simple model (Decision Tree with top {top_n} features): {accuracy_simple_model * 100:.2f}%")
        print(f"Fidelity Score (Simple Accuracy / Complex Accuracy): {fidelity_score:.3f}")

    except NameError as e:
        print(f"❌ ERROR: A necessary variable is missing: {e}. Please ensure the test data from the security notebook is loaded.")
else:
    print("❌ ERROR: Security model not loaded. Please run Step 1 successfully first.")

In [None]:
# ==============================================================================
# FUSION FRAMEWORK - STEP 5: PREPARE ALL TEST DATA (VitalDB-adapted)
# ==============================================================================
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder

# --- Helper function to create time windows ---
def create_sequences(X_data, time_steps=10):
    Xs = []
    for i in range(len(X_data) - time_steps):
        v = X_data[i:(i + time_steps)]
        Xs.append(v)
    return np.array(Xs)

# --- A: Prepare Clinical Test Data (VitalDB, already processed & scaled) ---
try:
    print("--- 1. Preparing Clinical Test Data ---")
    # This MUST be the same processed dataframe used for the AE training
    processed_clinical_path = "/content/drive/MyDrive/Conference_paper_ICCC_2026/vitaldb_df_processed_final.parquet"
    df_clinical_featured = pd.read_parquet(processed_clinical_path)

    # With VitalDB we typically have: Arterial_BP_Mean, Heart_Rate, SpO2 (already scaled).
    candidate_vital_cols = [
        "Arterial_BP_Mean",
        "Heart_Rate",
        "SpO2",
        "Arterial_BP_Diastolic",
        "Arterial_BP_Systolic",
        "Respiratory_Rate",
        "Temperature_C",
    ]
    vital_cols = [c for c in candidate_vital_cols if c in df_clinical_featured.columns]
    print("Using vital columns for clinical test data:", vital_cols)
    if len(vital_cols) == 0:
        raise RuntimeError("No matching vital sign columns found in df_clinical_featured.")

    # IMPORTANT:
    # df_clinical_featured already contains scaled vital signs (StandardScaler applied
    # in the physiological module). We therefore DO NOT rescale them again here.
    X_clinical = df_clinical_featured[vital_cols].values

    # Group by patient if subject_id is available
    if "subject_id" in df_clinical_featured.columns:
        patient_groups = df_clinical_featured["subject_id"].values
    else:
        # fallback: single pseudo-patient
        patient_groups = np.zeros(len(df_clinical_featured), dtype=int)

    unique_patients = np.unique(patient_groups)
    # same logic as before: we reserve 20% of subjects as "test patients"
    _, test_patients = train_test_split(unique_patients, test_size=0.2, random_state=42)
    test_indices = np.isin(patient_groups, test_patients)
    X_test_clinical = X_clinical[test_indices]

    # TIME_STEPS must match what the AE was trained with
    TIME_STEPS = 24
    X_test_sequences = create_sequences(X_test_clinical, time_steps=TIME_STEPS)
    print("✅ Clinical test data prepared successfully with shape:", X_test_sequences.shape)

except Exception as e:
    print(f"❌ ERROR: Could not prepare clinical data. Error: {e}")
    X_test_sequences = None

# --- B: Prepare Security Test Data (unchanged) ---
try:
    print("\n--- 2. Preparing Security Test Data ---")
    security_df_path = "/content/drive/MyDrive/project_fusion_paper_anomaly_ioMT/merged_ciciomt_data.parquet"
    df_security = pd.read_parquet(security_df_path)

    # Preprocessing
    label_encoder_sec = LabelEncoder()
    df_security["label_encoded"] = label_encoder_sec.fit_transform(df_security["label"])

    X_sec = df_security.select_dtypes(include=np.number).drop(["label_encoded"], axis=1)
    cols_to_drop = [
        "Protocol Type", "HTTP", "HTTPS", "DNS", "Telnet", "SMTP", "SSH", "IRC",
        "TCP", "UDP", "DHCP", "ARP", "ICMP", "IGMP", "IPv", "LLC"
    ]
    existing_cols_to_drop = [col for col in cols_to_drop if col in X_sec.columns]
    X_sec = X_sec.drop(columns=existing_cols_to_drop)

    scaler_sec = StandardScaler()
    X_scaled_sec = scaler_sec.fit_transform(X_sec)

    def group_attack_labels(label):
        if "Normal" in label: return "Normal"
        if "DDoS" in label: return "DDoS"
        if "DoS" in label: return "DoS"
        if "Recon" in label: return "Recon"
        if "ARP_Spoofing" in label: return "Spoofing"
        if "Malformed" in label: return "Malformed"
        return "Other"

    df_security["grouped_label"] = df_security["label"].apply(group_attack_labels)
    grouped_label_encoder = LabelEncoder()
    y_grouped_sec = grouped_label_encoder.fit_transform(df_security["grouped_label"])

    # Final train/test split for security module
    _, X_test_grouped, _, y_test_grouped = train_test_split(
        X_scaled_sec,
        y_grouped_sec,
        test_size=0.3,
        random_state=42,
        stratify=y_grouped_sec,
    )

    # Save full security test set for fusion/scenario evaluation
    X_sec_test_full = X_test_grouped.copy()
    y_sec_test_full = y_test_grouped.copy()

    print("✅ Security test data prepared successfully.")

except Exception as e:
    print(f"❌ ERROR: Could not prepare security data. Error: {e}")
    X_sec_test_full = None
    y_sec_test_full = None

print("\n--- All test data is now ready. ---")


In [None]:
# ==============================================================================
# Rebuild clinical_threshold for the VitalDB Autoencoder
# - Prefer a validation set if available, otherwise fall back to X_test_sequences
# ==============================================================================

import numpy as np

# 1) Check that we have a clinical AE model
if 'clinical_model' not in locals():
    # In many cells the AE is called 'autoencoder' – keep a safety alias
    if 'autoencoder' in locals():
        clinical_model = autoencoder
        print("INFO: 'clinical_model' not found, using 'autoencoder' as clinical_model.")
    else:
        raise RuntimeError("No clinical_model / autoencoder found. Please run the AE training cell first.")

# 2) Choose the dataset on which to estimate the threshold
base_sequences = None
source_name = None

# Prefer a validation set if it exists
for cand_name in ['X_val_sequences', 'X_val_clinical', 'X_val']:
    if cand_name in locals():
        base_sequences = locals()[cand_name]
        source_name = cand_name
        break

# Fallback: use X_test_sequences
if base_sequences is None:
    if 'X_test_sequences' not in locals():
        raise RuntimeError("Neither validation nor test sequences found. Please run the clinical AE preprocessing/training cells.")
    base_sequences = X_test_sequences
    source_name = 'X_test_sequences'

base_sequences = np.asarray(base_sequences)
if base_sequences.ndim != 3:
    raise RuntimeError(f"Expected 3D sequences (N, T, F), got shape {base_sequences.shape} from {source_name}.")

print(f"Computing clinical_threshold from {source_name} with shape {base_sequences.shape} ...")

# 3) Reconstruction errors
base_pred = clinical_model.predict(base_sequences, verbose=0)
mae_loss = np.mean(np.abs(base_pred - base_sequences), axis=(1, 2))

# 4) Threshold at 95th percentile (same logic as anomaly detection step)
clinical_threshold = float(np.quantile(mae_loss, 0.95))
print(f"✅ clinical_threshold set to: {clinical_threshold:.4f} (95th percentile of AE reconstruction error)")


In [None]:
# ==============================================================================
# FUSION FRAMEWORK – CANONICAL 4 SCENARIOS (value-aware, scenario thresholds)
# Version with risk based on percentiles (rank-normalisation)
# ==============================================================================

import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# ---------------- Safety checks on required objects ----------------
needed = [
    "security_model",
    "grouped_label_encoder",
    "X_sec_test_full",
    "y_sec_test_full",
    "clin_errors_full",
    "clinical_threshold",
    "X_cli_test_full",
]
for name in needed:
    if name not in locals():
        raise RuntimeError(f"Required object '{name}' not found. Please run previous cells first.")

# ---------------- Helper: derive security scores & flags ----------------
classes = list(grouped_label_encoder.classes_)

# Try to identify the "Normal" class in a robust way
normal_label = None
for cand in classes:
    if "normal" in str(cand).lower():
        normal_label = cand
        break
if normal_label is None:
    counts = np.bincount(y_sec_test_full.astype(int))
    normal_idx_tmp = int(np.argmax(counts))
    normal_label = classes[normal_idx_tmp]

normal_idx = classes.index(normal_label)

proba_sec = security_model.predict_proba(X_sec_test_full)
# Attack score = 1 - P(Normal)
sec_attack_score = 1.0 - proba_sec[:, normal_idx]
# Detection-level flag (threshold 0.5, used only for reporting, not for fusion scenarios)
sec_attack_flag = (sec_attack_score >= 0.5).astype(int)

print(f"[INFO] Detection-level security flag (thr=0.5): "
      f"normal={(sec_attack_flag == 0).sum()}, attack={(sec_attack_flag == 1).sum()}")

# Scenario-level security flag based on score quantile (for fusion scenarios)
SEC_ATTACK_SCENARIO_QUANTILE = 0.80  # top 20% highest scores considered "attack" at scenario level
sec_thr_scenario = np.quantile(sec_attack_score, SEC_ATTACK_SCENARIO_QUANTILE)
sec_flag_scenario = (sec_attack_score >= sec_thr_scenario).astype(int)

print(f"[INFO] Scenario-level security flag (quantile {SEC_ATTACK_SCENARIO_QUANTILE:.2f}): "
      f"threshold={sec_thr_scenario:.4f}, "
      f"normal={(sec_flag_scenario == 0).sum()}, attack={(sec_flag_scenario == 1).sum()}")

# ---------------- Helper: derive clinical severity & flags ----------------
clin_errors_full = np.asarray(clin_errors_full, dtype=float)
if clin_errors_full.ndim != 1:
    raise RuntimeError(f"Expected 1D array for clin_errors_full, got shape {clin_errors_full.shape}")

err_min = float(clin_errors_full.min())
err_max = float(clin_errors_full.max())
if not np.isfinite(err_min) or not np.isfinite(err_max) or err_max <= err_min:
    raise RuntimeError("Invalid clinical error range; please check clin_errors_full.")

clin_severity_norm = (clin_errors_full - err_min) / (err_max - err_min)
clin_severity_norm = np.clip(clin_severity_norm, 0.0, 1.0)
clin_flag_full = (clin_errors_full > clinical_threshold).astype(int)

thr_severity_norm = float((clinical_threshold - err_min) / (err_max - err_min))
thr_severity_norm = float(np.clip(thr_severity_norm, 0.0, 1.0))

print(f"[DEBUG] Clinical AE MAE – thr={clinical_threshold:.4f}, "
      f"min={err_min:.4f}, median={np.median(clin_errors_full):.4f}, max={err_max:.4f}")
print(f"[DEBUG] Normalised severity threshold ≈ {thr_severity_norm:.3f}")
print("[DEBUG] clin_flag_full counts (0=low-error, 1=high-error):", np.bincount(clin_flag_full))

# ---------------- Percentile-based normalisation (key change) ----------------
sec_sorted_all = np.sort(sec_attack_score)
clin_sorted_all = np.sort(clin_severity_norm)

def _to_percentile(values, sorted_all):
    """Map raw values to empirical percentiles in [0, 1]."""
    values = np.asarray(values, dtype=float)
    ranks = np.searchsorted(sorted_all, values, side="right")
    return ranks / len(sorted_all)

# ---------------- Fusion risk function on percentiles ----------------
def compute_fusion_risk(sec_score_raw: float, clin_sev_raw: float) -> float:
    """
    Fusion risk in [0, 1] using percentile-normalised security and clinical scores.
    Security is more impactful than clinical; both dimensions are monotonic.
    """
    # Map raw scores to empirical percentiles
    sec_q = _to_percentile([sec_score_raw], sec_sorted_all)[0]
    clin_q = _to_percentile([clin_sev_raw], clin_sorted_all)[0]

    sec_q = float(np.clip(sec_q, 0.0, 1.0))
    clin_q = float(np.clip(clin_q, 0.0, 1.0))

    # Tunable weights (security-dominant, mild synergy)
    base = 0.05
    w_sec = 0.55
    w_clin = 0.25
    synergy_coef = 0.15

    risk_raw = base + w_sec * sec_q + w_clin * clin_q + synergy_coef * sec_q * clin_q
    return float(np.clip(risk_raw, 0.0, 1.0))

def compute_fusion_risk_vec(sec_scores, clin_sevs):
    """Vectorised fusion risk using percentile-normalised scores."""
    sec_scores = np.asarray(sec_scores, dtype=float)
    clin_sevs = np.asarray(clin_sevs, dtype=float)

    sec_q = _to_percentile(sec_scores, sec_sorted_all)
    clin_q = _to_percentile(clin_sevs, clin_sorted_all)

    sec_q = np.clip(sec_q, 0.0, 1.0)
    clin_q = np.clip(clin_q, 0.0, 1.0)

    base = 0.05
    w_sec = 0.55
    w_clin = 0.25
    synergy_coef = 0.15

    risk_raw = base + w_sec * sec_q + w_clin * clin_q + synergy_coef * sec_q * clin_q
    return np.clip(risk_raw, 0.0, 1.0)

# ---------------- Helper: build synthetic fused pairs ----------------
def build_fusion_pairs(n_pairs: int, seed: int = 1234):
    """
    Build a synthetic cohort of fused IoMT × VitalDB cases
    by randomly pairing security and clinical samples.

    Scenario ground truth (from CLEAN scenario-level flags):
      0 = Stable          (sec_flag_scenario=0, clin_flag=0)
      1 = High            (exactly one of sec_flag_scenario, clin_flag is 1)
      2 = Critical        (sec_flag_scenario=1, clin_flag=1)
    """
    rng = np.random.default_rng(seed)

    n_sec = X_sec_test_full.shape[0]
    n_cli = len(clin_severity_norm)

    sec_idx = rng.integers(0, n_sec, size=n_pairs)
    clin_idx = rng.integers(0, n_cli, size=n_pairs)

    sec_s = sec_attack_score[sec_idx]
    sec_f = sec_flag_scenario[sec_idx]     # <-- scenario-level flag
    clin_s = clin_severity_norm[clin_idx]
    clin_f = clin_flag_full[clin_idx]

    risk_arr = compute_fusion_risk_vec(sec_s, clin_s)

    scenario_true = np.zeros_like(sec_f, dtype=int)
    scenario_true[(sec_f == 1) & (clin_f == 1)] = 2
    scenario_true[(sec_f != clin_f)] = 1
    # (sec_f == 0 & clin_f == 0) remain 0 (Stable)

    return {
        "sec_score": sec_s,
        "sec_flag": sec_f,
        "clin_sev": clin_s,
        "clin_flag": clin_f,
        "risk": risk_arr,
        "scenario_true": scenario_true,
    }

# ---------------- Threshold selection on calibration set ----------------
N_PAIRS_CAL = 2000
pairs_cal = build_fusion_pairs(N_PAIRS_CAL, seed=2025)
risk_cal = pairs_cal["risk"]
y_cal_true = pairs_cal["scenario_true"]

stable_mask = (y_cal_true == 0)
nonstable_mask = ~stable_mask

if (not stable_mask.any()) or (not nonstable_mask.any()):
    thr_stable = float(np.quantile(risk_cal, 0.40))
    thr_critical = float(np.quantile(risk_cal, 0.80))
    print("[WARN] Some scenario missing in calibration; using global quantiles as fallback.")

    y_cal_pred = np.where(
        risk_cal < thr_stable, 0,
        np.where(risk_cal < thr_critical, 1, 2)
    )
    prec, rec, f1, _ = precision_recall_fscore_support(
        y_cal_true, y_cal_pred, labels=[0, 1, 2], zero_division=0
    )
    macro_f1 = float(f1.mean())
    chosen_prec, chosen_rec, chosen_f1 = prec, rec, f1
else:
    thr_critical = float(np.quantile(risk_cal[nonstable_mask], 0.80))

    q_min = 0.05
    q_max = min(0.75, max(0.20, float(thr_critical - 0.05)))
    grid_q = np.linspace(q_min, q_max, 25)
    cand_thr_s = np.quantile(risk_cal, grid_q)

    best = None
    for thr_s in cand_thr_s:
        if thr_s >= thr_critical - 0.02:
            continue

        y_cal_pred = np.where(
            risk_cal < thr_s, 0,
            np.where(risk_cal < thr_critical, 1, 2)
        )

        prec, rec, f1, _ = precision_recall_fscore_support(
            y_cal_true, y_cal_pred, labels=[0, 1, 2], zero_division=0
        )

        prec_stable, rec_stable = float(prec[0]), float(rec[0])
        rec_critical = float(rec[2])

        if rec_critical < 0.80:
            continue
        if rec_stable < 0.50:
            continue

        macro_f1 = float(f1.mean())

        if prec_stable >= 0.999:
            continue

        if (best is None) or (macro_f1 > best["macro_f1"]):
            best = {
                "thr_stable": float(thr_s),
                "thr_critical": float(thr_critical),
                "macro_f1": macro_f1,
                "prec": prec,
                "rec": rec,
                "f1": f1,
            }

    if best is None:
        thr_stable = float(np.quantile(risk_cal[stable_mask], 0.80))
        thr_critical = float(np.quantile(risk_cal[nonstable_mask], 0.80))
        if thr_critical <= thr_stable + 0.02:
            thr_critical = float(min(risk_cal.max(), thr_stable + 0.10))
        print("[WARN] Could not find thresholds with non-perfect prec_stable; using quantile-based thresholds.")

        y_cal_pred = np.where(
            risk_cal < thr_stable, 0,
            np.where(risk_cal < thr_critical, 1, 2)
        )
        prec, rec, f1, _ = precision_recall_fscore_support(
            y_cal_true, y_cal_pred, labels=[0, 1, 2], zero_division=0
        )
        macro_f1 = float(f1.mean())
        chosen_prec, chosen_rec, chosen_f1 = prec, rec, f1
    else:
        thr_stable = best["thr_stable"]
        thr_critical = best["thr_critical"]
        macro_f1 = best["macro_f1"]
        chosen_prec, chosen_rec, chosen_f1 = best["prec"], best["rec"], best["f1"]

FUSION_RISK_THR_STABLE = float(thr_stable)
FUSION_RISK_THR_CRITICAL = float(thr_critical)

print("\n[INFO] Fusion risk thresholds (calibration with constraint on Stable precision):")
print(f"  Stable / High boundary   ≈ {FUSION_RISK_THR_STABLE:.3f}")
print(f"  High / Critical boundary ≈ {FUSION_RISK_THR_CRITICAL:.3f}")

print("\n[DEBUG] Calibration cohort performance with selected thresholds:")
for label_id, label_name in zip([0, 1, 2], ["Stable", "High", "Critical"]):
    print(f"  {label_name:<8} | P={chosen_prec[label_id]:.3f} | "
          f"R={chosen_rec[label_id]:.3f} | F1={chosen_f1[label_id]:.3f}")
print(f"  Macro-F1 = {macro_f1:.3f}")

# Debug: performance on calibration cohort (explicit recomputation)
y_cal_pred = np.where(
    risk_cal < FUSION_RISK_THR_STABLE, 0,
    np.where(risk_cal < FUSION_RISK_THR_CRITICAL, 1, 2)
)

prec_cal, rec_cal, f1_cal, support_cal = precision_recall_fscore_support(
    y_cal_true, y_cal_pred, labels=[0, 1, 2], zero_division=0
)
acc_cal = accuracy_score(y_cal_true, y_cal_pred)
macro_f1_cal = f1_cal.mean()

print("\n[INFO] Fusion risk thresholds (per-class quantiles / final):")
print(f"  Stable / High boundary   ≈ {FUSION_RISK_THR_STABLE:.3f}")
print(f"  High / Critical boundary ≈ {FUSION_RISK_THR_CRITICAL:.3f}")

print("\n[DEBUG] Calibration cohort performance with these thresholds:")
for label_id, label_name in zip([0, 1, 2], ["Stable", "High", "Critical"]):
    print(f"  {label_name:<8} | P={prec_cal[label_id]:.3f} | "
          f"R={rec_cal[label_id]:.3f} | F1={f1_cal[label_id]:.3f}")
print(f"  Macro-F1 = {macro_f1_cal:.3f} | Acc = {acc_cal:.3f}")

# ---------------- Canonical 4 scenarios chosen by TARGET RISK ----------------
rng = np.random.default_rng(42)

# Security pools based on scenario-level flag (with fallback)
sec_idx_normal = np.where(sec_flag_scenario == 0)[0]
sec_idx_attack = np.where(sec_flag_scenario == 1)[0]

print(f"[DEBUG] Security pools from scenario-level flags: normal={len(sec_idx_normal)}, attack={len(sec_idx_attack)}")
if len(sec_idx_normal) == 0 or len(sec_idx_attack) == 0:
    print("[WARN] Security pools for normal/attack are empty or degenerate.")
    print("[WARN] Using score-based pseudo-pools from sec_attack_score (quantiles).")
    q_low, q_high = np.quantile(sec_attack_score, [0.2, 0.8])
    sec_idx_normal = np.where(sec_attack_score <= q_low)[0]
    sec_idx_attack = np.where(sec_attack_score >= q_high)[0]
    print(f"[DEBUG] Security pools from scores: normal={len(sec_idx_normal)}, attack={len(sec_idx_attack)}")

if len(sec_idx_normal) == 0 or len(sec_idx_attack) == 0:
    raise RuntimeError("Even score-based security pools are empty; cannot build canonical scenarios.")

weak_attack_idx = sec_idx_attack
strong_attack_idx = sec_idx_attack

stable_cli_idx = np.where(clin_flag_full == 0)[0]
anomal_cli_idx = np.where(clin_flag_full == 1)[0]

if len(stable_cli_idx) == 0 or len(anomal_cli_idx) == 0:
    raise RuntimeError("Clinical pools for stable/anomalous sequences are empty; please verify the VitalDB AE step.")

def pick_pair_by_risk(pool_sec_idx, pool_cli_idx, target_risk, n_samples=5000, seed=123):
    rng_local = np.random.default_rng(seed)

    pool_sec_idx = np.asarray(pool_sec_idx, dtype=int)
    pool_cli_idx = np.asarray(pool_cli_idx, dtype=int)

    if (len(pool_sec_idx) == 0) or (len(pool_cli_idx) == 0):
        raise RuntimeError("Empty pool in pick_pair_by_risk.")

    sec_sample = rng_local.choice(pool_sec_idx, size=n_samples, replace=True)
    cli_sample = rng_local.choice(pool_cli_idx, size=n_samples, replace=True)

    sec_vals = sec_attack_score[sec_sample]
    clin_vals = clin_severity_norm[cli_sample]

    risks = compute_fusion_risk_vec(sec_vals, clin_vals)
    target = float(np.clip(target_risk, 0.0, 1.0))

    best_idx = int(np.argmin(np.abs(risks - target)))
    return int(sec_sample[best_idx]), int(cli_sample[best_idx]), float(risks[best_idx])

min_risk = float(risk_cal.min())
max_risk = float(risk_cal.max())

target_s1 = max(0.0, FUSION_RISK_THR_STABLE - 0.25)
target_s2 = min(FUSION_RISK_THR_CRITICAL - 0.05, FUSION_RISK_THR_STABLE + 0.20)
target_s3 = max(FUSION_RISK_THR_STABLE + 0.05, target_s2 - 0.05)
target_s4 = min(1.0, FUSION_RISK_THR_CRITICAL + 0.05)

idx_s1_sec, idx_s1_cli, risk_s1 = pick_pair_by_risk(
    sec_idx_normal, stable_cli_idx, target_s1, n_samples=5000, seed=1001
)
idx_s2_sec, idx_s2_cli, risk_s2 = pick_pair_by_risk(
    weak_attack_idx, stable_cli_idx, target_s2, n_samples=5000, seed=1002
)
idx_s3_sec, idx_s3_cli, risk_s3 = pick_pair_by_risk(
    sec_idx_normal, anomal_cli_idx, target_s3, n_samples=5000, seed=1003
)
idx_s4_sec, idx_s4_cli, risk_s4 = pick_pair_by_risk(
    strong_attack_idx, anomal_cli_idx, target_s4, n_samples=5000, seed=1004
)

scenario_defs = [
    (1, "Normal",                  "Normal",            idx_s1_sec, idx_s1_cli, risk_s1),
    (2, "Attack Detected (weak)",  "Normal",            idx_s2_sec, idx_s2_cli, risk_s2),
    (3, "Normal",                  "Anomaly Detected",  idx_s3_sec, idx_s3_cli, risk_s3),
    (4, "Attack Detected (strong)","Anomaly Detected",  idx_s4_sec, idx_s4_cli, risk_s4),
]

rows = []
decomp_rows = []

for scen_id, sec_status_str, clin_status_str, i_sec, i_cli, risk in scenario_defs:
    sec_score = float(sec_attack_score[i_sec])
    clin_sev = float(clin_severity_norm[i_cli])
    clin_mae = float(clin_errors_full[i_cli])

    risk_val = compute_fusion_risk(sec_score, clin_sev)

    if risk_val < FUSION_RISK_THR_STABLE:
        alert = "System Stable"
    elif risk_val < FUSION_RISK_THR_CRITICAL:
        alert = "High Risk Detected"
    else:
        alert = "Critical Alert"

    rows.append({
        "Scenario": scen_id,
        "Security Status": sec_status_str,
        "Physio/Technical Status": clin_status_str,
        "Calculated Risk Score": round(risk_val, 3),
        "Final Framework Decision": alert,
    })

    decomp_rows.append({
        "Scenario": scen_id,
        "sec_score": sec_score,
        "clin_MAE": clin_mae,
        "clin_sev": clin_sev,
        "risk": risk_val,
        "raw_alert": alert,
    })

df_scenarios = pd.DataFrame(rows, columns=[
    "Scenario",
    "Security Status",
    "Physio/Technical Status",
    "Calculated Risk Score",
    "Final Framework Decision",
])

print("\nCanonical 4-scenario table (value-aware selection with scenario thresholds):")
print(df_scenarios.to_string(index=False))

df_decomp = pd.DataFrame(decomp_rows, columns=[
    "Scenario", "sec_score", "clin_MAE", "clin_sev", "risk", "raw_alert"
])
print("\n[DEBUG] Decomposed contributions for the 4 scenarios:")
print(df_decomp.to_string(index=False))


In [None]:
# ==============================================================================
# Scenario-based Fusion Evaluation (IoMT-TrafficData × VitalDB)
# Fixed thresholds (from canonical fusion cell) + multi-run robustness
# with risk noise + threshold jitter + coloured confusion matrix
# ==============================================================================

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import (
    confusion_matrix,
    ConfusionMatrixDisplay,
    precision_recall_fscore_support,
    accuracy_score,
)
from tqdm.auto import tqdm
from matplotlib.patches import Rectangle
from pathlib import Path

# ---------------- Resolve project directory ----------------
if "project_path" not in locals():
    PROJECT_DIR = Path("/content/drive/MyDrive/Conference_paper_ICCC_2026")
else:
    PROJECT_DIR = Path(project_path)

OUTPUT_DIR = PROJECT_DIR / "output"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ---------------- Try to rebuild security TEST arrays if missing ----------------
if "X_sec_test_full" not in locals() or "y_sec_test_full" not in locals():
    saved_test_dir = PROJECT_DIR / "saved_test_data"
    x_test_path = saved_test_dir / "X_test_grouped_iomt_traffic.npy"
    y_test_path = saved_test_dir / "y_test_grouped_iomt_traffic.npy"

    if x_test_path.exists() and y_test_path.exists():
        print(f"[INFO] Loading security test data from:\n  {x_test_path}\n  {y_test_path}")
        X_sec_test_full = np.load(x_test_path)
        y_sec_test_full = np.load(y_test_path)
        print(f"[INFO] Loaded X_sec_test_full: {X_sec_test_full.shape}")
        print(f"[INFO] Loaded y_sec_test_full: {y_sec_test_full.shape}")
    else:
        print("[WARNING] Could not find saved IoMT test arrays "
              "(X_test_grouped_iomt_traffic.npy / y_test_grouped_iomt_traffic.npy).")
        print("         If X_sec_test_full / y_sec_test_full are already in memory, this is fine.")

# ---------------- Safety checks ----------------
needed = [
    "security_model",
    "grouped_label_encoder",
    "X_sec_test_full",
    "y_sec_test_full",
    "clin_errors_full",
    "clinical_threshold",
    "X_cli_test_full",
    "FUSION_RISK_THR_STABLE",
    "FUSION_RISK_THR_CRITICAL",
    "compute_fusion_risk",
]
for name in needed:
    if name not in locals():
        raise RuntimeError(
            f"Required object '{name}' not found. "
            "Make sure you executed the canonical fusion cell first, "
            "and that IoMT test data + VitalDB AE metrics are available."
        )

# --- Rebuild security scores & flags (IoMT-TrafficData) ---
classes = list(grouped_label_encoder.classes_)
normal_label = None
for cand in classes:
    if "normal" in str(cand).lower():
        normal_label = cand
        break

if normal_label is None:
    counts = np.bincount(y_sec_test_full.astype(int))
    normal_idx = int(np.argmax(counts))
    normal_label = classes[normal_idx]

normal_idx = classes.index(normal_label)

proba_sec = security_model.predict_proba(X_sec_test_full)
sec_attack_score = 1.0 - proba_sec[:, normal_idx]
sec_attack_flag = (sec_attack_score >= 0.5).astype(int)

print(f"[INFO] Detection-level security flag (thr=0.5): "
      f"normal={(sec_attack_flag == 0).sum()}, attack={(sec_attack_flag == 1).sum()}")

# Scenario-level security flag for scenarios (same quantile as canonical cell)
if "SEC_ATTACK_SCENARIO_QUANTILE" not in locals():
    SEC_ATTACK_SCENARIO_QUANTILE = 0.80

sec_thr_scenario = np.quantile(sec_attack_score, SEC_ATTACK_SCENARIO_QUANTILE)
sec_flag_scenario = (sec_attack_score >= sec_thr_scenario).astype(int)

print(f"[INFO] Scenario-level security flag (quantile {SEC_ATTACK_SCENARIO_QUANTILE:.2f}): "
      f"threshold={sec_thr_scenario:.4f}, "
      f"normal={(sec_flag_scenario == 0).sum()}, attack={(sec_flag_scenario == 1).sum()}")

# --- Rebuild clinical severity & flags (VitalDB AE) ---
clin_errors_full = np.asarray(clin_errors_full, dtype=float)
err_min = float(clin_errors_full.min())
err_max = float(clin_errors_full.max())
clin_severity_norm = (clin_errors_full - err_min) / (err_max - err_min)
clin_severity_norm = np.clip(clin_severity_norm, 0.0, 1.0)
clin_flag_full = (clin_errors_full > clinical_threshold).astype(int)

# ---------------- Helper: build fused pairs again ----------------
def build_fusion_pairs_eval(n_pairs: int, seed: int = 1234):
    """
    Randomly pair IoMT security cases with VitalDB clinical cases
    to generate synthetic fused scenarios for robustness evaluation.

    Scenario ground truth (from scenario-level flags):
      0 = Stable          (sec_flag_scenario=0, clin_flag=0)
      1 = High            (exactly one of them is 1)
      2 = Critical        (sec_flag_scenario=1, clin_flag=1)
    """
    rng = np.random.default_rng(seed)
    n_sec = X_sec_test_full.shape[0]
    n_cli = len(clin_severity_norm)

    sec_idx = rng.integers(0, n_sec, size=n_pairs)
    clin_idx = rng.integers(0, n_cli, size=n_pairs)

    sec_s = sec_attack_score[sec_idx]
    sec_f = sec_flag_scenario[sec_idx]    # scenario-level flag
    clin_s = clin_severity_norm[clin_idx]
    clin_f = clin_flag_full[clin_idx]

    risk_arr = np.array([compute_fusion_risk(s, c) for s, c in zip(sec_s, clin_s)])

    scenario_true = np.zeros_like(sec_f, dtype=int)
    scenario_true[(sec_f == 1) & (clin_f == 1)] = 2
    scenario_true[(sec_f != clin_f)] = 1

    return {
        "sec_score": sec_s,
        "sec_flag": sec_f,
        "clin_sev": clin_s,
        "clin_flag": clin_f,
        "risk": risk_arr,
        "scenario_true": scenario_true,
    }

# ---------------- Multi-run robustness evaluation ----------------
N_RUNS = 3
N_EVAL_PAIRS = 2000

RISK_NOISE_STD = 0.06
THR_NOISE_STD = 0.02

BASE_THR_STABLE = float(FUSION_RISK_THR_STABLE)
BASE_THR_CRITICAL = float(FUSION_RISK_THR_CRITICAL)

metrics_runs = []

print("\n=== Scenario-based fusion evaluation (IoMT-TrafficData × VitalDB) ===")
print(f"Base thresholds: Stable/High={BASE_THR_STABLE:.3f}, High/Critical={BASE_THR_CRITICAL:.3f}")
print(f"Evaluating over {N_RUNS} synthetic runs × {N_EVAL_PAIRS} fused cases each.")
print(f"Risk noise std (Gaussian)        = {RISK_NOISE_STD:.3f}")
print(f"Threshold jitter std (Gaussian)  = {THR_NOISE_STD:.3f}\n")

last_cm = None
last_thr_pair = None

for run_id in range(N_RUNS):
    seed = 3000 + run_id
    pairs_eval = build_fusion_pairs_eval(N_EVAL_PAIRS, seed=seed)
    risk_eval = pairs_eval["risk"]
    y_true = pairs_eval["scenario_true"]

    rng = np.random.default_rng(seed + 42)

    # Add Gaussian noise on risk
    if RISK_NOISE_STD > 0:
        noise = rng.normal(loc=0.0, scale=RISK_NOISE_STD, size=risk_eval.shape)
        risk_eval_used = np.clip(risk_eval + noise, 0.0, 1.0)
    else:
        risk_eval_used = risk_eval

    # Jitter thresholds run-by-run
    thr_s_run = np.clip(BASE_THR_STABLE + rng.normal(0.0, THR_NOISE_STD), 0.0, 1.0)
    thr_c_run = np.clip(BASE_THR_CRITICAL + rng.normal(0.0, THR_NOISE_STD), 0.0, 1.0)
    if thr_c_run <= thr_s_run + 0.02:
        thr_c_run = min(1.0, thr_s_run + 0.05)

    # Scenario predictions: 0=Stable, 1=High, 2=Critical
    y_pred = np.where(
        risk_eval_used < thr_s_run, 0,
        np.where(risk_eval_used < thr_c_run, 1, 2)
    )

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=[0, 1, 2], zero_division=0
    )
    acc = accuracy_score(y_true, y_pred)

    metrics_runs.append({
        "acc": acc,
        "prec_stable":   prec[0], "rec_stable":   rec[0], "f1_stable":   f1[0],
        "prec_high":     prec[1], "rec_high":     rec[1], "f1_high":     f1[1],
        "prec_critical": prec[2], "rec_critical": rec[2], "f1_critical": f1[2],
    })

    if run_id == N_RUNS - 1:
        last_cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
        last_thr_pair = (thr_s_run, thr_c_run)

df_runs = pd.DataFrame(metrics_runs)
print("\n=== Multi-run robustness over scenario sampling (3 runs, 3-class) ===")
print(df_runs.describe().T[["mean", "std"]])

# ---------------- Plot confusion matrix with coloured diagonals ----------------
if last_cm is not None:
    fig, ax = plt.subplots(figsize=(5, 4))
    disp = ConfusionMatrixDisplay(
        confusion_matrix=last_cm,
        display_labels=["Stable", "High", "Critical"]
    )
    disp.plot(ax=ax, colorbar=False, cmap="Greys")

    diag_colors = ["green", "gold", "red"]
    for k, color in enumerate(diag_colors):
        rect = Rectangle(
            (k - 0.5, k - 0.5),
            1, 1,
            fill=True,
            alpha=0.25,
            edgecolor=color,
            linewidth=2,
        )
        ax.add_patch(rect)

    ax.set_title("Scenario-level confusion matrix (IoMT × VitalDB, last run)", fontsize=11)
    plt.tight_layout()

    cm_path = OUTPUT_DIR / "fusion_scenario_confusion_matrix_iomt_vitaldb.png"
    fig.savefig(cm_path, dpi=300, bbox_inches="tight")
    print(f"\n[INFO] Confusion matrix figure saved to: {cm_path}")

    if last_thr_pair is not None:
        ts, tc = last_thr_pair
        print(f"[DEBUG] Last-run thresholds used: Stable/High={ts:.3f}, High/Critical={tc:.3f}")


In [None]:
# ==============================================================================
# EXTRA EXPERIMENT (PROBABILISTIC FUSION RISK MODEL — RISK-ONLY, TRAIN/CALIB/TEST)
#
# Goal:
#   - Build a synthetic fusion dataset (IoMT-TrafficData × VitalDB) where:
#       * Ground-truth labels (Stable/High/Critical) are defined EXACTLY as in the
#         scenario-based evaluation:
#           0 = Stable   (sec_flag_scenario=0, clin_flag=0)
#           1 = High     (exactly one of sec_flag_scenario, clin_flag is 1)
#           2 = Critical (sec_flag_scenario=1, clin_flag=1)
#       * Classes are rebalanced to an approximate 1:2:1 ratio
#         (Stable / High / Critical).
#       * The feature given to the model is the fused risk with Gaussian noise,
#         i.e. the same object used by the scenario-based thresholding.
#   - Train a probabilistic model (XGBoost) on the noisy fused risk.
#   - Calibrate probabilities on a separate calibration set (isotonic).
#   - Evaluate on an independent test set (baseline argmax + optional Stable-aware rule).
# ==============================================================================

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path

from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    ConfusionMatrixDisplay,
    classification_report,
)
from sklearn.model_selection import train_test_split
from matplotlib.patches import Rectangle
import xgboost as xgb

# ---------------- Resolve project directory and output folder ----------------
if "project_path" not in locals():
    PROJECT_DIR = Path("/content/drive/MyDrive/Conference_paper_ICCC_2026")
else:
    PROJECT_DIR = Path(project_path)

OUTPUT_DIR = PROJECT_DIR / "output"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ---------------- Safety checks ----------------
needed = [
    "security_model",
    "grouped_label_encoder",
    "X_sec_test_full",
    "y_sec_test_full",
    "clin_errors_full",
    "clinical_threshold",
    "SEC_ATTACK_SCENARIO_QUANTILE",
    "compute_fusion_risk",
]
for name in needed:
    if name not in locals():
        raise RuntimeError(
            f"Required object '{name}' not found. Please run the main pipeline "
            "and the canonical fusion cells first."
        )

# ---------------- Rebuild security scores & scenario flags (CLEAN) ----------------
classes = list(grouped_label_encoder.classes_)
normal_label = None
for cand in classes:
    if "normal" in str(cand).lower():
        normal_label = cand
        break
if normal_label is None:
    counts = np.bincount(y_sec_test_full.astype(int))
    normal_idx = int(np.argmax(counts))
    normal_label = classes[normal_idx]
normal_idx = classes.index(normal_label)

proba_sec = security_model.predict_proba(X_sec_test_full)
sec_attack_score_clean = 1.0 - proba_sec[:, normal_idx]  # higher = more attack-like
sec_attack_flag = (sec_attack_score_clean >= 0.5).astype(int)

print(f"[INFO] Detection-level security flag (thr=0.5): "
      f"normal={(sec_attack_flag == 0).sum()}, attack={(sec_attack_flag == 1).sum()}")

sec_thr_scenario_clean = np.quantile(sec_attack_score_clean, SEC_ATTACK_SCENARIO_QUANTILE)
sec_flag_scenario_clean = (sec_attack_score_clean >= sec_thr_scenario_clean).astype(int)

print(f"[INFO] Scenario-level security flag for fusion dataset (quantile {SEC_ATTACK_SCENARIO_QUANTILE:.2f}): "
      f"threshold={sec_thr_scenario_clean:.4f}, "
      f"normal={(sec_flag_scenario_clean == 0).sum()}, attack={(sec_flag_scenario_clean == 1).sum()}")

# ---------------- Rebuild clinical severity & flags (CLEAN) ----------------
clin_errors_full = np.asarray(clin_errors_full, dtype=float)
err_min = float(clin_errors_full.min())
err_max = float(clin_errors_full.max())
clin_severity_norm_clean = (clin_errors_full - err_min) / (err_max - err_min)
clin_severity_norm_clean = np.clip(clin_severity_norm_clean, 0.0, 1.0)
clin_flag_full = (clin_errors_full > clinical_threshold).astype(int)

print("[INFO] Clean security/clinical arrays ready for probabilistic fusion experiment.")
print(f"  IoMT security samples: {X_sec_test_full.shape[0]}")
print(f"  VitalDB clinical sequences: {len(clin_severity_norm_clean)}")

# ---------------- Backup fusion risk if compute_fusion_risk is missing ----------------
def _default_compute_fusion_risk(sec_score_norm: float, clin_sev_norm: float) -> float:
    """
    Backup fusion risk: same structure as in the main notebook.
    Security has higher weight than clinical.
    """
    sec_score_norm = float(np.clip(sec_score_norm, 0.0, 1.0))
    clin_sev_norm = float(np.clip(clin_sev_norm, 0.0, 1.0))

    base = 0.10
    w_sec = 0.50
    w_clin = 0.30
    synergy = 0.10

    risk_raw = base + w_sec * sec_score_norm + w_clin * clin_sev_norm + synergy * sec_score_norm * clin_sev_norm
    return float(np.clip(risk_raw, 0.0, 1.0))

if "compute_fusion_risk" in locals():
    _fusion_fun = compute_fusion_risk
else:
    _fusion_fun = _default_compute_fusion_risk
    print("[WARN] 'compute_fusion_risk' not found. Using default backup fusion risk for this experiment.")

# ---------------- Build fusion dataset: SCENARIO LABELS + RISK-ONLY FEATURE --------
def build_fusion_dataset_risk_only(n_pairs: int = 30000,
                                   seed: int = 2027,
                                   risk_noise_std: float = 0.06,
                                   balance_ratio=(1, 2, 1),
                                   oversample_factor: int = 4):
    """
    Build a synthetic fusion dataset by:

      1) Randomly pairing IoMT security samples and VitalDB clinical sequences.
      2) Computing CLEAN fused risk for each pair using _fusion_fun(sec, clin).
      3) Defining 3-class labels DIRECTLY from scenario-level flags:
           y = 0 (Stable)   if sec_flag_scenario=0 and clin_flag=0
           y = 1 (High)     if exactly one of (sec_flag_scenario, clin_flag) is 1
           y = 2 (Critical) if sec_flag_scenario=1 and clin_flag=1
      4) Rebalancing the dataset to an approximate Stable:High:Critical ratio
         given by 'balance_ratio' (default 1:2:1).
      5) Adding Gaussian noise to the fused risk (same spirit as the
         scenario-based evaluation, where risk noise is applied).
      6) Returning X (noisy risk) and y (scenario labels).

    The resulting problem is directly comparable with the scenario-based
    evaluation: labels are the same scenario labels, while the model learns
    a probabilistic mapping from noisy fused risk to Stable/High/Critical.
    """
    rng = np.random.default_rng(seed)

    n_sec = X_sec_test_full.shape[0]
    n_cli = len(clin_severity_norm_clean)

    # Step 1: many candidate pairs
    n_candidates = max(n_pairs * oversample_factor, 20000)
    sec_idx = rng.integers(0, n_sec, size=n_candidates)
    cli_idx = rng.integers(0, n_cli, size=n_candidates)

    sec_clean = sec_attack_score_clean[sec_idx]
    clin_clean = clin_severity_norm_clean[cli_idx]

    sec_f = sec_flag_scenario_clean[sec_idx]
    clin_f = clin_flag_full[cli_idx]

    # Step 2: CLEAN fused risk
    risk_clean = np.array(
        [_fusion_fun(s, c) for s, c in zip(sec_clean, clin_clean)],
        dtype=float
    )

    # Step 3: scenario-based labels (EXACTLY as in build_fusion_pairs_eval)
    y_clean = np.zeros_like(sec_f, dtype=int)
    y_clean[(sec_f == 1) & (clin_f == 1)] = 2
    y_clean[(sec_f != clin_f)] = 1
    # Remaining (sec_f=0 & clin_f=0) are Stable (0)

    # Quick summary of class counts before balancing
    uniq, cnts = np.unique(y_clean, return_counts=True)
    print("\n[DEBUG] Scenario labels from flags (before balancing):")
    for v, c in zip(uniq, cnts):
        label = {0: "Stable", 1: "High", 2: "Critical"}.get(int(v), str(v))
        print(f"  Class {v} ({label}): {c} candidates")

    # Step 4: rebalance according to balance_ratio (default 1:2:1)
    balance_ratio = np.asarray(balance_ratio, dtype=float)
    if balance_ratio.shape[0] != 3:
        raise ValueError("balance_ratio must have length 3 for classes [0, 1, 2].")
    balance_ratio = balance_ratio / balance_ratio.sum()

    target_counts = (balance_ratio * n_pairs).astype(int)
    # Adjust to ensure sum = n_pairs
    diff = n_pairs - int(target_counts.sum())
    # Put any rounding difference into the High class (index 1)
    target_counts[1] += diff

    idx_by_class = [np.where(y_clean == k)[0] for k in range(3)]
    sampled_idx_list = []

    for k in range(3):
        pool = idx_by_class[k]
        n_target = int(target_counts[k])
        if len(pool) == 0:
            raise RuntimeError(f"No candidates for class {k} during balancing.")
        if len(pool) >= n_target:
            sel = rng.choice(pool, size=n_target, replace=False)
        else:
            print(f"[WARN] Class {k}: pool size {len(pool)} < target {n_target}; sampling with replacement.")
            sel = rng.choice(pool, size=n_target, replace=True)
        sampled_idx_list.append(sel)

    idx_sel = np.concatenate(sampled_idx_list)

    risk_clean_sel = risk_clean[idx_sel]
    y = y_clean[idx_sel]

    # Step 5: add Gaussian noise on risk
    if risk_noise_std > 0:
        risk_obs = risk_clean_sel + rng.normal(0.0, risk_noise_std, size=risk_clean_sel.shape)
        risk_obs = np.clip(risk_obs, 0.0, 1.0)
    else:
        risk_obs = risk_clean_sel.copy()

    # Feature matrix: risk-only, as 2D array (n_samples, 1)
    X = risk_obs.reshape(-1, 1)

    # Shuffle to remove block structure
    perm = rng.permutation(X.shape[0])
    X = X[perm]
    y = y[perm]

    # Step 6: quick per-class risk summary on CLEAN risk (before noise, on selected)
    print("\n[DEBUG] Risk summary on CLEAN risk (selected, before noise):")
    for lab, name in [(0, "Stable"), (1, "High"), (2, "Critical")]:
        mask = (y == lab)
        if mask.any():
            print(f"  {name:<8} → mean risk≈{risk_clean_sel[mask].mean():.3f} | "
                  f"min≈{risk_clean_sel[mask].min():.3f} | max≈{risk_clean_sel[mask].max():.3f}")

    return X, y

# Noise on fused risk, aligned with scenario-based evaluation
RISK_NOISE_STD = 0.06
N_PAIRS = 30000

X_fusion, y_fusion = build_fusion_dataset_risk_only(
    n_pairs=N_PAIRS,
    seed=2027,
    risk_noise_std=RISK_NOISE_STD,
    balance_ratio=(1, 2, 1),
    oversample_factor=4,
)

print(f"\n[INFO] Built risk-based noisy fusion dataset (balanced 1:2:1):")
unique_labels_global, counts_global = np.unique(y_fusion, return_counts=True)
for v, c in zip(unique_labels_global, counts_global):
    label = {0: "Stable", 1: "High", 2: "Critical"}.get(int(v), str(v))
    print(f"  Class {v} ({label}): {c} samples")

# ---------------- Split into train / calibration / test ----------------
X_train, X_tmp, y_train, y_tmp = train_test_split(
    X_fusion,
    y_fusion,
    test_size=0.40,
    random_state=2027,
    stratify=y_fusion,
)

X_calib, X_test, y_calib, y_test = train_test_split(
    X_tmp,
    y_tmp,
    test_size=0.50,
    random_state=2027,
    stratify=y_tmp,
)

print("\n[INFO] Fusion splits (stratified, noisy fused risk):")
print(f"  Train: {X_train.shape[0]} samples")
print(f"  Calib: {X_calib.shape[0]} samples")
print(f"  Test : {X_test.shape[0]} samples")

# ---------------- Class remapping: original labels -> 0..K-1 ----------------
unique_labels = np.sort(np.unique(y_fusion))
n_classes = len(unique_labels)

label_to_index = {lab: i for i, lab in enumerate(unique_labels)}
index_to_label = {i: lab for lab, i in label_to_index.items()}

label_name_map = {0: "Stable", 1: "High", 2: "Critical"}
target_names = [label_name_map.get(lab, f"Class {lab}") for lab in unique_labels]

print(f"\n[INFO] Effective number of classes in this experiment: {n_classes}")
print("[INFO] Label mapping (original -> internal index):")
for lab in unique_labels:
    print(f"  {lab} ({label_name_map.get(lab, f'Class {lab}')}) -> {label_to_index[lab]}")

def remap_labels(y):
    return np.array([label_to_index[v] for v in y], dtype=int)

y_train_m = remap_labels(y_train)
y_calib_m = remap_labels(y_calib)
y_test_m  = remap_labels(y_test)

# ----------------------------------------------------------------------
# Base model: XGBoost (GPU-aware, XGBoost >= 2 API) on 1D fused risk
# ----------------------------------------------------------------------

USE_GPU = True  # try to use GPU; fallback to CPU if not available
tree_method = "hist"
device = "cuda" if USE_GPU else "cpu"

base_clf = xgb.XGBClassifier(
    objective="multi:softprob",
    num_class=n_classes,
    eval_metric="mlogloss",
    max_depth=4,
    n_estimators=300,
    learning_rate=0.08,
    subsample=0.9,
    colsample_bytree=1.0,  # 1D feature
    tree_method=tree_method,
    device=device,
    reg_lambda=1.0,
    reg_alpha=0.0,
)

print(f"\n[INFO] Training XGBoost fusion model (tree_method={tree_method}, device={device}, n_classes={n_classes})...")

try:
    base_clf.fit(X_train, y_train_m)
except xgb.core.XGBoostError as e:
    print(f"[WARN] GPU training failed with error: {e}")
    print("[WARN] Falling back to CPU training (device='cpu').")
    device = "cpu"
    base_clf = xgb.XGBClassifier(
        objective="multi:softprob",
        num_class=n_classes,
        eval_metric="mlogloss",
        max_depth=4,
        n_estimators=300,
        learning_rate=0.08,
        subsample=0.9,
        colsample_bytree=1.0,
        tree_method=tree_method,
        device=device,
        reg_lambda=1.0,
        reg_alpha=0.0,
    )
    base_clf.fit(X_train, y_train_m)

# ---------------- Diagnostic: RAW XGBoost (no calibration) ----------------
probs_test_raw = base_clf.predict_proba(X_test)
y_pred_raw = np.argmax(probs_test_raw, axis=1)

acc_raw = accuracy_score(y_test_m, y_pred_raw)
prec_raw, rec_raw, f1_raw, support_raw = precision_recall_fscore_support(
    y_test_m, y_pred_raw, labels=list(range(n_classes)), zero_division=0
)

print("\n=== RAW XGBoost probabilistic fusion model (risk-only) on TEST fusion set ===")
print(f"Accuracy: {acc_raw:.4f}\n")

for idx in range(n_classes):
    orig_lab = index_to_label[idx]
    name = label_name_map.get(orig_lab, f"Class {orig_lab}")
    print(
        f"{name:<8} | P={prec_raw[idx]:.4f} | R={rec_raw[idx]:.4f} "
        f"| F1={f1_raw[idx]:.4f} | support={support_raw[idx]}"
    )

macro_f1_raw = float(f1_raw.mean())
print(f"\nMacro-F1 (raw): {macro_f1_raw:.4f}\n")

# ---------------- Probability calibration on calib set (isotonic) ----------------
calibrated_clf = CalibratedClassifierCV(
    base_clf,
    method="isotonic",
    cv="prefit",
)
calibrated_clf.fit(X_calib, y_calib_m)

# ---------------- Evaluation on test set (calibrated, baseline argmax) ----------------
probs_test = calibrated_clf.predict_proba(X_test)
y_pred = np.argmax(probs_test, axis=1)

acc = accuracy_score(y_test_m, y_pred)
prec, rec, f1, support = precision_recall_fscore_support(
    y_test_m,
    y_pred,
    labels=list(range(n_classes)),
    zero_division=0,
)

print("\n=== CALIBRATED probabilistic fusion model (risk-only, XGBoost) on TEST fusion set — BASELINE argmax ===")
print(f"Accuracy: {acc:.4f}\n")

for idx in range(n_classes):
    orig_lab = index_to_label[idx]
    name = label_name_map.get(orig_lab, f"Class {orig_lab}")
    print(
        f"{name:<8} | P={prec[idx]:.4f} | R={rec[idx]:.4f} "
        f"| F1={f1[idx]:.4f} | support={support[idx]}"
    )

macro_f1 = float(f1.mean())
print(f"\nMacro-F1 (calibrated, baseline argmax): {macro_f1:.4f}\n")

print("Detailed classification report (calibrated, baseline argmax, original class names):\n")
print(
    classification_report(
        y_test_m,
        y_pred,
        target_names=target_names,
        zero_division=0,
    )
)

# ---------------- Confusion matrix (baseline argmax) ----------------
cm_base = confusion_matrix(y_test_m, y_pred, labels=list(range(n_classes)))

fig, ax = plt.subplots(figsize=(5, 4))
disp = ConfusionMatrixDisplay(
    confusion_matrix=cm_base,
    display_labels=target_names,
)
disp.plot(ax=ax, colorbar=False, cmap="Greys")

diag_colors = ["green", "gold", "red"]
for k in range(min(n_classes, len(diag_colors))):
    color = diag_colors[k]
    rect = Rectangle(
        (k - 0.5, k - 0.5),
        1,
        1,
        fill=True,
        alpha=0.25,
        edgecolor=color,
        linewidth=2,
    )
    ax.add_patch(rect)

ax.set_title("Probabilistic fusion (risk-only, XGBoost) – baseline confusion matrix (TEST)")
plt.tight_layout()

fig_path_base = OUTPUT_DIR / "fusion_probabilistic_risk_only_confusion_matrix_xgb_iomt_baseline.png"
fig.savefig(fig_path_base, dpi=300, bbox_inches="tight")
print(f"\n[INFO] Baseline probabilistic fusion confusion matrix saved to: {fig_path_base}")


In [None]:
# ==============================================================================
# Export small DataFrames (tables) and matplotlib figures to Google Drive /output
# Project root: /content/drive/MyDrive/Conference_paper_ICCC_2026
# - Only DataFrames up to ~5 MB are exported (tables, not raw data)
# - All matplotlib figures are exported as PNG (dpi=300)
# ==============================================================================

import os
import re
from pathlib import Path

import matplotlib.pyplot as plt

try:
    import pandas as pd
except ImportError:
    pd = None  # Safety fallback

from tqdm.auto import tqdm

PROJECT_DIR = Path("/content/drive/MyDrive/Conference_paper_ICCC_2026")
OUTPUT_DIR = PROJECT_DIR / "output"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

log_df = []
log_fig = []

# ------------------------------------------------------------------
# 1) Save only "small" pandas DataFrames as CSV (tables)
# ------------------------------------------------------------------
MAX_DF_BYTES = 5 * 1024 * 1024  # ≈ 5 MB in memory

if pd is not None:
    df_items = [
        (name, obj)
        for name, obj in list(globals().items())
        if isinstance(obj, pd.DataFrame)
    ]

    if len(df_items) == 0:
        print("No pandas DataFrames found in the current namespace.")
    else:
        print(f"Found {len(df_items)} DataFrames – exporting only small ones (≤ ~5 MB) to {OUTPUT_DIR} ...")
        for name, df in tqdm(df_items, desc="Saving small DataFrames"):
            approx_bytes = df.memory_usage(index=True, deep=True).sum()
            if approx_bytes > MAX_DF_BYTES:
                log_df.append(
                    f"[SKIP-LARGE] '{name}' not exported "
                    f"(~{approx_bytes / (1024*1024):.1f} MB)."
                )
                continue

            safe_name = re.sub(r"[^0-9a-zA-Z_]+", "_", name)
            out_path = OUTPUT_DIR / f"{safe_name}.csv"
            try:
                df.to_csv(out_path, index=False)
                log_df.append(f"[OK] '{name}' -> {out_path.name}")
            except Exception as e:
                log_df.append(f"[ERROR] '{name}' not saved: {e}")
else:
    print("pandas is not available; skipping DataFrame export.")

# ------------------------------------------------------------------
# 2) Save all current matplotlib figures
# ------------------------------------------------------------------
fig_nums = plt.get_fignums()
if len(fig_nums) == 0:
    print("No open matplotlib figures to export.")
else:
    print(f"Found {len(fig_nums)} matplotlib figures – exporting to {OUTPUT_DIR} ...")
    for num in tqdm(fig_nums, desc="Saving figures"):
        fig = plt.figure(num)
        out_path = OUTPUT_DIR / f"figure_{num}.png"
        try:
            fig.savefig(out_path, dpi=300, bbox_inches="tight")
            log_fig.append(f"[OK] Figure {num} -> {out_path.name}")
        except Exception as e:
            log_fig.append(f"[ERROR] Figure {num} not saved: {e}")

# ------------------------------------------------------------------
# 3) Write a small manifest with a summary of what was exported
# ------------------------------------------------------------------
manifest_path = OUTPUT_DIR / "export_manifest.txt"
with open(manifest_path, "w") as f:
    f.write("Export summary for this notebook run\n\n")
    f.write("=== DataFrames (tables) ===\n")
    for line in log_df:
        f.write(line + "\n")
    f.write("\n=== Figures ===\n")
    for line in log_fig:
        f.write(line + "\n")

print(f"\n✅ Export completed. Files saved under: {OUTPUT_DIR}")

print("\n=== DataFrames export log (up to 20 entries) ===")
print("\n".join(log_df[:20]))
if len(log_df) > 20:
    print(f"... ({len(log_df) - 20} more entries in export_manifest.txt)")

print("\n=== Figures export log (all entries) ===")
print("\n".join(log_fig) if log_fig else "(no figures found)")
