In [None]:
import shap
import sklearn

In [None]:
X_adult, y_adult = shap.datasets.adult()

In [None]:
model_adult = sklearn.linear_model.LogisticRegression(max_iter=10000)
model_adult.fit(X_adult, y_adult)

In [None]:
background_adult = shap.maskers.Independent(X_adult, max_samples=100)

In [None]:
X_adult.dtypes

In [None]:
def _get_background_example(df):
    median = df.median(numeric_only=True)
    background_sample = df.mode(dropna=False).head(1)

    # Use median of the numerical features.
    numerical_features = [feature for feature in list(df.columns) if df[feature].dtype == float]
    for feature in numerical_features:
        background_sample[feature] = median[feature]

    background_sample = background_sample.astype(df.dtypes)
    return background_sample


def model_adult_log_odds(x):
    p = model_adult.predict_log_proba(x)
    return p[:, 1] - p[:, 0]


def prediction_function(x):
    return model_adult.predict_proba(x)[:, 1]


explainer = shap.KernelExplainer(
    model=prediction_function,
    data=_get_background_example(X_adult),
    feature_names=X_adult.columns,
    link="logit",
    keep_index=True,
)

In [None]:
shap_values_adult = explainer(X_adult[:1000])

In [None]:
shap.plots.scatter(shap_values_adult[:, "Age"])

In [None]:
shap.plots.waterfall(shap_values_adult[0])

In [None]:
model_adult_log_odds(X_adult.iloc[0].values.reshape(1, -1))

In [None]:
shap.plots.beeswarm(shap_values_adult)

In [None]:
shap.plots.bar(shap_values_adult)

In [None]:
from sklearn.metrics import classification_report, roc_auc_score, f1_score

print(classification_report(y_adult, model_adult.predict(X_adult)))

In [None]:
print(roc_auc_score(y_adult, model_adult.predict_proba(X_adult)[:, 1]))

In [None]:
print(f1_score(y_adult, model_adult.predict(X_adult)))

In [None]:
model_adult.predict_proba(X_adult)[:, 1]

In [None]:
from src.explanation.local.shap_explainer import ShapExplainer

In [None]:
X_adult

In [None]:
shap_explainer = ShapExplainer(prediction_function, X_adult)

In [None]:
shap_explanation = shap_explainer.get_shap_explanation(X_adult[:1000])

In [None]:
shap.plots.scatter(shap_explanation[:, "Age"])

In [None]:
shap.plots.waterfall(shap_explanation[0])

In [None]:
vars(shap_explanation[0])

In [None]:
import numpy as np

np.abs(shap_explanation[:, "Relationship"].values).mean()

In [None]:
res = shap_explainer.get_global_explanation(X_adult[:1000], normalize=True)

In [None]:
res