Skip to content

Commit

Permalink
Update LinearExplainer and add sentiment analysis example
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Dec 22, 2018
1 parent db6c4db commit cf56810
Show file tree
Hide file tree
Showing 6 changed files with 472 additions and 34 deletions.

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions shap/common.py
Expand Up @@ -242,6 +242,11 @@ def convert_name(ind, shap_values, feature_names):
# we allow rank based indexing using the format "rank(int)"
if ind.startswith("rank("):
return np.argsort(-np.abs(shap_values).mean(0))[int(ind[5:-1])]

# we allow the sum of all the SHAP values to be specified with "sum()"
# assuming here that the calling method can deal with this case
elif ind == "sum()":
return "sum()"
else:
print("Could not find feature named: " + ind)
return None
Expand Down
13 changes: 13 additions & 0 deletions shap/datasets.py
Expand Up @@ -34,6 +34,19 @@ def boston(display=False):
df = pd.DataFrame(data=d.data, columns=d.feature_names) # pylint: disable=E1101
return df, d.target # pylint: disable=E1101

def imdb(display=False):
""" Return the clssic IMDB sentiment analysis training data in a nice package.
Full data is at: http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Paper to cite when using the data is: http://www.aclweb.org/anthology/P11-1015
"""

with open(cache(github_data_url + "imdb_train.txt")) as f:
data = f.readlines()
y = np.ones(25000, dtype=np.bool)
y[:12500] = 0
return data, y

def communitiesandcrime(display=False):
""" Predict total number of non-violent crimes per 100K popuation.
Expand Down
57 changes: 34 additions & 23 deletions shap/explainers/linear.py
Expand Up @@ -26,17 +26,26 @@ class LinearExplainer(Explainer):
nsamples : int
Number of samples to use when estimating the transformation matrix used to account for
feature correlations.
feature_dependence : "correlation" (default) or "interventional"
feature_dependence : "independent" (default) or "correlation"
There are two ways we might want to compute SHAP values, either the full conditional SHAP
values or the interventional SHAP values. For interventional SHAP values we break any
dependence structure in the model and so uncover how the model would behave if we
values or the independent SHAP values. For independent SHAP values we break any
dependence structure between features in the model and so uncover how the model would behave if we
intervened and changed some of the inputs. For the full conditional SHAP values we respect
the correlations among the input features, so if the model depends on one input but that
input is correlated with another input, then both get some credit for the model's behavior.
input is correlated with another input, then both get some credit for the model's behavior. The
independent option stays "true to the model" meaning it will only give credit to features that are
actually used by the model, while the correlation option stays "true to the data" in the sense that
it only considers how the model would behave when respecting the correlations in the input data.
"""

def __init__(self, model, data, nsamples=1000, feature_dependence="correlation"):
def __init__(self, model, data, nsamples=1000, feature_dependence=None):
self.nsamples = nsamples
if feature_dependence == "interventional":
warnings.warn('The option feature_dependence="interventional" is has been renamed to feature_dependence="independent"!')
feature_dependence = "independent"
elif feature_dependence is None:
warnings.warn('The default value for feature_dependence has been changed to "independent"!')
feature_dependence = "independent"
self.feature_dependence = feature_dependence

# raw coefficents
Expand Down Expand Up @@ -64,22 +73,22 @@ def __init__(self, model, data, nsamples=1000, feature_dependence="correlation")
if type(data) == tuple and len(data) == 2:
self.mean = data[0]
self.cov = data[1]
elif str(type(data)).endswith("'numpy.ndarray'>"):
self.mean = data.mean(0)
self.cov = np.cov(data, rowvar=False)
elif data is None:
raise Exception("A background data distribution must be provided!")

else:
self.mean = np.array(np.mean(data, 0)).flatten() # assumes it is an array
if feature_dependence == "correlation":
self.cov = np.cov(data, rowvar=False)
#print(self.coef, self.mean.flatten(), self.intercept)
self.expected_value = np.dot(self.coef, self.mean) + self.intercept

self.M = len(self.mean)
self.valid_inds = np.where(np.diag(self.cov) > 1e-8)[0]
self.mean = self.mean[self.valid_inds]
self.cov = self.cov[:,self.valid_inds][self.valid_inds,:]
self.coef = self.coef[self.valid_inds]

# if needed, estimate the transform matrices
if feature_dependence == "correlation":
self.valid_inds = np.where(np.diag(self.cov) > 1e-8)[0]
self.mean = self.mean[self.valid_inds]
self.cov = self.cov[:,self.valid_inds][self.valid_inds,:]
self.coef = self.coef[self.valid_inds]

# group perfectly redundant variables together
self.avg_proj,sum_proj = duplicate_components(self.cov)
Expand All @@ -95,9 +104,9 @@ def __init__(self, model, data, nsamples=1000, feature_dependence="correlation")
mean_transform, x_transform = self._estimate_transforms(nsamples)
self.mean_transformed = np.matmul(mean_transform, self.mean)
self.x_transform = x_transform
elif feature_dependence == "interventional":
elif feature_dependence == "independent":
if nsamples != 1000:
warnings.warn("Setting nsamples has no effect when feature_dependence = 'interventional'!")
warnings.warn("Setting nsamples has no effect when feature_dependence = 'independent'!")
else:
raise Exception("Unknown type of feature_dependence provided: " + feature_dependence)

Expand Down Expand Up @@ -187,18 +196,20 @@ def shap_values(self, X):
elif str(type(X)).endswith("'pandas.core.frame.DataFrame'>"):
X = X.values

assert str(type(X)).endswith("'numpy.ndarray'>"), "Unknown instance type: " + str(type(X))
#assert str(type(X)).endswith("'numpy.ndarray'>"), "Unknown instance type: " + str(type(X))
assert len(X.shape) == 1 or len(X.shape) == 2, "Instance must have 1 or 2 dimensions!"

if self.feature_dependence == "correlation":
phi = np.matmul(np.matmul(X[:,self.valid_inds], self.avg_proj.T), self.x_transform.T) - self.mean_transformed
phi = np.matmul(phi, self.avg_proj)
elif self.feature_dependence == "interventional":
phi = self.coef * (X[:,self.valid_inds] - self.mean)

full_phi = np.zeros(((phi.shape[0], self.M)))
full_phi[:,self.valid_inds] = phi
return full_phi

full_phi = np.zeros(((phi.shape[0], self.M)))
full_phi[:,self.valid_inds] = phi

return full_phi

elif self.feature_dependence == "independent":
return np.array(X - self.mean) * self.coef

def duplicate_components(C):
D = np.diag(1/np.sqrt(np.diag(C)))
Expand Down
14 changes: 10 additions & 4 deletions shap/plots/dependence.py
Expand Up @@ -224,10 +224,16 @@ def dependence_plot(ind, shap_values, features, feature_names=None, display_feat
# plot any nan feature values as tick marks along the y-axis
xv_nans = np.isnan(xv)
xlim = pl.xlim()
pl.scatter(
xlim[0] * np.ones(xv_nans.sum()), s[xv_nans], marker=1,
linewidth=2, c=cvals[xv_nans], cmap=colors.red_blue, alpha=alpha
)
if interaction_index is not None:
pl.scatter(
xlim[0] * np.ones(xv_nans.sum()), s[xv_nans], marker=1,
linewidth=2, c=cvals[xv_nans], cmap=colors.red_blue, alpha=alpha
)
else:
pl.scatter(
xlim[0] * np.ones(xv_nans.sum()), s[xv_nans], marker=1,
linewidth=2, color="#1E88E5", alpha=alpha
)
pl.xlim(*xlim)

# make the plot more readable
Expand Down
20 changes: 13 additions & 7 deletions shap/plots/embedding.py
Expand Up @@ -18,7 +18,8 @@ def embedding_plot(ind, shap_values, feature_names=None, method="pca", alpha=1.0
If this is an int it is the index of the feature to use to color the embedding.
If this is a string it is either the name of the feature, or it can have the
form "rank(int)" to specify the feature with that rank (ordered by mean absolute
SHAP value over all the samples).
SHAP value over all the samples), or "sum()" to mean the sum of all the SHAP values,
which is the model's output (minus it's expected value).
shap_values : numpy.array
Matrix of SHAP values (# samples x # features).
Expand All @@ -40,24 +41,30 @@ def embedding_plot(ind, shap_values, feature_names=None, method="pca", alpha=1.0
feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1])]

ind = convert_name(ind, shap_values, feature_names)
if ind == "sum()":
cvals = shap_values.sum(1)
fname = "sum(SHAP values)"
else:
cvals = shap_values[:,ind]
fname = feature_names[ind]

# see if we need to compute the embedding
if method == "pca":
if type(method) == str and method == "pca":
pca = sklearn.decomposition.PCA(2)
embedding_values = pca.fit_transform(shap_values)
elif type(method) == np.array and method.shape[1] == 2:
elif hasattr(method, "shape") and method.shape[1] == 2:
embedding_values = method
else:
print("Unsupported embedding method:", method)

pl.scatter(
embedding_values[:,0], embedding_values[:,1], c=shap_values[:,ind],
embedding_values[:,0], embedding_values[:,1], c=cvals,
cmap=colors.red_blue_solid, alpha=alpha, linewidth=0
)
pl.axis("off")
#pl.title(feature_names[ind])
cb = pl.colorbar()
cb.set_label("SHAP value for\n"+feature_names[ind], size=13)
cb.set_label("SHAP value for\n"+fname, size=13)
cb.outline.set_visible(False)


Expand All @@ -66,5 +73,4 @@ def embedding_plot(ind, shap_values, feature_names=None, method="pca", alpha=1.0
cb.ax.set_aspect((bbox.height - 0.7) * 10)
cb.set_alpha(1)
if show:
pl.show()

pl.show()

0 comments on commit cf56810

Please sign in to comment.