In [None]:
import lightgbm as lgbm
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
from gamexplainer import GamExplainer
from collections import defaultdict
from sklearn.model_selection import GridSearchCV
from math import comb
from sklearn.inspection import plot_partial_dependence
import pickle
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import accuracy_score, classification_report
%load_ext autoreload
%autoreload 2

In [None]:
plt.rcParams['text.usetex'] = False

In [None]:
!wget -O adult.csv https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data

## Read the dataset

In [None]:
col_names = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation", "relationship",
             "race", "sex", "capital-gain", "capital-loss", "hours-per-week", "native-country", "class"]
df = pd.read_csv("adult.csv", sep=",", header=None, names=col_names, index_col=False)

In [None]:
train = df.head(int(len(df) * 0.7))
test = df.tail(len(df) - len(train))
resp_var = "class"
X_train = train.drop(resp_var, axis=1)
y_train = train[resp_var]
X_test = test.drop(resp_var, axis=1)
y_test = test[resp_var]

In [None]:
X_train

In [None]:
final_cols = []
categorical_feats = ["workclass", "marital-status", "occupation", "relationship", "race", "sex", "native-country"]
to_drop = ["education"]
transformers = []
for column in X_train.columns:
    name = column
    trans = "passthrough"
    if column in categorical_feats:
        trans = OneHotEncoder()
        name = f"{column}_class"
    elif column in to_drop:
        trans = "drop"
    
    transformers.append((name, trans, [f"{column}"]))
    
    if trans != "drop":
        final_cols.append(column)
ct = ColumnTransformer(transformers, remainder="passthrough")
ct.fit(X_train)
# Encoder for the labels
le = LabelEncoder()
le.fit(y_train)

In [None]:
X_train_trans = ct.transform(X_train)
X_test_trans = ct.transform(X_test)
y_train_trans = le.transform(y_train)
y_test_trans = le.transform(y_test)

In [None]:
ct.get_feature_names_out(X_train.columns)

In [None]:
X_train_trans.shape

## Train the forest

In [None]:
lgbm_info = {}
parameters = {
    "n_estimators": np.geomspace(100, 10000, num=3, dtype=int),
    "num_leaves": np.geomspace(32, 256, num=4, dtype=int),
    "learning_rate": np.geomspace(1e-3, 1e-1, num=3)
}
CV_classifier = GridSearchCV(lgbm.LGBMClassifier(n_jobs=16), parameters, verbose=3, scoring="accuracy")
CV_classifier.fit(X_train_trans, y_train_trans)

In [None]:
print(CV_classifier.best_params_)

Best params:
{'learning_rate': 0.01, 'n_estimators': 1000, 'num_leaves': 32}

In [None]:
forest = lgbm.LGBMClassifier(learning_rate=0.01, n_estimators=1000, num_leaves=32, verbose=2)
forest.fit(X_train_trans, y_train_trans)
forest_to_explain = forest.booster_

In [None]:
print(accuracy_score(y_test_trans, forest.predict(X_test_trans)))
print(classification_report(y_test_trans, forest.predict(X_test_trans)))

## Feature selection

In [None]:
range_n_splines = range(1, 11)
range_n_inter = range(0, 9)

In [None]:
import warnings
warnings.filterwarnings('ignore')

### Create the numpy array 

In [None]:
explanation_params = {"verbose": False,
                      "sample_method": "all",
                      "classification": True,
                      "inter_max_distance": 32}

acc = np.zeros((len(range_n_splines), len(range_n_inter)))
for i, n_splines in enumerate(range_n_splines):
    explanation_params["n_spline_terms"] = n_splines
    for j, n_inter in enumerate(range_n_inter):
        if n_inter > comb(n_splines, 2):
            continue
        explanation_params["n_inter_terms"] = n_inter
        explainer = GamExplainer(**explanation_params)
        gam = explainer.explain(forest_to_explain, lam_search_space=[0.1, 1])
        print(f"Fit {n_splines=}, {n_inter=} completed")
        acc[i, j] = explainer.loss_res

In [None]:
np.save("feat_selection", acc)

### Or load it if already saved

In [None]:
acc = np.load("feat_selection.npy")

## Plot the results in a heatmap

In [None]:
dimension = (len(range_n_splines), len(range_n_inter))
mask = np.zeros(dimension)
for i, n_splines in enumerate(tqdm(range_n_splines)):
    for j, n_inter in enumerate(range_n_inter):
        if n_inter > comb(n_splines, 2):
            mask[i, j] = True
            continue

In [None]:
accuracy_df = pd.DataFrame(acc, columns=range_n_inter, index=range_n_splines)
ax = sns.heatmap(accuracy_df, annot=True, mask=mask, cmap=sns.color_palette("Blues", as_cmap=True),
                 cbar_kws={'label': 'accuarcy'})
ax.set_xlabel("Number of interaction terms used")
ax.set_ylabel("Number of splines used")

## Sampling strategy 

### Find the range to analyze

In [None]:
explanation_params = {"verbose": False,
                      "interaction_importance_method":"count_path",
                      "feat_importance_method": "gain",
                      "n_spline_terms": 6,
                      "sample_method": "all",
                      "n_spline_per_term": 50,
                      "inter_max_distance": 32,
                      "n_inter_terms": 0,
                      "n_sample_gam":int(1e5),
                      "portion_sample_test":0.3,
                      "classification": True
                      }
explainer = GamExplainer(**explanation_params)
gam = explainer.explain(forest, lam_search_space=[0.1, 1])

In [None]:
for key, value in explainer.feature_dict.items():
    if key in explainer.mif:
        print(f"{key}: {len(value)}")

In [None]:
sampling_methods = ["all", "quantile", "equal", "kmeans", "equi_size"]
range_m = range(50, 5001, 250)

### Compute it

In [None]:
explanation_params = {"verbose": False,
                      "interaction_importance_method":"count_path",
                      "feat_importance_method": "gain",
                      "n_spline_terms": 4,
                      "sample_method": "all",
                      "n_spline_per_term": 50,
                      "inter_max_distance": 64,
                      "n_inter_terms": 0,
                      "n_sample_gam":int(1e5),
                      "portion_sample_test":0.3,
                      "classification": True
                      }
acc_methods = defaultdict(list)
for m in tqdm(range_m):
    explanation_params["sample_n"] = m
    for sampling_method in sampling_methods:
        explanation_params["sample_method"] = sampling_method
        explainer = GamExplainer(**explanation_params)
        gam = explainer.explain(forest_to_explain, lam_search_space=[0.1, 1])

        acc_methods[sampling_method].append(explainer.loss_res)

In [None]:
with open('sampling_comparison.pickle', 'wb') as f:
    pickle.dump(acc_methods, f)

### Or load it

In [None]:
with open('sampling_comparison.pickle', 'rb') as f:
    acc_methods = pickle.load(f)

### Plot

In [None]:
labels = [r"\emph{All-Thresholds}", r"\emph{Quantile}", r"\emph{Equi-Width}", r"\emph{$k$-Means}", "\emph{Equi-Size}"]
colors = sns.color_palette(n_colors=len(sampling_methods))
for i, sampling_method in enumerate(sampling_methods):
    plt.plot(range_m, acc_methods[sampling_method], 'o--', color=colors[i], label=labels[i])
plt.xlabel("$K$")
plt.ylabel("Accuracy")
plt.legend()

## Splines investigation

In [None]:
explanation_params = {"verbose": False,
                      "interaction_importance_method":"count_path",
                      "feat_importance_method": "gain",
                      "n_spline_terms": 5,
                      "sample_method": "equi_size",
                      "sample_n": 2000,
                      "sample_method": "all",
                      "n_spline_per_term": 50,
                      "inter_max_distance": 64,
                      "n_inter_terms": 1,
                      "n_sample_gam":int(1e5),
                      "portion_sample_test":0.3,
                      "classification": True
                      }
explainer = GamExplainer(**explanation_params)
explainer.explain(forest_to_explain)

## With sample highlighting

In [None]:
final_cols = ct.get_feature_names_out().copy()
final_cols[14] = "MS-Married"
final_cols[47] = "CapitalGain"
final_cols[11] = "EducationNum"
final_cols[0] = "Age"

In [None]:
sample_index = 0
sample = X_train_trans[sample_index].reshape(1, -1)

In [None]:
sample

In [None]:
n_row, n_col = 2, 2

fig = plt.figure(figsize=(13, 8), tight_layout=False)

lines = []

terms = [(i, x) for i, x in enumerate(explainer.gam.terms) if not x.isintercept and not x.istensor]
terms.sort(key=lambda x: x[1].feature)
c1, c2, c3 = sns.color_palette(n_colors=3)

plot_index = 0
axes = []
for i, term in enumerate(explainer.gam.terms):
    if plot_index == 4:
        break
    if term.isintercept or term.istensor:
        continue

    ax = fig.add_subplot(n_row, n_col, plot_index + 1, sharey = axes[-1] if plot_index % n_col != 0 else None)
    plt.setp(ax.get_yticklabels(), visible=plot_index % n_col == 0)

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    print(term.feature)

    # Spline print
    grid = explainer.gam.generate_X_grid(term=i, meshgrid=term.istensor)
    pdep, confi = explainer.gam.partial_dependence(term=i, X=grid, width=0.95, meshgrid=term.istensor)

    conf_u = ax.plot(grid[:, term.feature], confi[:, 0], ls="--", c=c2, zorder=1)
    conf_l = ax.plot(grid[:, term.feature], confi[:, 1], label="95% width confidence interval", ls="--", c=c2, zorder=1)
    l1 = ax.plot(grid[:, term.feature], pdep, label="Spline learned", lw=2, c=c1, zorder=2)
    ax.set_title(final_cols[term.feature])

    # Print the sample
    """
    x_point = sample[0, term.feature]  # col vector
    y_point = explainer.gam.partial_dependence(term=i, X=sample)

    plt.vlines(x_point, ax.get_ylim()[0], y_point, linestyle="dashed", color=c3)
    plt.hlines(y_point, ax.get_xlim()[0], x_point, linestyle="dashed", color=c3)
    ax.scatter(x_point, y_point, label="Sample under investigation", color=c3, zorder=3)
    """

    plot_index += 1
    axes.append(ax)

plt.subplots_adjust(hspace=0.3)
file_out = "generators.pdf"
params = {'legend.fontsize': 18,
          'figure.figsize': (20, 5),
          'axes.titlesize': 18,
          'xtick.labelsize': 18,
          'ytick.labelsize': 18}
plt.rcParams.update(params)
plt.legend(loc='upper center', bbox_to_anchor=(-0.35, 2.7), ncol=3)
plt.savefig(file_out)

## Results with SHAP

### Load them

In [None]:
with open('shapley_values_training.pickle', 'rb') as f:
    shap_values = pickle.load(f)
with open('shap_explainer_training.pickle', 'rb') as f:
    shap_explainer = pickle.load(f)

### Or compute them

In [None]:
import shap

shap_explainer = shap.Explainer(forest, feature_names=final_cols)
shap_values = shap_explainer(X_train_trans.toarray())


In [None]:
# visualize the first prediction's explanation
shap.plots.force(shap_explainer.expected_value[1], shap_values.values[0, :, 1], matplotlib=True)

In [None]:
shap_values = shap_values[:, :, 1]

In [None]:
n_row, n_col = 2, 2

fig = plt.figure(figsize=(13, 8))

lines = []

terms = [(i, x) for i, x in enumerate(explainer.gam.terms) if not x.isintercept and not x.istensor]
terms.sort(key=lambda x: x[1].feature)
c1, c2, c3 = sns.color_palette(n_colors=3)

plot_index = 0
axes = []
for i, term in enumerate(explainer.gam.terms):
    if plot_index == 4:
        break
    if term.isintercept or term.istensor:
        continue

    ax = fig.add_subplot(n_row, n_col, plot_index + 1, sharey = axes[-1] if plot_index % n_col != 0 else None)

    # Shap scatter print
    shap.plots.scatter(shap_values[:, term.feature], ax=ax, show=False, hist=False, color=c1)
    shap_plot = ax
    
    plt.setp(ax.get_yticklabels(), visible=plot_index % n_col == 0)
    ax.tick_params(labelsize=18)

    ax.set_ylabel("")
    ax.set_xlabel("")
    ax.set_title(final_cols[term.feature])

    # Print the sample
    """
    x_point = shap_values[sample_index, term.feature].data
    y_point = shap_values[sample_index, term.feature].values

    plt.vlines(x_point, ax.get_ylim()[0], y_point, linestyle="dashed", color=c2)
    plt.hlines(y_point, ax.get_xlim()[0], x_point, linestyle="dashed", color=c2)
    sample_plot = ax.scatter(x_point, y_point, label="Sample under investigation", color=c2, zorder=3)
    """

    plot_index += 1
    axes.append(ax)
    
params = {'legend.fontsize': 18,
          'figure.figsize': (20, 5),
          'axes.titlesize': 18}

plt.rcParams.update(params)
plt.subplots_adjust(hspace=0.3)
file_out = "shap.pdf"
dummy_shap_plot = Line2D([0], [0], marker='o', color=c1, label='SHAP values', lw=0)
plt.legend(handles=[dummy_shap_plot], loc='upper center', bbox_to_anchor=(-1.0, 2.7), ncol=3, fontsize=14)
plt.savefig(file_out)