In [None]:
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
from math import comb
import pickle
from gamexplainer.utils import plot_local_all_terms
import matplotlib.pyplot as plt
import shap
import lime
import lime.lime_tabular
import lightgbm as lgbm
%load_ext autoreload
%autoreload 2

In [None]:
plt.rcParams['text.usetex'] = True

### Feature selection

To compute the results run:
python forest_train.py
python feat_selection.py

Load

In [None]:
acc = np.load("precomputed/feat_selection_superconduct.npy")

## Plot the results in a heatmap

In [None]:
range_n_splines = range(1, 11) # copy and pasted ranges
range_n_inter = range(0, 9) # copy and pasted ranges
dimension = (len(range_n_splines), len(range_n_inter))
mask = np.zeros(dimension)
for i, n_splines in enumerate(tqdm(range_n_splines)):
    for j, n_inter in enumerate(range_n_inter):
        if n_inter > comb(n_splines, 2):
            mask[i, j] = True 
            continue

In [None]:
accuracy_df = pd.DataFrame(acc, columns=range_n_inter, index=range_n_splines)
ax = sns.heatmap(accuracy_df, annot=True, mask = mask, cmap=sns.color_palette("Blues", as_cmap=True), cbar_kws={'label': 'RMSE'})
ax.set_xlabel("Number of interaction terms used")
ax.set_ylabel("Number of splines used")
file_out = "plots/heatmap_splines_inter.pdf"
plt.savefig(file_out)

## Sampling strategy 

### Analyze the maximum number of splits per feature

To replicate the experiments run:
```
python forest_train.py
python sampling_analysis.py
```

### Setup

In [None]:
sampling_methods = ["all", "quantile", "equal", "kmeans", "equi_size"]
range_m = range(50, 17000, 750)
with open('precomputed/sampling_comparison.pickle', 'rb') as f:
    acc_methods = pickle.load(f)

### Plot

In [None]:
labels = [r"\emph{All-Thresholds}", r"\emph{Quantile}", r"\emph{Equi-Width}", r"\emph{$k$-Means}", "\emph{Equi-Size}"]
colors = sns.color_palette(n_colors=len(sampling_methods))
for i, sampling_method in enumerate(sampling_methods):
    plt.plot(range_m, acc_methods[sampling_method], 'o--', color=colors[i], label=labels[i])
plt.xlabel("$K$")
plt.ylabel("RMSE")
plt.legend()
plt.grid(visible=True)
file_out = "plots/sampling_comparison.pdf"
plt.savefig(file_out)

## Global explanation with GEF

To replicate the results run:
```
python forest_train.py
python final_explainer.py
```

In [None]:
with open("precomputed/explainer.pickle", "rb") as f:
    explainer = pickle.load(f)

In [None]:
df = pd.read_csv("train.csv", sep=",")
train = df.head(int(len(df) * 0.7))
test = df.tail(len(df) - len(train))
resp_var = "critical_temp"
X_train = train.drop(resp_var, axis=1)
y_train = train[resp_var]
X_test = test.drop(resp_var, axis=1)
y_test = test[resp_var]

## With sample highlighting

In [None]:
feature_names_display = {i: feat for i, feat in enumerate(X_train.columns)}

In [None]:
feature_names_display[6] = "WEAM"
feature_names_display[62] = "WMTC"
feature_names_display[70] = "WSTC"
feature_names_display[76] = "WEV"
feature_names_display[74] = "WGV"
feature_names_display[9] = "SAM"
feature_names_display[33] = "GMD"
feature_names_display[64] = "WGTC"
feature_names_display[44] = "WGEA"
feature_names_display[72] = "WMV"
feature_names_display[27] = "RAR"
feature_names_display[80] = "WSV"

In [None]:
sample_index = 0
sample = X_train.iloc[sample_index].values.reshape(1, -1)

In [None]:
n_row, n_col = 2, 3

fig = plt.figure(figsize=(13, 10), tight_layout=False)

lines = []

terms = [(i, x) for i, x in enumerate(explainer.gam.terms) if not x.isintercept and not x.istensor]
terms.sort(key=lambda x: x[1].feature)
c1, c2, c3 = sns.color_palette(n_colors=3)

plot_index = 0
axes = []
points = []
for i, term in enumerate(explainer.gam.terms):
    if i == 6:
        break
    if term.isintercept or term.istensor:
        continue
    
    ax = fig.add_subplot(n_row, n_col, plot_index + 1, sharey = axes[-1] if plot_index % n_col != 0 else None)

    plt.setp(ax.get_yticklabels(), visible=plot_index % n_col == 0)
    
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
        
    
    # Spline print
    grid = explainer.gam.generate_X_grid(term=i, meshgrid=term.istensor)
    pdep, confi = explainer.gam.partial_dependence(term=i, X=grid, width=0.95, meshgrid=term.istensor)
  
    conf_u = ax.plot(grid[:, term.feature], confi[:,0], ls="--", c=c2, zorder=1)
    conf_l = ax.plot(grid[:, term.feature], confi[:,1], label="95% width confidence interval", ls="--", c=c2, zorder=1)
    l1 = ax.plot(grid[:, term.feature], pdep, label="Spline learned", lw=2, c=c1, zorder=2)
    ax.set_title(feature_names_display[term.feature])
    
    # Compute and save the points
    x_point = sample[0, term.feature] # col vector
    y_point = explainer.gam.partial_dependence(term=i, X=sample)
    
    points.append((x_point, y_point))
  
    
    
    plot_index +=1
    axes.append(ax)

# Find the minimum y for each row
min_y_rows = (axes[0].get_ylim()[0], axes[3].get_ylim()[0])
for i in range(3):
    min_y_rows = (min(min_y_rows[0], axes[i].get_ylim()[0]), min(min_y_rows[0], axes[i + 3].get_ylim()[0]))

# Plot the points
for i, (ax, (x_point, y_point)) in enumerate(zip(axes, points)):
    ax.vlines(x_point, min_y_rows[i // 3], y_point, linestyle="dashed", color="black")
    ax.hlines(y_point, ax.get_xlim()[0], x_point, linestyle="dashed", color="black")
    ax.scatter(x_point, y_point, label="Sample under investigation", color="black", zorder=3)


params = {'legend.fontsize': 18,
          'figure.figsize': (20, 5),
          'axes.titlesize': 18,
          'xtick.labelsize': 18,
          'ytick.labelsize': 18}
plt.rcParams.update(params)

file_out = "plots/global_gef.pdf"
plt.legend(loc='upper center', bbox_to_anchor=(-0.7, 2.5), ncol=3)
plt.savefig(file_out)

## Global SHAP

### Setup

In [None]:
with open('precomputed/shap_values_training.pickle', 'rb') as f:
    shap_values = pickle.load(f)
with open('precomputed/shap_explainer_training.pickle', 'rb') as f:
    shap_explainer = pickle.load(f)

## Plots

In [None]:
# visualize the first prediction's explanation
shap.plots.force(shap_explainer.expected_value, shap_values.values[0,:], matplotlib=True)

In [None]:
shap_values = shap_values[:, :]

In [None]:
n_row, n_col = 2, 3

fig = plt.figure(figsize=(13, 10))

lines = []

terms = [(i, x) for i, x in enumerate(explainer.gam.terms) if not x.isintercept and not x.istensor]
terms.sort(key=lambda x: x[1].feature)
c1, c2, c3 = sns.color_palette(n_colors=3)

plot_index = 0
axes = []
points = []
for i, term in enumerate(explainer.gam.terms):
    if i == 6:
        break
    if term.isintercept or term.istensor:
        continue
    
    ax = fig.add_subplot(n_row, n_col, plot_index + 1, sharey = axes[-1] if plot_index % n_col != 0 else None)
    
    # Shap scatter print
    shap.plots.scatter(shap_values[:,term.feature], ax=ax, show=False, hist=False, color=c1)
    shap_plot = ax
    
    plt.setp(ax.get_yticklabels(), visible=plot_index % n_col == 0)
    
    ax.set_ylabel("")
    ax.set_xlabel("")
    ax.tick_params(labelsize=18)
    ax.set_title(feature_names_display[term.feature])
    
     # Compute and save the points
    x_point = sample[0, term.feature] # col vector
    y_point = explainer.gam.partial_dependence(term=i, X=sample)
    
    points.append((x_point, y_point))

    plot_index +=1
    axes.append(ax)
    
# Find the minimum y for each row
min_y_rows = (axes[0].get_ylim()[0], axes[3].get_ylim()[0])
for i in range(3):
    min_y_rows = (min(min_y_rows[0], axes[i].get_ylim()[0]), min(min_y_rows[0], axes[i + 3].get_ylim()[0]))

# Plot the points
for i, (ax, (x_point, y_point)) in enumerate(zip(axes, points)):
    ax.vlines(x_point, min_y_rows[i // 3], y_point, linestyle="dashed", color="black")
    ax.hlines(y_point, ax.get_xlim()[0], x_point, linestyle="dashed", color="black")
    sample_plot = ax.scatter(x_point, y_point, label="Sample under investigation", color="black", zorder=3)
    

params = {'legend.fontsize': 18,
          'figure.figsize': (20, 5),
          'axes.titlesize': 18}

plt.rcParams.update(params)
#plt.subplots_adjust(hspace=0.3)
file_out = "plots/global_shap.pdf"
dummy_shap_plot = Line2D([0], [0], marker='o', color=c1, label='SHAP values', lw=0)
plt.legend(handles = [dummy_shap_plot, sample_plot], loc='upper center', bbox_to_anchor=(-1.47, 2.5), ncol=3, fontsize=14)
plt.savefig(file_out)

## Local explanations

### SHAP

In [None]:
feature_names_display_local = [feature_names_display[i].replace("_", "\_") for i in range(len(feature_names_display))]

In [None]:
shap_values.feature_names = feature_names_display_local

In [None]:
plt.figure()
shap.plots.waterfall(shap_values[sample_index], max_display=7, show=False)
plt.rcParams.update({'font.size': 14})
plt.tight_layout()
file_out = "plots/local_shap.pdf"
plt.savefig(file_out)

### GEF

In [None]:
plot_local_all_terms(explainer.gam, feature_names_display, X_train.values, sample_index, range_perc = 20, figsize=(9, 15))
file_out = "plots/local_gef.pdf"
plt.savefig(file_out)

### LIME

In [None]:
explainer = lime.lime_tabular.LimeTabularExplainer(X_train.values, 
                                                   feature_names=feature_names_display_local, 
                                                   class_names=['critical_temp'],
                                                   verbose=True, 
                                                   mode='regression')

In [None]:
rf = lgbm.Booster(model_file="precomputed/forest.lgbm")

In [None]:
exp = explainer.explain_instance(X_train.values[sample_index], rf.predict, num_features=5)

In [None]:
# Lime has some issues with Latex, the greater and equal symbol is not shown correctly
plt.rcParams['text.usetex'] = False

In [None]:
exp.as_pyplot_figure()
plt.savefig("plots/local_lime.pdf", bbox_inches="tight")

In [None]:
exp.show_in_notebook()

In [None]:
exp.save_to_file("plots/local_lime_html.html")

In [None]:
# Lime has some issues with Latex, the greater and equal symbol is not shown correctly
plt.rcParams['text.usetex'] = False

In [None]:
exp.as_pyplot_figure()
plt.savefig("plots/local_lime.pdf", bbox_inches="tight")

In [None]:
exp.show_in_notebook()

In [None]:
exp.save_to_file("plots/local_lime_html.html")