# Import modules, data, and models

In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

import os, sys
from os.path import join, dirname, abspath, normpath
from pickle import load as pkl_load, dump as pkl_dump
import shap
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from typing import Optional
import sklearn
from scipy.stats import chi2_contingency

## Utils

In [None]:
def run_shap_on_model(model, X_preprocessed: pd.DataFrame, background_samples: Optional[int] = 100, seed: int = 42):
    """
    Run SHAP on a fitted model and preprocessed dataset.

    Args:
        model: fitted model object (sklearn/XGBoost/...) that supports prediction.
        X_preprocessed: preprocessed features as DataFrame or array used for SHAP evaluation.
        background_samples: number of background rows to use for the explainer (int) or None to use all.
        seed: RNG seed for reproducible sampling.

    Returns:
        shap_values: object returned by shap.Explainer(...)(X_preprocessed). Contains .values, .base_values, .data, .feature_names when available.
    """
    # Convert to DataFrame if needed (so feature names are preserved)
    if not isinstance(X_preprocessed, pd.DataFrame):
        try:
            X_preprocessed = pd.DataFrame(X_preprocessed)
        except Exception:
            pass

    # Choose a background dataset for the explainer (smaller is faster)
    if background_samples is None or (isinstance(background_samples, int) and background_samples >= len(X_preprocessed)):
        background = X_preprocessed
    else:
        background = X_preprocessed.sample(min(background_samples, len(X_preprocessed)), random_state=seed)

    # Use shap.Explainer which auto-selects the best explainer for the model type
    if type(model) in [sklearn.pipeline.Pipeline]:
        if hasattr(model[-1], 'predict_proba'):
            explainer = shap.Explainer(lambda X: model.predict_proba(X)[:,1], background)
        elif hasattr(model[-1], 'decision_function'):
            explainer = shap.Explainer(lambda X: model.decision_function(X), background)
        elif hasattr(model[-1], 'predict'):
            explainer = shap.Explainer(lambda X: model.predict(X.values), background)
        else:
            raise NotImplementedError("Model pipeline's final step does not have predict_proba, decision_function, or predict method.")
    else:
        if hasattr(model, 'decision_function'):
            print("Using decision_function for SHAP explainer.")
            explainer = shap.Explainer(lambda X: model.decision_function(X), background)
        else:
            explainer = shap.Explainer(model, background)

    # Compute SHAP values for the full preprocessed set
    shap_values = explainer(X_preprocessed)

    return shap_values


def save_shap_results(shap_values, path: str):
    """
    Save shap results (arrays and metadata) to disk using pickle.

    This stores a small dict with arrays so the heavy shap.Explanation object isn't required to be re-created.
    """
    os.makedirs(os.path.dirname(path) or '.', exist_ok=True)

    payload = {
        'values': getattr(shap_values, 'values', None),
        'base_values': getattr(shap_values, 'base_values', None),
        'data': getattr(shap_values, 'data', None),
        'feature_names': getattr(shap_values, 'feature_names', None),
        # also save the explanation's output names if ppayloadent
        'output_names': getattr(shap_values, 'output_names', None),
        'explanation': shap_values
    }

    with open(path, 'wb') as f:
        pkl_dump(payload, f)

    return path


def plot_shap_summary_plotly(shap_values, max_features: int = 20, title: Optional[str] = None):
    """
    Create a Plotly horizontal bar chart showing mean(|SHAP value|) per feature.

    Args:
        shap_values: the result object from shap.Explainer(...)(X)
        max_features: max number of features to show (top by mean abs SHAP)
        title: optional chart title

    Returns:
        fig: plotly.graph_objects.Figure
    """
    # Extract numeric array; handle multiclass shape
    vals = getattr(shap_values, 'values', None)
    if vals is None:
        raise ValueError("shap_values has no .values attribute")

    # If multiclass, shap returns shape (n_samples, n_features) for single-output or (n_samples, n_classes, n_features)
    if vals.ndim == 3:
        # reduce across classes by taking the mean absolute across classes
        vals_reduced = np.mean(np.abs(vals), axis=1)  # (n_samples, n_features)
    elif vals.ndim == 2:
        vals_reduced = vals
    else:
        raise ValueError(f"Unexpected shap values shape: {vals.shape}")

    mean_abs = np.mean(np.abs(vals_reduced), axis=0)

    # Feature names
    feature_names = getattr(shap_values, 'feature_names', None)
    if feature_names is None:
        # try to get from data attribute column names
        data = getattr(shap_values, 'data', None)
        if isinstance(data, pd.DataFrame):
            feature_names = data.columns.tolist()
        elif isinstance(data, np.ndarray):
            feature_names = [f'f{i}' for i in range(data.shape[1])]
        else:
            feature_names = [f'f{i}' for i in range(len(mean_abs))]

    df = pd.DataFrame({'feature': feature_names, 'mean_abs_shap': mean_abs})
    df = df.sort_values('mean_abs_shap', ascending=True).tail(max_features)

    fig = px.bar(df, x='mean_abs_shap', y='feature', orientation='h', title=title,
                 labels={'mean_abs_shap': 'mean(|SHAP value|)', 'feature': 'Feature'})
    fig.update_layout(margin=dict(l=200, r=20, t=60, b=40), height=50 + 30 * len(df))

    return fig


def save_plotly_figure(fig, path: str):
    """
    Save a Plotly figure to disk. If the path extension is .html, writes interactive HTML. Otherwise
    attempts to save a static image (requires kaleido) and falls back to HTML if that fails.

    Returns the path actually written.
    """
    os.makedirs(os.path.dirname(path) or '.', exist_ok=True)
    ext = os.path.splitext(path)[1].lower()

    if ext == '.html':
        fig.write_html(path)
        return path

    # try saving as image (png, svg, jpeg, pdf) if kaleido is installed
    try:
        fig.write_image(path)
        return path
    except Exception as e:
        # fallback to html (same filename with .html)
        fallback = path + '.html'
        fig.write_html(fallback)
        print(f"Could not write image directly ({e}). Saved interactive HTML to {fallback} instead.")
        return fallback

# Run SHAP analysis

In [None]:
OUTPUT_DIR = "output/shap"
experiment_names = ['model_results_immune_90', 
                    'model_results_immune_30', 
                    'model_results_all_90',
                    'model_results_all_30',
                    'model_results_labs_90',
                    'model_results_labs_30'
                    ]
model_names = [
    'XGB', 
    'KNN', 
    'SVC Linear', 
    'SVC RBF', 
    'SVC Poly'
    ]

In [None]:
# How much data should we use when running shap? arger sample = Longer runtime.
# None -> use all the available data
SAMPLE_SIZE = None

for experiment_name in experiment_names:
    with open(f"{experiment_name}.pkl", 'rb') as f:
        experiment = pkl_load(f)
    #experiment = globals().get(experiment_name)

    for trial_name in model_names:
        RESULTS_PATH = join(OUTPUT_DIR, experiment_name, trial_name)
        
        trial = experiment[trial_name]
        X_pre = trial['X_test_preprocessed'] # Column renaming necessary
        X_raw = trial['X_test_raw']
        model = trial['best_model']

        print(f"Running SHAP on {experiment_name}:{trial_name}...")
        shap_vals = run_shap_on_model(model[1], X_pre, background_samples=SAMPLE_SIZE)
        shap_vals.feature_names = X_pre.columns.tolist()

        # Save SHAP results and plots
        shap_out_path = os.path.join(RESULTS_PATH, f"shap_{experiment_name}_{trial_name}.pkl")
        save_shap_results(shap_vals, shap_out_path)
        print('Saved SHAP arrays to', shap_out_path)

        fig = plot_shap_summary_plotly(shap_vals, max_features=25, title=f"SHAP mean(|value|) - Immune {trial_name}")
        chart_out = save_plotly_figure(fig, os.path.join(RESULTS_PATH, f"shap_summary_{experiment_name}_{trial_name}.html"))
        chart_out = save_plotly_figure(fig, os.path.join(RESULTS_PATH, f"shap_summary_{experiment_name}_{trial_name}.jpeg"))
        print('Saved SHAP chart to', chart_out)



## Create a better shap plot

In [None]:
def load_pickle(path):
    with open(path, "rb") as f:
        return pkl_load(f)

def _as_2d(values, class_index=None):
    """
    Normalize SHAP `values` to shape (n_samples, n_features).
    - If it's a list (multi-output), pick class_index (default: 1 if it exists else 0).
    - If it's 1D, add a leading sample axis.
    """
    if isinstance(values, list):
        if class_index is None:
            class_index = 1 if len(values) > 1 else 0
        values = values[class_index]
    values = np.asarray(values)
    if values.ndim == 1:
        values = values[np.newaxis, :]
    return values

def plot_top_pos_neg_from_payload(
    dir: str,
    pickle_name: str,
    row_index: int = 0,
    top_k: int = 5,
    class_index: int | None = None,
    title: str | None = None
):
    """
    Load SHAP payload (values/base_values/data/feature_names/output_names) and
    plot the top-K positive and negative SHAP contributions for a single row.
    Returns a Plotly Figure.
    """
    payload = load_pickle(join(dir,pickle_name))

    values = payload.get("values", None)
    if values is None:
        raise ValueError("payload['values'] is missing; please ensure you saved shap_values.values.")

    # Normalize to (n_samples, n_features)
    S = _as_2d(values, class_index=class_index)

    # Feature names
    feature_names = payload.get("feature_names", None)
    if feature_names is None:
        feature_names = [f"f{i}" for i in range(S.shape[1])]
    feature_names = np.array(feature_names)

    # Optional raw feature values for hover
    data = payload.get("data", None)
    if data is not None:
        data = np.asarray(data)
        if data.ndim == 1:
            data = data[np.newaxis, :]

    # Select the target row
    if row_index < 0 or row_index >= S.shape[0]:
        raise IndexError(f"row_index {row_index} out of range (0..{S.shape[0]-1}).")
    s = S[row_index]  # shap vector (n_features,)

    # Indices for top positive & negative contributions
    k = min(top_k, s.size)
    pos_idx = np.argpartition(-s, kth=k-1)[:k]
    neg_idx = np.argpartition( s, kth=k-1)[:k]
    sel_idx = np.unique(np.concatenate([pos_idx, neg_idx]))

    # Build plotting frame
    dfp = pd.DataFrame({
        "feature": feature_names[sel_idx],
        "shap": s[sel_idx],
        "direction": np.where(s[sel_idx] >= 0, "positive", "negative"),
    })
    if data is not None and data.shape[1] == s.size:
        dfp["value"] = data[row_index, sel_idx]
    else:
        dfp["value"] = np.nan

    # Sort so negatives appear together (left) then positives (right)
    dfp = dfp.sort_values("shap")

    # Title
    if title is None:
        title = f"Top Â±{top_k} SHAP contributions (row {row_index})"

    # Plotly horizontal bar
    fig = px.bar(
        dfp, x="shap", y="feature", orientation="h",
        color="direction",
        color_discrete_map={"positive": "#2ca02c", "negative": "#d62728"},
        hover_data={"value": True, "shap":":.4f", "direction": False}
    )
    fig.update_layout(
        title=title,
        xaxis_title="SHAP value (impact on model output)",
        yaxis_title="",
        showlegend=False,
        bargap=0.25,
    )
    # Add text labels for quick read
    fig.update_traces(text=dfp["shap"].map(lambda v: f"{v:+.3f}"), textposition="outside", cliponaxis=False)

    # Optional save
    save_html = join(dir, f"shap_top_pos_neg_row{row_index}.html")
    save_png = join(dir, f"shap_top_pos_neg_row{row_index}.png")
    if save_html:
        os.makedirs(os.path.dirname(save_html) or ".", exist_ok=True)
        fig.write_html(save_html)
    if save_png:
        os.makedirs(os.path.dirname(save_png) or ".", exist_ok=True)
        # Requires kaleido installed: pip install -U kaleido
        fig.write_image(save_png, scale=2)

    return fig

def save_fig(fig, dir, name):
    fig.write_html(join(dir, f"{name}.html"))
    fig.write_image(join(dir, f"{name}.jpeg"))

## Run Plot Generation for the best observed models

1. (f1): All features + SVC RBF + 30 days

2. (recall): Labs + SVC RBF + 30 days

In [None]:
dir_1 = 'output/shap/model_results_all_30/SVC RBF/'
file_1 = 'output/shap/model_results_all_30/SVC RBF/shap_model_results_all_30_SVC RBF.pkl'
dir_2 = 'output/shap/model_results_labs_30/SVC RBF/'
file_2 = 'output/shap/model_results_labs_30/SVC RBF/shap_model_results_labs_30_SVC RBF.pkl'

file = file_1
dir = dir_1

payload = load_pickle(file)
if 'explanation' in payload:
    explainer = payload['explanation']
else:
    explainer = shap.Explanation(values=payload['values'], 
                           base_values=payload['base_values'], 
                           data=payload['data'], feature_names=payload['feature_names'], output_names=payload['output_names'])
# Eliminate extraneous parts of feature names for readability
colnames = [s.split("__")[1].replace('/','-') for s in explainer.feature_names]
explainer.feature_names = colnames

In [None]:

values = pd.DataFrame(explainer.values, columns=colnames)
# 1. Bar plot of mean of abs of shap vals
fig = plt.figure()
shap.plots.bar(explainer, show=False)
fig.savefig(join(dir, 'shap_abs_bar_plot.png'), dpi=300, bbox_inches='tight')
plt.close(fig)

# 2. Beeswarm plot
fig = plt.figure()
shap.plots.beeswarm(explainer, show=False)
fig.savefig(join(dir, 'shap_beeswarm_plot.png'), dpi=300, bbox_inches='tight')
plt.close(fig)

# Top and bottom bar plot
top_5 = values.mean().sort_values(ascending=False).head(5)
bottom_5 = values.mean().sort_values(ascending=False).tail(5)
top_bottom_5 = pd.concat([top_5, bottom_5])
top_bottom_5 = pd.DataFrame({
    "feature_name": top_bottom_5.index,
    "shap_value": top_bottom_5.values,
    "color": ['positive' if x > 0 else 'negative' for x in top_bottom_5]
                             })
fig = px.bar(data_frame=top_bottom_5, y="feature_name", x="shap_value", orientation='h', color='color', title='Top and Bottom 5 Features by Mean(SHAP value)')
save_fig(fig, dir, 'shap_top_bottom_5_features')

# 4. Scatter plots
for feature in top_bottom_5['feature_name']:
    fig = plt.figure()
    shap.plots.scatter(explainer[:, feature], title=f'SHAP Scatter Plot for {feature}', show=False)
    plt.tight_layout()
    fig.savefig(join(dir, f"shap_scatter_{feature.replace('/','-')}.png"), dpi=300, bbox_inches='tight')
    plt.close(fig)




## Calculate Correlation between a feature of interest and 30-day mortality

In [None]:
df = pd.read_csv("30_day_mort.csv")
last_careunit = pd.get_dummies(df["last_careunit"]).astype(int)
last_careunit["mortality"] = df["target"]

In [None]:
d = {}
for col in last_careunit.columns:
    pct = (last_careunit[col] & last_careunit["mortality"]).sum() / last_careunit[col].sum()
    d[col] = pct*100
corrs = last_careunit.corr()["mortality"]

Get correlation coefficient and Odds Ratio for CVICU

In [None]:
feature = "Cardiac Vascular Intensive Care Unit (CVICU)"
table = pd.crosstab(last_careunit[feature], last_careunit["mortality"])
chi2, p_value, dof, expected = chi2_contingency(table)
print(f"Chi-Squared Statistic: {chi2:.4f}")
print(f"**P-Value:** {p_value:.4f}")
# Get phi coefficient for correlation
n = table.sum().sum()  # Total number of observations
phi = np.sqrt(chi2 / n)
print(f"Phi Coefficient: {phi:.4f}")

In [None]:
a = table.loc[1, 1]  # Admitted, Died
b = table.loc[0, 1]  # Not Admitted, Died
c = table.loc[1, 0]  # Admitted, Lived
d = table.loc[0, 0]  # Not Admitted, Lived

# Calculate Odds Ratio: (a/c) / (b/d)  OR  (a*d) / (b*c)
odds_ratio = (a * d) / (b * c)

print(f"**Odds Ratio (OR):** {odds_ratio:.4f}")
print(f"Interpretation: The odds of 30-day mortality are {odds_ratio:.2f} times higher "
      "for patients admitted to the unit compared to those not admitted.")

In [None]:
log_or = np.log(odds_ratio)

# 2. Calculate the Standard Error (SE) of the log-OR
se_log_or = np.sqrt( (1/a) + (1/b) + (1/c) + (1/d) )

# 3. Get the 95% CI on the log scale
z_score = 1.96  # for 95% CI
log_ci_lower = log_or - z_score * se_log_or
log_ci_upper = log_or + z_score * se_log_or

# 4. Convert the CIs back to the normal scale
ci_lower = np.exp(log_ci_lower)
ci_upper = np.exp(log_ci_upper)

print(f"**Odds Ratio (OR):** {odds_ratio:.2f}")
print(f"**95% Confidence Interval (CI):** [{ci_lower:.2f} - {ci_upper:.2f}]")