Skip to content

Commit

Permalink
Merge pull request #645 from imatiach-msft/ilmat/linear-sparse
Browse files Browse the repository at this point in the history
adding sparse support to shap linear explainer
  • Loading branch information
slundberg committed Jun 18, 2019
2 parents 9866ae7 + dc309d1 commit 05aa606
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
42 changes: 32 additions & 10 deletions shap/explainers/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import scipy as sp
import warnings
from tqdm.autonotebook import tqdm
from .explainer import Explainer
Expand All @@ -19,7 +20,7 @@ class LinearExplainer(Explainer):
model : (coef, intercept) or sklearn.linear_model.*
User supplied linear model either as either a parameter pair or sklearn object.
data : (mean, cov), numpy.array, pandas.DataFrame, or iml.DenseData
data : (mean, cov), numpy.array, pandas.DataFrame, iml.DenseData or scipy.csr_matrix
The background dataset to use for computing conditional expectations. Note that only the
mean and covariance of the dataset are used. This means passing a raw data matrix is just
a convienent alternative to passing the mean and covariance directly.
Expand All @@ -35,7 +36,8 @@ class LinearExplainer(Explainer):
input is correlated with another input, then both get some credit for the model's behavior. The
independent option stays "true to the model" meaning it will only give credit to features that are
actually used by the model, while the correlation option stays "true to the data" in the sense that
it only considers how the model would behave when respecting the correlations in the input data.
it only considers how the model would behave when respecting the correlations in the input data.
For sparse case only independent option is supported.
"""

def __init__(self, model, data, nsamples=1000, feature_dependence=None):
Expand Down Expand Up @@ -76,11 +78,23 @@ def __init__(self, model, data, nsamples=1000, feature_dependence=None):
elif data is None:
raise Exception("A background data distribution must be provided!")
else:
self.mean = np.array(np.mean(data, 0)).flatten() # assumes it is an array
if feature_dependence == "correlation":
self.cov = np.cov(data, rowvar=False)
if sp.sparse.issparse(data):
self.mean = data.mean(0)
if feature_dependence != "independent":
raise Exception("Only feature_dependence = 'independent' is supported for sparse data")
else:
self.mean = np.array(np.mean(data, 0)).flatten() # assumes it is an array
if feature_dependence == "correlation":
self.cov = np.cov(data, rowvar=False)
#print(self.coef, self.mean.flatten(), self.intercept)
self.expected_value = np.dot(self.coef, self.mean) + self.intercept
# Note: mean can be numpy.matrixlib.defmatrix.matrix or numpy.matrix type depending on numpy version
if sp.sparse.issparse(self.mean) or str(type(self.mean)).endswith("matrix'>"):
# accept both sparse and dense coef
if not sp.sparse.issparse(self.coef):
self.coef = np.asmatrix(self.coef)
self.expected_value = self.coef.dot(self.mean.T) + self.intercept
else:
self.expected_value = np.dot(self.coef, self.mean) + self.intercept
self.M = len(self.mean)

# if needed, estimate the transform matrices
Expand Down Expand Up @@ -179,7 +193,7 @@ def shap_values(self, X):
Parameters
----------
X : numpy.array or pandas.DataFrame
X : numpy.array, pandas.DataFrame or scipy.csr_matrix
A matrix of samples (# samples x # features) on which to explain the model's output.
Returns
Expand All @@ -200,6 +214,8 @@ def shap_values(self, X):
assert len(X.shape) == 1 or len(X.shape) == 2, "Instance must have 1 or 2 dimensions!"

if self.feature_dependence == "correlation":
if sp.sparse.issparse(X):
raise Exception("Only feature_dependence = 'independent' is supported for sparse data")
phi = np.matmul(np.matmul(X[:,self.valid_inds], self.avg_proj.T), self.x_transform.T) - self.mean_transformed
phi = np.matmul(phi, self.avg_proj)

Expand All @@ -209,10 +225,16 @@ def shap_values(self, X):
return full_phi

elif self.feature_dependence == "independent":
if len(self.coef.shape) == 1:
return np.array(X - self.mean) * self.coef
if sp.sparse.issparse(X):
if len(self.coef.shape) == 1:
return np.array(np.multiply(X - self.mean, self.coef[0]))
else:
return [np.array(np.multiply(X - self.mean, self.coef[i])) for i in range(self.coef.shape[0])]
else:
return [np.array(X - self.mean) * self.coef[i] for i in range(self.coef.shape[0])]
if len(self.coef.shape) == 1:
return np.array(X - self.mean) * self.coef
else:
return [np.array(X - self.mean) * self.coef[i] for i in range(self.coef.shape[0])]

def duplicate_components(C):
D = np.diag(1/np.sqrt(np.diag(C)))
Expand Down
26 changes: 25 additions & 1 deletion tests/explainers/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,28 @@ def test_single_feature():
explainer = shap.LinearExplainer(model, X)
shap_values = explainer.shap_values(X)
assert np.abs(explainer.expected_value - model.predict(X).mean()) < 1e-6
assert np.max(np.abs(explainer.expected_value + shap_values.sum(1) - model.predict(X))) < 1e-6
assert np.max(np.abs(explainer.expected_value + shap_values.sum(1) - model.predict(X))) < 1e-6

def test_sparse():
""" Validate running LinearExplainer on scipy sparse data
"""
import sklearn.linear_model
from sklearn.datasets import make_multilabel_classification
from scipy.special import expit

np.random.seed(0)
n_features = 20
X, y = make_multilabel_classification(n_samples=100,
sparse=True,
n_features=n_features,
n_classes=1,
n_labels=2)

# train linear model
model = sklearn.linear_model.LogisticRegression()
model.fit(X, y)

# explain the model's predictions using SHAP values
explainer = shap.LinearExplainer(model, X)
shap_values = explainer.shap_values(X)
assert np.max(np.abs(expit(explainer.expected_value + shap_values[0].sum(1)) - model.predict_proba(X)[:, 1])) < 1e-6

0 comments on commit 05aa606

Please sign in to comment.