# Import modules, data, and models

In [11]:
import os, sys
from os.path import join, dirname, abspath, normpath
from pickle import load as pkl_load, dump as pkl_dump
import shap
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from typing import Optional
import sklearn

Var definition

In [2]:
OUTPUT_DIR = "output/shap"
names = ['model_results_immune_90']
model_names = ['XGB', 'KNN', 'SVC Linear', 'SVC RBF', 'SVC Poly']

In [3]:
for name in names:
    with open(f"{name}.pkl", 'rb') as f:
        globals()[name] = pkl_load(f)

## Utils

In [4]:
def create_if_not_exists(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

In [17]:
def run_shap_on_model(model, X_preprocessed: pd.DataFrame, background_samples: Optional[int] = 100, seed: int = 42):
    """
    Run SHAP on a fitted model and preprocessed dataset.

    Args:
        model: fitted model object (sklearn/XGBoost/...) that supports prediction.
        X_preprocessed: preprocessed features as DataFrame or array used for SHAP evaluation.
        background_samples: number of background rows to use for the explainer (int) or None to use all.
        seed: RNG seed for reproducible sampling.

    Returns:
        shap_values: object returned by shap.Explainer(...)(X_preprocessed). Contains .values, .base_values, .data, .feature_names when available.
    """
    # Convert to DataFrame if needed (so feature names are preserved)
    if not isinstance(X_preprocessed, pd.DataFrame):
        try:
            X_preprocessed = pd.DataFrame(X_preprocessed)
        except Exception:
            pass

    # Choose a background dataset for the explainer (smaller is faster)
    if background_samples is None or (isinstance(background_samples, int) and background_samples >= len(X_preprocessed)):
        background = X_preprocessed
    else:
        background = X_preprocessed.sample(min(background_samples, len(X_preprocessed)), random_state=seed)

    # Use shap.Explainer which auto-selects the best explainer for the model type
    if type(model) in [sklearn.pipeline.Pipeline]:
        explainer = shap.Explainer(lambda X: model.predict_proba(X)[:,1], background)
    else:
        explainer = shap.Explainer(model, background)

    # Compute SHAP values for the full preprocessed set
    shap_values = explainer(X_preprocessed)

    return shap_values


def save_shap_results(shap_values, path: str):
    """
    Save shap results (arrays and metadata) to disk using pickle.

    This stores a small dict with arrays so the heavy shap.Explanation object isn't required to be re-created.
    """
    os.makedirs(os.path.dirname(path) or '.', exist_ok=True)

    payload = {
        'values': getattr(shap_values, 'values', None),
        'base_values': getattr(shap_values, 'base_values', None),
        'data': getattr(shap_values, 'data', None),
        'feature_names': getattr(shap_values, 'feature_names', None),
        # also save the explanation's output names if present
        'output_names': getattr(shap_values, 'output_names', None),
    }

    with open(path, 'wb') as f:
        pkl_dump(payload, f)

    return path


def plot_shap_summary_plotly(shap_values, max_features: int = 20, title: Optional[str] = None):
    """
    Create a Plotly horizontal bar chart showing mean(|SHAP value|) per feature.

    Args:
        shap_values: the result object from shap.Explainer(...)(X)
        max_features: max number of features to show (top by mean abs SHAP)
        title: optional chart title

    Returns:
        fig: plotly.graph_objects.Figure
    """
    # Extract numeric array; handle multiclass shape
    vals = getattr(shap_values, 'values', None)
    if vals is None:
        raise ValueError("shap_values has no .values attribute")

    # If multiclass, shap returns shape (n_samples, n_features) for single-output or (n_samples, n_classes, n_features)
    if vals.ndim == 3:
        # reduce across classes by taking the mean absolute across classes
        vals_reduced = np.mean(np.abs(vals), axis=1)  # (n_samples, n_features)
    elif vals.ndim == 2:
        vals_reduced = vals
    else:
        raise ValueError(f"Unexpected shap values shape: {vals.shape}")

    mean_abs = np.mean(np.abs(vals_reduced), axis=0)

    # Feature names
    feature_names = getattr(shap_values, 'feature_names', None)
    if feature_names is None:
        # try to get from data attribute column names
        data = getattr(shap_values, 'data', None)
        if isinstance(data, pd.DataFrame):
            feature_names = data.columns.tolist()
        elif isinstance(data, np.ndarray):
            feature_names = [f'f{i}' for i in range(data.shape[1])]
        else:
            feature_names = [f'f{i}' for i in range(len(mean_abs))]

    df = pd.DataFrame({'feature': feature_names, 'mean_abs_shap': mean_abs})
    df = df.sort_values('mean_abs_shap', ascending=True).tail(max_features)

    fig = px.bar(df, x='mean_abs_shap', y='feature', orientation='h', title=title,
                 labels={'mean_abs_shap': 'mean(|SHAP value|)', 'feature': 'Feature'})
    fig.update_layout(margin=dict(l=200, r=20, t=60, b=40), height=50 + 30 * len(df))

    return fig


def save_plotly_figure(fig, path: str):
    """
    Save a Plotly figure to disk. If the path extension is .html, writes interactive HTML. Otherwise
    attempts to save a static image (requires kaleido) and falls back to HTML if that fails.

    Returns the path actually written.
    """
    os.makedirs(os.path.dirname(path) or '.', exist_ok=True)
    ext = os.path.splitext(path)[1].lower()

    if ext == '.html':
        fig.write_html(path)
        return path

    # try saving as image (png, svg, jpeg, pdf) if kaleido is installed
    try:
        fig.write_image(path)
        return path
    except Exception as e:
        # fallback to html (same filename with .html)
        fallback = path + '.html'
        fig.write_html(fallback)
        print(f"Could not write image directly ({e}). Saved interactive HTML to {fallback} instead.")
        return fallback

# Run SHAP analysis

In [None]:
experiment = model_results_immune_90['XGB']
X_raw, X_preprocessed, model, params,  = experiment['X_test_raw'], experiment['X_test_preprocessed'], experiment['best_model'], experiment['best_params']


In [23]:
# How much data should we use when running shap? larger sample = slower runtime.
# None -> use all the available data
SAMPLE_SIZE = 20

experiment_name = 'model_results_immune_90'

experiment = globals().get(experiment_name)
if experiment is None:
    raise ValueError("Invalid experiment name")
trial_name = "XGB"
RESULTS_PATH = join(OUTPUT_DIR, experiment_name, trial_name)

trial = experiment[trial_name]
X_pre = trial['X_test_preprocessed'] # Column renaming necessary
X_raw = trial['X_test_raw']
model = trial['best_model']

print('Running SHAP...')
shap_vals = run_shap_on_model(model, X_raw, background_samples=SAMPLE_SIZE)

shap_out_path = os.path.join(RESULTS_PATH, f"shap_{experiment_name}_{trial_name}.pkl")
save_shap_results(shap_vals, shap_out_path)
print('Saved SHAP arrays to', shap_out_path)

fig = plot_shap_summary_plotly(shap_vals, max_features=25, title=f"SHAP mean(|value|) - Immune {trial_name}")
chart_out = save_plotly_figure(fig, os.path.join(RESULTS_PATH, f"shap_summary_{experiment_name}_{trial_name}.html"))
print('Saved SHAP chart to', chart_out)

# Display the figure in the notebook (Jupyter will render it)
fig

Running SHAP...
Saved SHAP arrays to output/shap/model_results_immune_90/XGB/shap_model_results_immune_90_XGB.pkl
Saved SHAP arrays to output/shap/model_results_immune_90/XGB/shap_model_results_immune_90_XGB.pkl
Saved SHAP chart to output/shap/model_results_immune_90/XGB/shap_summary_model_results_immune_90_XGB.html
Saved SHAP chart to output/shap/model_results_immune_90/XGB/shap_summary_model_results_immune_90_XGB.html


In [None]:
model(X_pre

Unnamed: 0,num__SIRI,num__Absolute Monocyte Count,num__Absolute Lymphocyte Count,num__Absolute Neutrophil Count
0,-0.104579,0.044254,-0.023154,1.512731
1,-0.073325,0.204830,-0.048654,0.524386
2,-0.222816,-0.421416,0.030677,1.923492
3,-0.484886,-0.694394,0.234671,-0.619765
4,-0.136470,-0.357185,-0.190316,-0.572551
...,...,...,...,...
384,-0.082380,-0.068149,-0.132234,-0.026443
385,0.004861,0.301175,0.159590,3.106988
386,-0.409406,0.285117,-0.007572,-1.137545
387,-0.170623,-0.517761,-0.211565,-0.773997


In [9]:
type(model)

sklearn.pipeline.Pipeline

In [13]:
model_results_immune_90.keys()

dict_keys(['XGB', 'KNN', 'SVC Linear', 'SVC RBF', 'SVC Poly'])

In [21]:
model_results_immune_90['XGB']['X_test_raw']

Unnamed: 0,SIRI,Absolute Monocyte Count,Absolute Lymphocyte Count,Absolute Neutrophil Count
366,7.650056,0.79,1.77,20.97
1613,8.222704,0.89,1.59,14.69
1182,5.483721,0.50,2.15,23.58
1254,0.682061,0.33,3.59,7.42
384,7.065763,0.54,0.59,7.72
...,...,...,...,...
973,8.056800,0.72,1.00,11.19
1461,9.655229,0.95,3.06,31.10
1731,2.065000,0.94,1.88,4.13
990,6.440000,0.44,0.44,6.44


In [16]:
model_results_immune_90['XGB']['best_model']

0,1,2
,steps,"[('preprocessor', ...), ('model', ...)]"
,transform_input,
,memory,
,verbose,False

0,1,2
,transformers,"[('num', ...), ('cat', ...)]"
,remainder,'drop'
,sparse_threshold,0.3
,n_jobs,
,transformer_weights,
,verbose,False
,verbose_feature_names_out,True
,force_int_remainder_cols,'deprecated'

0,1,2
,copy,True
,with_mean,True
,with_std,True

0,1,2
,missing_values,
,strategy,'constant'
,fill_value,'missing'
,copy,True
,add_indicator,False
,keep_empty_features,False

0,1,2
,categories,'auto'
,drop,
,sparse_output,False
,dtype,<class 'numpy.float64'>
,handle_unknown,'ignore'
,min_frequency,
,max_categories,
,feature_name_combiner,'concat'

0,1,2
,objective,'binary:logistic'
,base_score,
,booster,
,callbacks,
,colsample_bylevel,
,colsample_bynode,
,colsample_bytree,0.1
,device,
,early_stopping_rounds,
,enable_categorical,False
