# Estimating n-Shapley Values for the kNN classifier on the Folktables Travel data set

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
sns.set_context("notebook", rc={'axes.linewidth': 2, 'grid.linewidth': 1},  font_scale=2.1)

import datasets
import nshap

### Load the data

In [None]:
_, X_test, _, _, feature_names = datasets.load_dataset('folk_travel')

### Load the pre-computed n-Shapley Values

In [None]:
n_shapley_values = nshap.load('../../results/n_shapley_values/folk_travel/knn/observation_0_proba_500.JSON')
n_shapley_values_5000 = nshap.load('../../results/n_shapley_values/folk_travel/knn/observation_0_proba_5000.JSON')

In [None]:
def vfunc(x, S):
    S = tuple(S)
    fname = f'../../results/n_shapley_values/folk_travel/knn/observation_0_proba_133549/v{S}.txt'   
    with open(fname, 'r') as f:
        result = float( f.read() )
    return result

In [None]:
n_shapley_values_133549 = nshap.n_shapley_values(X_test[0, :], vfunc)

### Plots

In [None]:
for idx, v in enumerate([n_shapley_values, n_shapley_values_5000, n_shapley_values_133549]):
    fig, ax = plt.subplots(1, 1, figsize=(7, 7.9))
    v.plot(axis=ax, legend=False, feature_names=feature_names, rotation=60)
    plt.ylim([-0.295, 0.29])
    plt.title(f'Shapey-GAM, {["500", "5000", "133549"][idx]} Samples')
    plt.tight_layout()
    plt.savefig(f'../../figures/knn_estimation_{idx}.pdf')
    plt.show()

### Latex code for Table in Appendix

In [None]:
for S in nshap.powerset(list(range(10))):
    if len(S) == 0:
        continue
    print(f'{S} & {n_shapley_values[S]:0.4f} & {n_shapley_values_5000[S]:0.4f} & {n_shapley_values_133549[S]:0.4f} \\\\')