## SHAP Plots

#### Compute SHAP Values

In [1]:
import shap
from sklearn.svm import OneClassSVM
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

np.random.seed(0)
#path = '../results/2d stereo model performance/'
#path = '../results/2d single model performance/'
path = '../results/2d single far model performance/'

shap_train = np.load(path + 'shap_train.npy', allow_pickle=True)
shap_test = np.load(path + 'shap_test.npy', allow_pickle=True)

columns = shap_train[0].columns

folds = [0, 22]

shap_values = []
shap_data = []


for fold in folds:
	print(f' ---- COMPUTING FOLD NUMBER: {fold} -----')

	clf = OneClassSVM(nu=0.25, kernel='rbf', gamma=((1/49)**2), degree=3) #gamma=((1/48)**2), degree=3)
	clf.fit(shap_train[fold])
	data = shap_test[fold].sample(frac=(1/3))
	explainer = shap.Explainer(clf.predict, data)
	shap_values.append(explainer(data))
	shap_data.append(data)

shap_values = np.array(shap_values, dtype=object)
shap_data = np.array(shap_data, dtype=object)

  from .autonotebook import tqdm as notebook_tqdm


 ---- COMPUTING FOLD NUMBER: 0 -----


Permutation explainer: 563it [16:55,  1.81s/it]                           


 ---- COMPUTING FOLD NUMBER: 22 -----


Permutation explainer: 365it [14:34,  2.43s/it]                         


#### Summary Plots

In [None]:
plt.rcParams.update({'font.size': 9})
for index, shap_value in enumerate(shap_values):
	fig = plt.figure(figsize=(20,15), dpi=400)
	ax = fig.add_subplot(1,1,1)

	shap.summary_plot(shap_value, show=False, max_display=80, plot_size=[20,15], alpha=0.7)
	if index == 0:
		plt.title('SHAP Summary Plot - Best performing fold')
	else:
		plt.title('SHAP Summary Plot - Worst performing fold')
	plt.show()

#### Bar Plots

In [None]:
plt.rcParams.update({'font.size': 9})
for index, shap_value in enumerate(shap_values):
	fig = plt.figure(figsize=(12,8), dpi=400)
	ax = fig.add_subplot(1,1,1)

	shap.plots.bar(shap_value, max_display=80, show=False)

	if index == 0:
		plt.title('SHAP Bar Plot - Best performing fold')
	else:
		plt.title('SHAP Bar Plot - Worst performing fold')
	plt.show()


#### Violin Plots

In [None]:
plt.rcParams.update({'font.size': 9})
for index, shap_value in enumerate(shap_values):
	fig = plt.figure(figsize=(20, 15), dpi=400)
	ax = fig.add_subplot(1,1,1)

	shap.summary_plot(shap_value, show=False, plot_size=[20,15], plot_type='violin', max_display=80)

	if index == 0:
		plt.title('SHAP Violin Plot - Best performing fold')
	else:
		plt.title('SHAP Violin Plot - Worst performing fold')
	plt.show()

In [2]:
plt.rcParams.update({'font.size': 9})

for index, shap_value in enumerate(shap_values):
    if index == 0:
        for col in sorted(columns):
            plt.figure()

            shap.plots.scatter(shap_value[:,col], color=shap_value[:,col], alpha=0.7, show=False)
            plt.plot([0 for _ in range(2)], [min(shap_value[:,col].values), max(shap_value[:,col].values)], 'k', linewidth=0.5)
            plt.plot([min(shap_value[:,col].data), max(shap_value[:,col].data)], [0 for _ in range(2)], 'k', linewidth=0.5)
            plt.savefig(path + f'best fold\\{col}.png')
            plt.close()





<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>