# Import modules, data, and models

In [2]:
import os, sys
from os.path import join, dirname, abspath, normpath
from pickle import load as pkl_load, dump as pkl_dump
import shap
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from typing import Optional
import sklearn

Var definition

In [11]:
OUTPUT_DIR = "output/shap"
experiment_names = [#'model_results_immune_90', 
                    #'model_results_immune_30', 
                    #'model_results_all_90',
                    #'model_results_all_30',
                    'model_results_labs_90',
                    'model_results_labs_30']
model_names = [
    'XGB', 
    'KNN', 
    'SVC Linear', 
    'SVC RBF', 
    'SVC Poly']

## Utils

In [4]:
def create_if_not_exists(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

In [18]:
def run_shap_on_model(model, X_preprocessed: pd.DataFrame, background_samples: Optional[int] = 100, seed: int = 42):
    """
    Run SHAP on a fitted model and preprocessed dataset.

    Args:
        model: fitted model object (sklearn/XGBoost/...) that supports prediction.
        X_preprocessed: preprocessed features as DataFrame or array used for SHAP evaluation.
        background_samples: number of background rows to use for the explainer (int) or None to use all.
        seed: RNG seed for reproducible sampling.

    Returns:
        shap_values: object returned by shap.Explainer(...)(X_preprocessed). Contains .values, .base_values, .data, .feature_names when available.
    """
    # Convert to DataFrame if needed (so feature names are preserved)
    if not isinstance(X_preprocessed, pd.DataFrame):
        try:
            X_preprocessed = pd.DataFrame(X_preprocessed)
        except Exception:
            pass

    # Choose a background dataset for the explainer (smaller is faster)
    if background_samples is None or (isinstance(background_samples, int) and background_samples >= len(X_preprocessed)):
        background = X_preprocessed
    else:
        background = X_preprocessed.sample(min(background_samples, len(X_preprocessed)), random_state=seed)

    # Use shap.Explainer which auto-selects the best explainer for the model type
    if type(model) in [sklearn.pipeline.Pipeline]:
        if hasattr(model[-1], 'predict_proba'):
            explainer = shap.Explainer(lambda X: model.predict_proba(X)[:,1], background)
        elif hasattr(model[-1], 'decision_function'):
            explainer = shap.Explainer(lambda X: model.decision_function(X), background)
        elif hasattr(model[-1], 'predict'):
            explainer = shap.Explainer(lambda X: model.predict(X.values), background)
        else:
            raise NotImplementedError("Model pipeline's final step does not have predict_proba, decision_function, or predict method.")
    else:
        explainer = shap.Explainer(model, background)

    # Compute SHAP values for the full preprocessed set
    shap_values = explainer(X_preprocessed)

    return shap_values


def save_shap_results(shap_values, path: str):
    """
    Save shap results (arrays and metadata) to disk using pickle.

    This stores a small dict with arrays so the heavy shap.Explanation object isn't required to be re-created.
    """
    os.makedirs(os.path.dirname(path) or '.', exist_ok=True)

    payload = {
        'values': getattr(shap_values, 'values', None),
        'base_values': getattr(shap_values, 'base_values', None),
        'data': getattr(shap_values, 'data', None),
        'feature_names': getattr(shap_values, 'feature_names', None),
        # also save the explanation's output names if present
        'output_names': getattr(shap_values, 'output_names', None),
    }

    with open(path, 'wb') as f:
        pkl_dump(payload, f)

    return path


def plot_shap_summary_plotly(shap_values, max_features: int = 20, title: Optional[str] = None):
    """
    Create a Plotly horizontal bar chart showing mean(|SHAP value|) per feature.

    Args:
        shap_values: the result object from shap.Explainer(...)(X)
        max_features: max number of features to show (top by mean abs SHAP)
        title: optional chart title

    Returns:
        fig: plotly.graph_objects.Figure
    """
    # Extract numeric array; handle multiclass shape
    vals = getattr(shap_values, 'values', None)
    if vals is None:
        raise ValueError("shap_values has no .values attribute")

    # If multiclass, shap returns shape (n_samples, n_features) for single-output or (n_samples, n_classes, n_features)
    if vals.ndim == 3:
        # reduce across classes by taking the mean absolute across classes
        vals_reduced = np.mean(np.abs(vals), axis=1)  # (n_samples, n_features)
    elif vals.ndim == 2:
        vals_reduced = vals
    else:
        raise ValueError(f"Unexpected shap values shape: {vals.shape}")

    mean_abs = np.mean(np.abs(vals_reduced), axis=0)

    # Feature names
    feature_names = getattr(shap_values, 'feature_names', None)
    if feature_names is None:
        # try to get from data attribute column names
        data = getattr(shap_values, 'data', None)
        if isinstance(data, pd.DataFrame):
            feature_names = data.columns.tolist()
        elif isinstance(data, np.ndarray):
            feature_names = [f'f{i}' for i in range(data.shape[1])]
        else:
            feature_names = [f'f{i}' for i in range(len(mean_abs))]

    df = pd.DataFrame({'feature': feature_names, 'mean_abs_shap': mean_abs})
    df = df.sort_values('mean_abs_shap', ascending=True).tail(max_features)

    fig = px.bar(df, x='mean_abs_shap', y='feature', orientation='h', title=title,
                 labels={'mean_abs_shap': 'mean(|SHAP value|)', 'feature': 'Feature'})
    fig.update_layout(margin=dict(l=200, r=20, t=60, b=40), height=50 + 30 * len(df))

    return fig


def save_plotly_figure(fig, path: str):
    """
    Save a Plotly figure to disk. If the path extension is .html, writes interactive HTML. Otherwise
    attempts to save a static image (requires kaleido) and falls back to HTML if that fails.

    Returns the path actually written.
    """
    os.makedirs(os.path.dirname(path) or '.', exist_ok=True)
    ext = os.path.splitext(path)[1].lower()

    if ext == '.html':
        fig.write_html(path)
        return path

    # try saving as image (png, svg, jpeg, pdf) if kaleido is installed
    try:
        fig.write_image(path)
        return path
    except Exception as e:
        # fallback to html (same filename with .html)
        fallback = path + '.html'
        fig.write_html(fallback)
        print(f"Could not write image directly ({e}). Saved interactive HTML to {fallback} instead.")
        return fallback

In [None]:
def check_all_float(arr: np.array) -> bool:
    """
    Check if all elements in a numpy array are of float type.

    Args:
        arr: numpy array to check.
    """
    

# Run SHAP analysis

In [19]:
# How much data should we use when running shap? larger sample = slower runtime.
# None -> use all the available data
SAMPLE_SIZE = None

for experiment_name in experiment_names:
    with open(f"{experiment_name}.pkl", 'rb') as f:
        experiment = pkl_load(f)
    #experiment = globals().get(experiment_name)

    for trial_name in model_names:
        RESULTS_PATH = join(OUTPUT_DIR, experiment_name, trial_name)
        
        trial = experiment[trial_name]
        X_pre = trial['X_test_preprocessed'] # Column renaming necessary
        X_raw = trial['X_test_raw']
        model = trial['best_model']

        print(f"Running SHAP on {experiment_name}:{trial_name}...")
        #try:
        #if X_raw.values.dtype == 'float64':
        shap_vals = run_shap_on_model(model, X_raw, background_samples=SAMPLE_SIZE)
        #else:
        #    shap_vals = run_shap_on_model(model, X_pre, background_samples=SAMPLE_SIZE)

        shap_out_path = os.path.join(RESULTS_PATH, f"shap_{experiment_name}_{trial_name}.pkl")
        save_shap_results(shap_vals, shap_out_path)
        print('Saved SHAP arrays to', shap_out_path)

        fig = plot_shap_summary_plotly(shap_vals, max_features=25, title=f"SHAP mean(|value|) - Immune {trial_name}")
        chart_out = save_plotly_figure(fig, os.path.join(RESULTS_PATH, f"shap_summary_{experiment_name}_{trial_name}.html"))
        chart_out = save_plotly_figure(fig, os.path.join(RESULTS_PATH, f"shap_summary_{experiment_name}_{trial_name}.jpeg"))
        print('Saved SHAP chart to', chart_out)
        #except Exception as e:
        #    print(f"Error running SHAP on {experiment_name}:{trial_name}: {e}")


Running SHAP on model_results_labs_90:XGB...


TypeError: unsupported operand type(s) for -: 'str' and 'str'

In [None]:
model

In [None]:
trial['best_model'].decision_function(trial['X_test_raw'])

In [None]:
trial['X_test_raw']

In [None]:
model_results_immune_90['XGB']['best_model']