In [14]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
from sklearn.preprocessing import StandardScaler
import shap
import joblib
import matplotlib.pyplot as plt
import mlflow
import mlflow.sklearn
import os

# --- Configuration ---
# Define the root directory of your project.
# This assumes the script is run from a location where 'data/', 'models/', 'shap_plots/'
# can be found relative to this project_root.
project_root = '/Users/sangeethgeorge/MyProjects/oncoai-patient-outcome-navigator'

# Define paths for data, models, and SHAP plots
data_file_path = os.path.join(project_root, "data", "onco_features_cleaned.parquet")
model_save_base_path = os.path.join(project_root, "models")
shap_plots_base_path = os.path.join(project_root, "shap_plots")

# Ensure necessary directories exist
os.makedirs(os.path.dirname(data_file_path), exist_ok=True)
os.makedirs(model_save_base_path, exist_ok=True)
os.makedirs(shap_plots_base_path, exist_ok=True)

# --- Data Loading Function ---
def load_dataset(path: str = data_file_path) -> pd.DataFrame:
    """
    Loads the dataset from a specified parquet file.

    Args:
        path (str): The path to the parquet file.

    Returns:
        pd.DataFrame: The loaded DataFrame, or an empty DataFrame if the file is not found.
    """
    try:
        df = pd.read_parquet(path)
        print(f"✅ Dataset loaded successfully from {path}")
        return df
    except FileNotFoundError:
        print(f"❌ Error: Dataset not found at {path}. Please ensure the file exists and the path is correct.")
        return pd.DataFrame() # Return an empty DataFrame to indicate failure

# --- Data Preprocessing Functions ---
def train_test_impute_split(df: pd.DataFrame, label_col: str = "mortality_30d") -> tuple:
    """
    Splits the data into training and testing sets, and imputes missing values
    using medians calculated from the training set.

    Args:
        df (pd.DataFrame): The input DataFrame.
        label_col (str): The name of the target column.

    Returns:
        tuple: X_train, X_test, y_train, y_test DataFrames.
    """
    # Drop identifiers and timestamps
    df = df.drop(columns=['icustay_id', 'subject_id', 'hadm_id', 'admittime', 'dob', 'dod', 'intime', 'outtime', 'icd9_code'], errors='ignore')

    y = df[label_col]
    X = df.drop(columns=[label_col])

    # Split before imputation to prevent data leakage
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )

    # Impute missing values using training set statistics only
    for col in X_train.select_dtypes(include=np.number).columns:
        if X_train[col].isnull().any():
            median_val = X_train[col].median()
            X_train[col] = X_train[col].fillna(median_val)
            X_test[col] = X_test[col].fillna(median_val)
    
    print("✅ Data split and imputed successfully.")
    return X_train, X_test, y_train, y_test

def check_for_leakage(X: pd.DataFrame, y: pd.Series) -> pd.DataFrame:
    """
    Checks for potential data leakage by identifying highly correlated features
    with the target variable. Drops highly correlated columns if detected.

    Args:
        X (pd.DataFrame): Features DataFrame.
        y (pd.Series): Target Series.

    Returns:
        pd.DataFrame: Features DataFrame after dropping highly correlated columns
                      if leakage is detected.
    """
    X_copy = X.copy()
    y_copy = y.copy()
    # Reset indices to ensure correct concatenation for correlation calculation
    X_copy.index = range(len(X_copy))
    y_copy.index = range(len(y_copy))

    combined_df = pd.concat([X_copy, y_copy], axis=1)
    
    # Ensure all columns used for correlation are numeric
    numeric_cols = combined_df.select_dtypes(include=np.number).columns
    combined_df_numeric = combined_df[numeric_cols]

    if y.name not in combined_df_numeric.columns:
        print(f"⚠️ Warning: Target column '{y.name}' not found in numeric columns for leakage check. Skipping correlation check.")
        return X

    corr = combined_df_numeric.corr()[y.name].drop(y.name, errors='ignore') # drop target if it's there
    high_corr_threshold = 0.95
    high_corr = corr[abs(corr) > high_corr_threshold]

    if not high_corr.empty:
        print(f"\n⚠️ Potential Leakage Detected (correlation > {high_corr_threshold}):")
        print(high_corr)
        leaky_columns = high_corr.index.tolist()
        X = X.drop(columns=leaky_columns, errors='ignore')
        print(f"Dropped potential leakage columns: {leaky_columns}")
    else:
        print("\nNo significant data leakage detected based on high correlation.")
    return X

# --- Model Training and Evaluation ---
def train_logistic_regression(X_train: np.ndarray, y_train: pd.Series, 
                              X_test: np.ndarray, y_test: pd.Series) -> tuple:
    """
    Trains a Logistic Regression model and evaluates its performance.

    Args:
        X_train (np.ndarray): Scaled training features.
        y_train (pd.Series): Training labels.
        X_test (np.ndarray): Scaled testing features.
        y_test (pd.Series): Testing labels.

    Returns:
        tuple: Trained model, X_train, y_train, X_test, y_test, y_pred, y_prob.
    """
    model = LogisticRegression(max_iter=1000, random_state=42)
    model.fit(X_train, y_train)

    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)[:, 1]

    print("\n🧠 Classification Report:")
    print(classification_report(y_test, y_pred))
    print("\n📊 ROC AUC Score:", roc_auc_score(y_test, y_prob))
    print("✅ Logistic Regression model trained and evaluated.")
    return model, X_train, y_train, X_test, y_test, y_pred, y_prob

# --- SHAP Explanation and Plotting ---
def explain_predictions(model, X_scaled: np.ndarray, X_df: pd.DataFrame, 
                        output_dir: str = "shap_plots", top_n: int = 10):
    """
    Generates and saves SHAP plots for overall interpretability and
    for top N high-risk patients.

    Args:
        model: The trained machine learning model.
        X_scaled (np.ndarray): Scaled features (e.g., X_test_scaled).
        X_df (pd.DataFrame): Original DataFrame of features corresponding to X_scaled
                              (e.g., X_test_leakage_checked). Used for feature names and values.
        output_dir (str): Directory to save the SHAP plots.
        top_n (int): Number of top high-risk patients for whom to generate individual plots.
    """
    os.makedirs(output_dir, exist_ok=True)

    if X_scaled.shape[0] == 0:
        print("Skipping SHAP explanation: No data in X_scaled for explanation.")
        return # Exit if no data to explain

    # Initialize Explainer: Pass X_scaled (numpy array) and feature_names (list)
    explainer = shap.Explainer(model, X_scaled, feature_names=X_df.columns.tolist())
    shap_values = explainer(X_scaled)

    # Explicitly set feature_names on the Explanation object (redundant but robust)
    shap_values.feature_names = X_df.columns.tolist()
    print(f"DEBUG: shap_values.feature_names after explicit assignment: {shap_values.feature_names}")


    # Prepare display_data for more readable categorical values
    # IMPORTANT: Adapt these mappings to your actual data and features
    display_data_df = X_df.copy()
    # Example mappings (customize these based on your actual data encoding)
    if 'gender_encoded' in display_data_df.columns:
        display_data_df['gender_encoded'] = display_data_df['gender_encoded'].map({0: 'Female', 1: 'Male'}).fillna(display_data_df['gender_encoded'])
    if 'ethnicity_encoded' in display_data_df.columns:
        ethnicity_map = {0: 'White', 1: 'Black', 2: 'Asian', 3: 'Other'}
        display_data_df['ethnicity_encoded'] = display_data_df['ethnicity_encoded'].map(ethnicity_map).fillna(display_data_df['ethnicity_encoded'])
    if 'admission_type_encoded' in display_data_df.columns:
        admission_type_map = {0: 'Emergency', 1: 'Elective', 2: 'Urgent'}
        display_data_df['admission_type_encoded'] = display_data_df['admission_type_encoded'].map(admission_type_map).fillna(display_data_df['admission_type_encoded'])
    
    # Assign the display_data to the shap_values object (expects a NumPy array)
    shap_values.display_data = display_data_df.values
    print(f"DEBUG: shap_values.display_data shape: {shap_values.display_data.shape}")


    # 📈 SHAP Summary Plot (overall feature importance)
    print("\n📈 Generating overall SHAP Summary Plot...")
    plt.figure(figsize=(12, 8))
    shap.summary_plot(shap_values, features=X_df, show=False)
    plt.savefig(os.path.join(output_dir, "shap_summary_overall.png"), bbox_inches='tight')
    plt.close()
    print("✅ Overall SHAP Summary Plot generated.")

    # Generate dependence plots for top K overall features
    if len(X_df.columns) > 0:
        abs_shap_means = np.abs(shap_values.values).mean(axis=0)
        top_overall_features_indices = np.argsort(abs_shap_means)[::-1][:min(3, len(X_df.columns))] # Top 3 features
        print("\n🔍 Generating SHAP Dependence Plots for top overall features...")
        for feat_idx in top_overall_features_indices:
            top_feature_name = X_df.columns[feat_idx] # Get the actual feature name
            plt.figure(figsize=(8, 6))
            shap.dependence_plot(top_feature_name, shap_values.values, features=X_df, feature_names=X_df.columns.tolist(), show=False)
            plt.title(f"SHAP Dependence Plot: {top_feature_name}")
            plt.savefig(os.path.join(output_dir, f"shap_dependence_{top_feature_name}.png"), bbox_inches='tight')
            plt.close()
        print("✅ SHAP Dependence Plots for top overall features generated.")
    else:
        print("\nSkipping SHAP Dependence Plots for overall features: No features available.")

    # Rank by predicted probability for individual plots
    if X_scaled.shape[0] > 0:
        risk_scores = model.predict_proba(X_scaled)[:, 1]
        num_samples = X_scaled.shape[0]
        actual_top_n = min(top_n, num_samples)
        top_indices = np.argsort(risk_scores)[-actual_top_n:][::-1]

        print(f"\n📊 Generating SHAP plots for top {actual_top_n} high-risk patients...")

        for i_idx, original_index in enumerate(top_indices):
            patient_identifier = f"patient_idx_{original_index}_rank_{i_idx+1}"

            # 🧬 SHAP Waterfall Plot
            plt.figure(figsize=(10, 6))
            shap.plots.waterfall(shap_values[original_index], show=False)
            plt.title(f"SHAP Waterfall Plot for {patient_identifier}\nPredicted Risk: {risk_scores[original_index]:.4f}")
            plt.savefig(os.path.join(output_dir, f"waterfall_{patient_identifier}.png"), bbox_inches='tight')
            plt.close()

            # Dependence Plots for top features of THIS SPECIFIC patient
            relevant_shap_values_patient = shap_values.values[original_index]
            num_features = len(X_df.columns)
            if num_features > 0:
                top_patient_features_indices = np.argsort(np.abs(relevant_shap_values_patient))[::-1][:min(3, num_features)]

                for feat_idx in top_patient_features_indices:
                    feat_name = X_df.columns[feat_idx]
                    plt.figure(figsize=(8, 6))
                    shap.dependence_plot(feat_name, shap_values.values, features=X_df, feature_names=X_df.columns.tolist(), show=False)
                    plt.title(f"SHAP Dependence Plot for {patient_identifier} - {feat_name}")
                    plt.savefig(os.path.join(output_dir, f"dependence_{patient_identifier}_{feat_name}.png"), bbox_inches='tight')
                    plt.close()
            else:
                print(f"No features available for dependence plots for {patient_identifier}.")

        print(f"✅ Individual SHAP plots for top {actual_top_n} patients saved in {output_dir}/")
    else:
        print("\nSkipping individual SHAP plots: No test data available.")

    print("✅ All SHAP plots generated and saved.")
    return shap_values

# --- Model Saving Function ---
def save_model(model, scaler, output_path: str):
    """
    Saves the trained model and scaler to a joblib file.

    Args:
        model: The trained machine learning model.
        scaler: The fitted StandardScaler object.
        output_path (str): The path to save the model.
    """
    joblib.dump({"model": model, "scaler": scaler}, output_path)
    print(f"\n✅ Saved model and scaler to {output_path}")

# --- Main MLflow Execution Block ---
if __name__ == "__main__":
    # Set up MLflow tracking
    mlflow.set_experiment("OncoAI-Mortality-Prediction")

    # Start an MLflow run
    with mlflow.start_run() as run:
        run_id = run.info.run_id
        
        # Make SHAP plots output directory specific to the run
        run_shap_output_dir = os.path.join(shap_plots_base_path, run_id)
        os.makedirs(run_shap_output_dir, exist_ok=True)
        
        # Define the full model save path for this run
        model_save_path_for_run = os.path.join(model_save_base_path, f"logreg_model_run_{run_id}.joblib")

        print(f"Starting MLflow Run with ID: {run_id}")
        print(f"SHAP plots will be saved to: {run_shap_output_dir}")
        print(f"Model will be saved to: {model_save_path_for_run}")

        # 1. Load dataset
        df = load_dataset()

        if df.empty: # Check if DataFrame is empty
            print("❌ Dataset is empty. Cannot proceed with training and explanation. Exiting MLflow run.")
            mlflow.end_run(status="FAILED")
        else:
            # 2. Train-test split and imputation
            X_train, X_test, y_train, y_test = train_test_impute_split(df)

            # 3. One-hot encode categorical columns
            X_train_ohe = pd.get_dummies(X_train, drop_first=True)
            X_test_ohe = pd.get_dummies(X_test, drop_first=True)

            # Align columns after one-hot encoding to ensure same features in train/test
            missing_cols_in_test = set(X_train_ohe.columns) - set(X_test_ohe.columns)
            for c in missing_cols_in_test:
                X_test_ohe[c] = 0
            # Ensure the order of columns is the same
            X_test_ohe = X_test_ohe[X_train_ohe.columns]

            # 4. Check for data leakage on the one-hot encoded training data
            X_train_leakage_checked = check_for_leakage(X_train_ohe, y_train)

            # Apply the same column selection (after leakage check) to the test set
            X_test_leakage_checked = X_test_ohe[X_train_leakage_checked.columns]

            # DEBUG: Print column names to verify before scaling and SHAP
            print(f"\nDEBUG: Columns of X_train_leakage_checked before scaling:\n{X_train_leakage_checked.columns.tolist()}")
            print(f"DEBUG: Columns of X_test_leakage_checked before scaling:\n{X_test_leakage_checked.columns.tolist()}")
            
            # 5. Scale features
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train_leakage_checked)
            X_test_scaled = scaler.transform(X_test_leakage_checked)

            print("✅ Features prepared (one-hot encoded and scaled).")

            # Log parameters to MLflow
            mlflow.log_param("scaler", "StandardScaler")
            mlflow.log_param("model_type", "LogisticRegression")

            # 6. Train Logistic Regression model
            model, X_train_final_scaled, y_train_final, X_test_final_scaled, y_test_final, y_pred, y_prob = \
                train_logistic_regression(X_train_scaled, y_train, X_test_scaled, y_test)

            # Log ROC AUC metric to MLflow
            auc = roc_auc_score(y_test_final, y_prob)
            mlflow.log_metric("roc_auc", auc)

            # 7. Save model and scaler
            save_model(model, scaler, output_path=model_save_path_for_run)

            # 8. Log model with input_example for signature inference
            if X_train_leakage_checked.shape[0] > 0:
                mlflow.sklearn.log_model(model, "logreg_model", 
                                         input_example=X_train_leakage_checked.head(10))
            else:
                mlflow.sklearn.log_model(model, "logreg_model") 

            # 9. Explain with SHAP
            # Pass X_test_scaled (numpy array) and X_test_leakage_checked (DataFrame)
            # X_test_leakage_checked is critical for feature names and display data.
            explain_predictions(model, X_test_scaled, X_test_leakage_checked, output_dir=run_shap_output_dir)

            # 10. Log SHAP plots as MLflow artifacts
            shap_plot_files = [f for f in os.listdir(run_shap_output_dir) if f.endswith('.png')]
            for plot_file in shap_plot_files:
                mlflow.log_artifact(os.path.join(run_shap_output_dir, plot_file), artifact_path="shap_plots")

    print("\n✨ MLflow run completed successfully. Check your MLflow UI for details.")

  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0
  X_test_ohe[c] = 0


Starting MLflow Run with ID: 533af684cc7142de8053496a19833157
SHAP plots will be saved to: /Users/sangeethgeorge/MyProjects/oncoai-patient-outcome-navigator/shap_plots/533af684cc7142de8053496a19833157
Model will be saved to: /Users/sangeethgeorge/MyProjects/oncoai-patient-outcome-navigator/models/logreg_model_run_533af684cc7142de8053496a19833157.joblib
✅ Dataset loaded successfully from /Users/sangeethgeorge/MyProjects/oncoai-patient-outcome-navigator/data/onco_features_cleaned.parquet
✅ Data split and imputed successfully.

No significant data leakage detected based on high correlation.

DEBUG: Columns of X_train_leakage_checked before scaling:
['age', 'mean_heart_rate', 'mean_respiratory_rate', 'min_heart_rate', 'min_respiratory_rate', 'max_heart_rate', 'max_respiratory_rate', 'slope_heart_rate', 'slope_respiratory_rate', 'mean_anion_gap', 'mean_bicarbonate', 'mean_calcium,_total', 'mean_chloride', 'mean_creatinine_y', 'mean_glucose', 'mean_hematocrit', 'mean_hemoglobin_y', 'mean_mch



DEBUG: shap_values.feature_names after explicit assignment: ['age', 'mean_heart_rate', 'mean_respiratory_rate', 'min_heart_rate', 'min_respiratory_rate', 'max_heart_rate', 'max_respiratory_rate', 'slope_heart_rate', 'slope_respiratory_rate', 'mean_anion_gap', 'mean_bicarbonate', 'mean_calcium,_total', 'mean_chloride', 'mean_creatinine_y', 'mean_glucose', 'mean_hematocrit', 'mean_hemoglobin_y', 'mean_mch', 'mean_mchc', 'mean_mcv', 'mean_magnesium_y', 'mean_phosphate', 'mean_platelet_count', 'mean_potassium', 'mean_rdw', 'mean_red_blood_cells', 'mean_sodium', 'mean_urea_nitrogen', 'mean_white_blood_cells', 'min_anion_gap', 'min_bicarbonate', 'min_calcium,_total', 'min_chloride', 'min_creatinine_y', 'min_glucose', 'min_hematocrit', 'min_hemoglobin_y', 'min_mch', 'min_mchc', 'min_mcv', 'min_magnesium_y', 'min_phosphate', 'min_platelet_count', 'min_potassium', 'min_rdw', 'min_red_blood_cells', 'min_sodium', 'min_urea_nitrogen', 'min_white_blood_cells', 'max_anion_gap', 'max_bicarbonate', 'm