In [None]:
import shap
import lime
import sklearn
from lime import lime_tabular

In [None]:
X_adult, y_adult = shap.datasets.adult()

In [None]:
model_adult = sklearn.linear_model.LogisticRegression(max_iter=10000)
model_adult.fit(X_adult, y_adult)

In [None]:
model_adult.classes_.tolist()

In [None]:
explainer = lime_tabular.LimeTabularExplainer(
    training_data=X_adult.values,
    feature_names=X_adult.columns,
    class_names=model_adult.classes_,
    discretize_continuous=False,
)

In [None]:
importance = {i: [] for i in range(X_adult.shape[1])}

In [None]:
exp = explainer.explain_instance(
    X_adult.iloc[0], model_adult.predict_proba, labels=[0, 1], num_features=X_adult.shape[1]
)

In [None]:
for k, v in exp.as_list():
    print(k, " - ", v)

In [None]:
from collections import defaultdict

import numpy as np
from tqdm import tqdm


n_samples = 1000
global_explanation = defaultdict(list)

for i in tqdm(range(n_samples)):
    sample_explanation = explainer.explain_instance(
        X_adult.iloc[i], model_adult.predict_proba, num_features=X_adult.shape[1]
    ).as_list()

    for k, v in sample_explanation:
        global_explanation[k].append(np.abs(v))

In [None]:
global_explanation = {k: np.mean(v) for k, v in global_explanation.items()}

In [None]:
global_explanation

In [None]:
import seaborn as sns


sns.barplot(data=global_explanation, orient="h")

In [None]:
import warnings

from src.explanation.local.lime_explainer import LimeExplainer

warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

In [None]:
lime_explainer = LimeExplainer(model_adult.predict_proba, X_adult)

In [None]:
res = lime_explainer.get_global_explanation(X_adult.iloc[:1000])

In [None]:
res

In [None]:
import seaborn as sns


sns.barplot(data=res, orient="h")

In [None]:
from src.explanation.local.utils import plot_bar, plot_scatter

In [None]:
plot_bar(data=res, orient="h")

In [None]:
from collections import defaultdict


explanations = defaultdict(list)

for i in range(1000):
    exp = lime_explainer.get_lime_explanation(X_adult.iloc[i]).as_list()

    for k, v in exp:
        explanations[k].append(v)

In [None]:
explanations.keys()

In [None]:
X_adult.columns

In [None]:
col = "Age"
plot_scatter(X_adult.loc[:999, col], explanations[col], feature_name=col, title="LIME")