In [None]:
# Step 0: If not already installed, install the necessary libraries
# pip install shap scikit-learn pandas matplotlib

# Step 1: Import Necessary Libraries
import shap
import pandas as pd
import numpy as np 
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import fetch_california_housing # Using the California housing dataset as an example
import matplotlib.pyplot as plt

# Print library versions for reproducibility and debugging
print(f"SHAP version: {shap.__version__}")
print(f"scikit-learn version: {sklearn.__version__}")
print(f"pandas version: {pd.__version__}")
print(f"numpy version: {np.__version__}")

# Step 2: Load and Prepare Data
print("\n--- Loading and Preparing Data ---")
housing = fetch_california_housing()
X = pd.DataFrame(housing.data, columns=housing.feature_names)
y = housing.target

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"Training set samples: {X_train.shape[0]}")
print(f"Test set samples: {X_test.shape[0]}")
print(f"Feature columns: {X_train.columns.tolist()}")
print("Data loading and preparation complete.\n")

# Step 3: Train the Model
print("--- Training the Model ---")
# Initialize a RandomForestRegressor model with some common parameters
model = RandomForestRegressor(n_estimators=100, random_state=42, max_depth=10, min_samples_leaf=5)
# Fit the model to the training data
model.fit(X_train, y_train)
print("Model training complete.\n")

# Step 4: Initialize the SHAP Explainer
print("--- Initializing SHAP Explainer ---")
# shap.Explainer automatically selects an appropriate explainer for the given model.
# For tree-based models, it uses TreeExplainer. X_train serves as background data.
explainer = shap.Explainer(model, X_train)
print(f"SHAP Explainer type used: {type(explainer)}") # Should output <class 'shap.explainers._tree.Tree'> or similar
print("SHAP Explainer initialization complete.\n")

# Step 5: Calculate SHAP Values
print("--- Calculating SHAP Values (this may take some time) ---")
# Add check_additivity=False to prevent Additivity check failed errors due to minor floating-point differences.
shap_values_obj_test = explainer(X_test, check_additivity=False)

print(f"SHAP values object type: {type(shap_values_obj_test)}") # Should output <class 'shap._explanation.Explanation'>
if hasattr(shap_values_obj_test, 'values'):
    print(f"SHAP values array shape: {shap_values_obj_test.values.shape}") # (n_samples, n_features)

# Get the base value (expected_value)
# The Explanation object usually has a base_values attribute.
base_value = None # Initialize base_value
if hasattr(shap_values_obj_test, 'base_values') and shap_values_obj_test.base_values is not None and len(shap_values_obj_test.base_values) > 0 :
    # For single-output regression models, base_values for all samples are typically the same.
    base_value = shap_values_obj_test.base_values[0]
    if isinstance(base_value, np.ndarray) and base_value.ndim > 0: # If base_values is an array (e.g., multi-output model)
        base_value = base_value[0] # Take the base value of the first output
    print(f"Base value (expected_value from Explanation object): {base_value}")
elif hasattr(explainer, 'expected_value'): # As a fallback, get it from the explainer object
    base_value = explainer.expected_value
    if isinstance(base_value, np.ndarray) and base_value.ndim > 0: # If explainer.expected_value is an array
         base_value = base_value[0]
    print(f"Base value (expected_value from explainer): {base_value}")
else:
    print("Could not retrieve base value (expected_value).")
print("SHAP values calculation complete.\n")

# Step 6: Visualize SHAP Values
print("--- Generating SHAP Visualizations ---")

# a. Explanation of a Single Prediction (Waterfall Plot)
# A waterfall plot shows how features contribute to push the prediction
# from the base value (average prediction) to the final predicted value for a single instance.
instance_idx = 0 # Choose the first instance in the test set to explain
print(f"\nExplaining prediction for instance {instance_idx} in the test set:")
plt.figure(figsize=(10, 6)) # Adjust figure size for better readability
# shap.plots.waterfall requires a slice of the Explanation object for a single instance
shap.plots.waterfall(shap_values_obj_test[instance_idx], show=False)
plt.title(f"SHAP Waterfall Plot for Prediction of Instance {instance_idx}")
plt.tight_layout() # Automatically adjust subplot params for a tight layout
plt.show()

# b. Global Feature Importance (Bar Plot)
# The bar plot shows the mean absolute SHAP value for each feature,
# indicating its overall importance to the model.
print("\nGlobal Feature Importance (Bar Plot):")
plt.figure(figsize=(10, 8))
shap.plots.bar(shap_values_obj_test, show=False)
plt.title("SHAP Global Feature Importance (Mean Absolute SHAP Value)")
plt.tight_layout()
plt.show()

# c. Global Feature Insights (Beeswarm Plot / Summary Plot)
# The beeswarm plot combines feature importance with feature effects.
# Each dot represents a SHAP value for a feature of a specific instance.
print("\nGlobal Feature Insights (Beeswarm Plot):")
plt.figure(figsize=(10, 8))
shap.plots.beeswarm(shap_values_obj_test, show=False)
plt.title("SHAP Beeswarm Plot (Feature Importance and Effects)")
plt.tight_layout()
plt.show()

# d. Feature Dependence Plot (Scatter Plot)
# Shows how the value of a single feature affects its SHAP value (impact on prediction).
# We'll automatically select the most important feature to plot.
# First, calculate the mean absolute SHAP value for each feature.
if hasattr(shap_values_obj_test, 'abs') and hasattr(shap_values_obj_test.abs, 'mean'):
    mean_abs_shap_values = shap_values_obj_test.abs.mean(0) # Returns an Explanation object or Series
    if hasattr(mean_abs_shap_values, 'values') and shap_values_obj_test.feature_names is not None: # If it's an Explanation object
        feature_importances = pd.Series(mean_abs_shap_values.values, index=shap_values_obj_test.feature_names)
    else: # If it's already a Series (older versions or specific cases)
        feature_importances = mean_abs_shap_values

    if not feature_importances.empty:
        important_feature_name = feature_importances.idxmax() # Get the name of the most important feature
        print(f"\nFeature Dependence Plot for most important feature: {important_feature_name}")
        plt.figure(figsize=(10, 6))
        # Use the feature name to index the Explanation object
        shap.plots.scatter(shap_values_obj_test[:, important_feature_name], color=shap_values_obj_test, show=False)
        plt.title(f"SHAP Dependence Plot for {important_feature_name}\n(Color indicates interaction with another feature)")
        plt.tight_layout()
        plt.show()
    else:
        print("Could not determine the most important feature to plot for dependence.")
else:
    print("Could not calculate mean absolute SHAP values to determine the most important feature.")

# e. Force Plot
# For a single instance (matplotlib version for scripts):
if base_value is not None: # Ensure base_value was successfully retrieved
    print(f"\nForce Plot for instance {instance_idx}:")
    plt.figure(figsize=(12,4)) # Force plots are typically wide
    # shap.force_plot requires the base value, SHAP values (numpy array) for the instance,
    # and feature values (pandas Series or DataFrame row) for the instance.
    shap.force_plot(base_value,
                    shap_values_obj_test.values[instance_idx,:],
                    X_test.iloc[instance_idx,:],
                    matplotlib=True, show=False) # matplotlib=True for static plot in scripts
    plt.title(f"SHAP Force Plot (Matplotlib) for Instance {instance_idx}")
    # plt.tight_layout() # tight_layout may not always work well with force_plot's fixed aspects
    plt.show()
else:
    print("Cannot generate Force Plot because base_value is undefined.")

# To save an interactive Force Plot for multiple samples
if base_value is not None:
    print("\nAttempting to generate and save an interactive Force Plot (first 100 samples)...")
    try:
        # shap.force_plot for multiple samples generates JavaScript
        force_plot_html = shap.force_plot(base_value,
                                          shap_values_obj_test.values[:100,:], # SHAP values for first 100 samples
                                          X_test.iloc[:100,:], # Feature values for first 100 samples
                                          show=False)
        shap.save_html("force_plot_multiple_samples.html", force_plot_html)
        print("Interactive Force Plot (multiple samples) saved to 'force_plot_multiple_samples.html'")
    except Exception as e:
        print(f"Failed to generate interactive Force Plot: {e}")
else:
    print("Cannot generate multi-sample Force Plot because base_value is undefined.")

print("\n--- SHAP Explainability Demo Complete ---")