In [None]:
import joblib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import warnings
from sklearn.metrics import accuracy_score
import joblib
import duckdb
from sklearn.model_selection import train_test_split

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

 
BLOOD_MODEL_PATH = 'xgb_cad_severity_model-7.joblib'
SCALER_PATH = 'scaler_cad_severity-7.joblib'
FEATURES_LIST_PATH = 'selected_features-7.joblib'
DUCKDB_PATH = '../../final_db/mimic_analysis.db'
ECG_MODEL_PATH = './cnn_lstm_hybrid_checkpoints/hybrid_model_epoch60.pt'
CLEANED_CSV_PATH = 'metadata_with_features.csv'
FUSION_MODEL_PATH = 'fusion_1_checkpoints/fusion_meta_model.joblib' 
BATCH_SIZE = 64
LEAD_COUNT = 12
SAMPLES = 2500
NUM_CLASSES = 3
NUM_CLINICAL_FEATURES = 8
HIDDEN_SIZE = 64
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#ECG Model Architecture & Dataset Classes

class CNNLSTM_FeatureExtractor(nn.Module):
    def __init__(self, hidden_size=HIDDEN_SIZE, num_layers=1):
        super().__init__()
        self.hidden_size = hidden_size
        self.cnn = nn.Sequential(
            nn.Conv1d(LEAD_COUNT, 32, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(32), nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            nn.Conv1d(32, 64, kernel_size=11, stride=2, padding=5),
            nn.BatchNorm1d(64), nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            nn.Conv1d(64, 128, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm1d(128), nn.ReLU(inplace=True)
        )
        self.lstm = nn.LSTM(
            input_size=128, hidden_size=hidden_size, num_layers=num_layers,
            batch_first=True, bidirectional=True
        )

    def forward(self, x):
        cnn_out = self.cnn(x)
        lstm_input = cnn_out.transpose(1, 2)
        _, (h_n, _) = self.lstm(lstm_input)
        final_state = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), dim=1)
        return final_state

class FinalMultiInputHybridModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, num_clinical_features=NUM_CLINICAL_FEATURES):
        super().__init__()
        self.signal_extractor = CNNLSTM_FeatureExtractor(hidden_size=HIDDEN_SIZE)
        INPUT_SIZE = (2 * HIDDEN_SIZE) + num_clinical_features
        self.final_fc = nn.Sequential(
            nn.Linear(INPUT_SIZE, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(64, num_classes)
        )

    def forward(self, signal, clinical_features):
        deep_features = self.signal_extractor(signal)
        combined_features = torch.cat((deep_features, clinical_features.float()), dim=1)
        return self.final_fc(combined_features)

class HybridECGDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # --- Load Signal ---
        try:
            signal = np.load(row['processed_npy_path']).T
            if signal.shape[1] > SAMPLES:
                signal = signal[:, :SAMPLES]
            elif signal.shape[1] < SAMPLES:
                 padding = np.zeros((LEAD_COUNT, SAMPLES - signal.shape[1]), dtype=np.float32)
                 signal = np.concatenate([signal, padding], axis=1)
            signal = np.nan_to_num(signal, nan=0.0)
        except Exception:
            signal = np.zeros((LEAD_COUNT, SAMPLES), dtype=np.float32)

        # --- Load Clinical Features ---
        try:
            features = np.load(row['feature_path'])
        except Exception:
            features = np.zeros(NUM_CLINICAL_FEATURES, dtype=np.float32)

        label = 0 # Placeholder if label is unknown
        if 'severity_level' in row and pd.notna(row['severity_level']):
            label_map = {"Low": 0, "Moderate": 1, "High": 2}
            label = label_map.get(row['severity_level'], 0)
        
        return (torch.tensor(signal, dtype=torch.float32), 
                torch.tensor(features, dtype=torch.float32)), label



In [None]:

# Data Loading Function 

def load_and_align_test_data(cleaned_csv_path, duckdb_path):
    
    df_ecg_full = pd.read_csv(cleaned_csv_path)
    con = duckdb.connect(duckdb_path)
    
    df_blood_full = con.execute("""
        SELECT hadm_id, subject_id, anchor_age, gender, cad, severity_level
        FROM admissions_severity
        WHERE severity_level IS NOT NULL
    """).df()
    
    cbc_keywords = ["hemoglobin", "hematocrit", "rbc", "wbc", "platelet", "mcv", "mch", "mchc", "rdw", "neutrophil", "lymphocyte", "monocyte", "eosinophil", "basophil"]
    bmp_keywords = ["sodium", "potassium", "chloride", "bicarbonate", "co2", "urea", "bun", "creatinine", "glucose", "calcium"]
    lft_keywords = ["albumin", "protein", "bilirubin", "alkaline phosphatase", "ast", "sgot", "alt", "sgpt", "lactate"]
    lipid_keywords = ["cholesterol", "hdl", "ldl", "triglyceride"]
    cardiac_keywords = ["troponin", "ck-mb", "creatine kinase", "ck", "bnp", "nt-probnp", "hs-crp"]
    all_keywords = cbc_keywords + bmp_keywords + lft_keywords + lipid_keywords + cardiac_keywords
    pattern = '|'.join(all_keywords)

    df_labitems = con.execute("SELECT itemid, label, fluid FROM d_labitems WHERE LOWER(fluid) = 'blood'").df()
    mask = df_labitems['label'].str.lower().str.contains(pattern, na=False)
    blood_itemids = df_labitems[mask]['itemid'].tolist()
    if len(blood_itemids) == 0:
        con.close()
        raise RuntimeError("No matching blood lab items found.")
    blood_itemids_str = ', '.join(map(str, blood_itemids))

    df_labs = con.execute(f"""
        SELECT l.subject_id, d.label, AVG(l.valuenum) as mean_value
        FROM labevents l JOIN d_labitems d ON l.itemid = d.itemid
        WHERE l.itemid IN ({blood_itemids_str})
        GROUP BY l.subject_id, d.label
    """).df()
    df_labs_pivot = df_labs.pivot(index='subject_id', columns='label', values='mean_value').reset_index()

    df_blood_full = pd.merge(df_blood_full, df_labs_pivot, on='subject_id', how='left')

    comorb_tables = {
        'diabetes': 'diabetes_adm',
        'hypertension': 'hypertension_adm',
        'renal': 'renal_adm',
        'obesity': 'obesity_icd_adm',
        'smoker': 'smokers_adm'
    }
    df_blood_full['hadm_id'] = df_blood_full['hadm_id'].astype(str)

    for comorb, table in comorb_tables.items():
        df_comorb = con.execute(f"SELECT DISTINCT hadm_id, 1 AS {comorb} FROM {table}").df()
        if not df_comorb.empty:
            df_comorb['hadm_id'] = df_comorb['hadm_id'].astype(str)
            df_blood_full = pd.merge(df_blood_full, df_comorb, on='hadm_id', how='left')
            df_blood_full[comorb] = df_blood_full[comorb].fillna(0).astype(int)
        else:
            df_blood_full[comorb] = 0

    con.close()
    
    expected_comorb_cols = list(comorb_tables.keys())
    for col in expected_comorb_cols:
        if col not in df_blood_full.columns:
            df_blood_full[col] = 0

    df_blood_full['gender'] = df_blood_full['gender'].map({'M': 0, 'F': 1})
    numeric_cols = df_blood_full.select_dtypes(include=['number']).columns.tolist()
    median_values = df_blood_full[numeric_cols].median().to_dict()
    df_blood_full.fillna(median_values, inplace=True)

    try:
        feature_names_full = joblib.load("feature_names_full-7.joblib")
        selected_features = joblib.load("selected_features-7.joblib")
    except Exception as e:
        raise RuntimeError(f"Required joblib artifact missing or unreadable: {e}. Cannot align data.")
    if not isinstance(feature_names_full, list):
        feature_names_full = list(feature_names_full)

    df_common = pd.merge(
        df_blood_full,
        df_ecg_full[['subject_id', 'processed_npy_path', 'feature_path']],
        on='subject_id', how='inner'
    )
    df_common = df_common[df_common['severity_level'].isin(['Low', 'Moderate', 'High'])].reset_index(drop=True)
    df_common['severity_class'] = df_common['severity_level'].map({'Low': 0, 'Moderate': 1, 'High': 2})

    # Splitting to get a representative test set
    df_train, df_test_fusion = train_test_split(
        df_common,
        test_size=0.2,
        random_state=42,
        stratify=df_common['severity_class']
    )

    for col in expected_comorb_cols:
        if col not in df_test_fusion.columns:
            df_test_fusion[col] = 0

    columns_to_drop_final = ['subject_id', 'hadm_id', 'cad', 'severity_level', 'severity_class', 'processed_npy_path', 'feature_path']
    X_blood_all_features = df_test_fusion.drop(columns=[c for c in columns_to_drop_final if c in df_test_fusion.columns], errors='ignore')
    X_blood_all_features = X_blood_all_features.reindex(columns=feature_names_full)

    for col in X_blood_all_features.columns:
        if X_blood_all_features[col].isna().any():
            if col in median_values:
                X_blood_all_features[col].fillna(median_values[col], inplace=True)
            else:
                X_blood_all_features[col].fillna(0, inplace=True)

    y_test_fusion = df_test_fusion['severity_class'].values
    X_ecg_metadata = df_test_fusion
    original_feature_names_list = feature_names_full
    
   
    return X_blood_all_features, X_ecg_metadata, y_test_fusion, original_feature_names_list


# Prediction Generation (P_A and P_B)


def generate_base_predictions(X_blood_raw, X_ecg_metadata, original_feature_names):
    """Generates P_A (XGBoost probs) and P_B (ECG-LSTM probs)."""
    
    model_blood = joblib.load(BLOOD_MODEL_PATH)
    scaler = joblib.load(SCALER_PATH)
    selected_features = joblib.load(FEATURES_LIST_PATH)
    
    model_ecg = FinalMultiInputHybridModel(num_clinical_features=NUM_CLINICAL_FEATURES).to(DEVICE)
    model_ecg.load_state_dict(torch.load(ECG_MODEL_PATH, map_location=DEVICE))
    model_ecg.eval()
    
    # --- A. Model A: XGBoost Predictions (P_A) ---
    X_test_blood_df = X_blood_raw[original_feature_names]
    X_test_full_np = X_test_blood_df.values 
    
    # Scale the full 99 features
    X_test_scaled_full = scaler.transform(X_test_full_np)
    
    # Select the final 30 features
    X_test_scaled_df = pd.DataFrame(X_test_scaled_full, columns=original_feature_names)
    X_test_sel = X_test_scaled_df[selected_features]
    
    # Predict probabilities
    P_A = model_blood.predict_proba(X_test_sel)
    

    # --- B. Model B: CNN-LSTM Predictions (P_B) ---
    ecg_dataset = HybridECGDataset(X_ecg_metadata)
    ecg_loader = DataLoader(ecg_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    P_B_list = []
    with torch.no_grad():
        for (x_signal, x_features), _ in tqdm(ecg_loader, desc="ECG Model Prediction"):
            x_signal, x_features = x_signal.to(DEVICE), x_features.to(DEVICE)
            logits = model_ecg(x_signal, x_features)
            probabilities = F.softmax(logits, dim=1).cpu().numpy()
            P_B_list.append(probabilities)
            
    P_B = np.concatenate(P_B_list, axis=0)
   
    
    # --- C. Create Fusion Input ---
    X_fusion = np.hstack((P_A, P_B)) # Shape N x 6 (3 probabilities from Model A + 3 from Model B)
    
    return X_fusion

# Final Fusion Prediction
def run_fusion_prediction(X_fusion_input):
    """Loads the meta-model and makes the final prediction."""
    
    try:
        meta_model = joblib.load(FUSION_MODEL_PATH)
        
    except FileNotFoundError:
        print(f"FATAL ERROR: Fusion model file '{FUSION_MODEL_PATH}' not found. Cannot proceed.")
        return None
    
    y_fusion_pred_numeric = meta_model.predict(X_fusion_input)
    
    severity_map = {0: 'Low', 1: 'Moderate', 2: 'High'}
    y_fusion_pred_label = np.array([severity_map[cls] for cls in y_fusion_pred_numeric])
    
    
    return y_fusion_pred_numeric, y_fusion_pred_label


In [None]:
#main
try:
    # 1. Load and align data
    X_blood_raw, X_ecg_metadata, y_true_labels, original_feature_names_list = load_and_align_test_data(CLEANED_CSV_PATH, DUCKDB_PATH)

    # 2. Generate input for the fusion model (P_A and P_B)
    X_fusion_input = generate_base_predictions(X_blood_raw, X_ecg_metadata, original_feature_names_list)

    # 3. Make the final prediction with the fusion model
    if X_fusion_input is not None:
        y_fusion_pred_numeric, y_fusion_pred_label = run_fusion_prediction(X_fusion_input)
        
        # 4. Display Results
        print("\n--- FINAL LATE FUSION PREDICTION SUMMARY ---")
        
        # Combine subject IDs and predictions for a summary table
        subject_ids = X_ecg_metadata['subject_id'].values
        results_df = pd.DataFrame({
            'subject_id': subject_ids,
            'Predicted_Severity_Label': y_fusion_pred_label,
            'True_Severity_level': X_ecg_metadata['severity_level'].values
        })
        
        
        print("\nFirst 5 Fused Predictions:")
        print(results_df.head())
        
     

except RuntimeError as e:
    print(f"\nExecution Failed: {e}")
    print("Please ensure all required data files (DuckDB, .joblib, .csv, and .pt) are correctly placed and named.")

ECG Model Prediction: 100%|██████████| 179/179 [03:19<00:00,  1.11s/it]



--- FINAL LATE FUSION PREDICTION SUMMARY ---

First 5 Fused Predictions:
   subject_id Predicted_Severity_Label True_Severity_level
0    11665092                     High                High
1    11173428                     High                High
2    11697344                     High                High
3    10711042                      Low                 Low
4    11568515                     High                High


In [10]:
print(results_df[5:10])

   subject_id Predicted_Severity_Label True_Severity_level
5    10485425                 Moderate            Moderate
6    11877234                 Moderate            Moderate
7    11296936                     High                High
8    11576109                     High                High
9    10692761                      Low                 Low
