/
image.py
72 lines (64 loc) · 2.99 KB
/
image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import warnings
try:
import matplotlib.pyplot as pl
except ImportError:
warnings.warn("matplotlib could not be loaded!")
pass
from . import colors
def image_plot(shap_values, x, labels=None, show=True, width=20, aspect=0.2, hspace=0.2, labelpad=None):
""" Plots SHAP values for image inputs.
"""
multi_output = True
if type(shap_values) != list:
multi_output = False
shap_values = [shap_values]
# make sure labels
if labels is not None:
assert labels.shape[0] == shap_values[0].shape[0], "Labels must have same row count as shap_values arrays!"
if multi_output:
assert labels.shape[1] == len(shap_values), "Labels must have a column for each output in shap_values!"
else:
assert len(labels.shape) == 1, "Labels must be a vector for single output shap_values."
label_kwargs = {} if labelpad is None else {'pad': labelpad}
# plot our explanations
fig_size = np.array([3 * (len(shap_values) + 1), 2.5 * (x.shape[0] + 1)])
if fig_size[0] > width:
fig_size *= width / fig_size[0]
fig, axes = pl.subplots(nrows=x.shape[0], ncols=len(shap_values) + 1, figsize=fig_size)
if len(axes.shape) == 1:
axes = axes.reshape(1,axes.size)
for row in range(x.shape[0]):
x_curr = x[row].copy()
# make sure
if len(x_curr.shape) == 3 and x_curr.shape[2] == 1:
x_curr = x_curr.reshape(x_curr.shape[:2])
if x_curr.max() > 1:
x_curr /= 255.
# get a grayscale version of the image
if len(x_curr.shape) == 3 and x_curr.shape[2] == 3:
x_curr_gray = (0.2989 * x_curr[:,:,0] + 0.5870 * x_curr[:,:,1] + 0.1140 * x_curr[:,:,2]) # rgb to gray
else:
x_curr_gray = x_curr
axes[row,0].imshow(x_curr, cmap=pl.get_cmap('gray'))
axes[row,0].axis('off')
if len(shap_values[0][row].shape) == 2:
abs_vals = np.stack([np.abs(shap_values[i]) for i in range(len(shap_values))], 0).flatten()
else:
abs_vals = np.stack([np.abs(shap_values[i].sum(-1)) for i in range(len(shap_values))], 0).flatten()
max_val = np.nanpercentile(abs_vals, 99.9)
for i in range(len(shap_values)):
if labels is not None:
axes[row,i+1].set_title(labels[row,i], **label_kwargs)
sv = shap_values[i][row] if len(shap_values[i][row].shape) == 2 else shap_values[i][row].sum(-1)
axes[row,i+1].imshow(x_curr_gray, cmap=pl.get_cmap('gray'), alpha=0.15, extent=(-1, sv.shape[0], sv.shape[1], -1))
im = axes[row,i+1].imshow(sv, cmap=colors.red_transparent_blue, vmin=-max_val, vmax=max_val)
axes[row,i+1].axis('off')
if hspace == 'auto':
fig.tight_layout()
else:
fig.subplots_adjust(hspace=hspace)
cb = fig.colorbar(im, ax=np.ravel(axes).tolist(), label="SHAP value", orientation="horizontal", aspect=fig_size[0]/aspect)
cb.outline.set_visible(False)
if show:
pl.show()