Skip to content

Commit

Permalink
Support multi-class LightGBM output and the LightGBM sklearn interface
Browse files Browse the repository at this point in the history
Fixes #103
  • Loading branch information
slundberg committed May 29, 2018
1 parent 7f95ae8 commit 2d8d103
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 2 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run_setup(with_binary):

setup(
name='shap',
version='0.16.2',
version='0.17.0',
description='A unified approach to explain the output of any machine learning model.',
url='http://github.com/slundberg/shap',
author='Scott Lundberg',
Expand All @@ -23,7 +23,7 @@ def run_setup(with_binary):
packages=['shap', 'shap.explainers'],
install_requires=['numpy', 'scipy', 'iml>=0.6.0', 'scikit-learn', 'matplotlib', 'pandas', 'tqdm'],
test_suite='nose.collector',
tests_require=['nose', 'xgboost'],
tests_require=['nose', 'xgboost', 'lightgbm'],
ext_modules = ext_modules,
zip_safe=False
)
Expand Down
84 changes: 84 additions & 0 deletions shap/explainers/mimic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np
import multiprocessing
import sys

try:
import xgboost
except ImportError:
pass
except:
print("xgboost is installed...but failed to load!")
pass

class MimicExplainer:
"""Fits a mimic model to the original model and then explains predictions using the mimic model.
Tree SHAP allows for very fast SHAP value explainations of flexible gradient boosted decision
tree (GBDT) models. Since GBDT models are so flexible we can train them to mimic any black-box
model and then using Tree SHAP we can explain them. This won't work well for images, but for
any type of problem that GBDTs do reasonable well on, they should also be able to learn how to
explain black-box models on the data. This mimic explainer also allows you to use a linear model,
but keep in mind that will not do as well at explaining typical non-linear black-box models. In
the future we could include other mimic model types given enough demand/help. Finally, we would
like to note that this explainer is vaugely inspired by https://arxiv.org/abs/1802.07814 where
they learn an explainer that can be applied to any input.
"""

def __init__(self, model, data, mimic_model="xgboost", mimic_model_params={}):
self.mimic_model_type = mimic_model
self.mimic_model_params = mimic_model_params

# convert incoming inputs to standardized iml objects
self.link = convert_to_link(link)
self.model = convert_to_model(model)
self.keep_index = kwargs.get("keep_index", False)
self.data = convert_to_data(data, keep_index=self.keep_index)
match_model_to_data(self.model, self.data)

self.model_out = self.model.f(data.data)

# enforce our current input type limitations
assert isinstance(self.data, DenseData), "Shap explainer only supports the DenseData input currently."
assert not self.data.transposed, "Shap explainer does not support transposed DenseData currently."

# warn users about large background data sets
if len(self.data.weights) < 100:
log.warning("Using only " + str(len(self.data.weights)) + " training data samples could cause " +
"the mimic model poorly to fit the real model. Consider using more training samples " +
"or if you don't have more samples, using shap.inflate(data, N) to generate more.")

self._train_mimic_model()

def _train_mimic_model(self):

if self.mimic_model_type == "xgboost":
self.mimic_model = xgboost.train(self.mimic_model_params, xgboost.DMatrix(data.data))

def shap_values(self, X, **kwargs):
""" Estimate the SHAP values for a set of samples.
Parameters
----------
X : numpy.array or pandas.DataFrame
A matrix of samples (# samples x # features) on which to explain the model's output.
Returns
-------
For a models with a single output this returns a matrix of SHAP values
(# samples x # features + 1). The last column is the base value of the model, which is
the expected value of the model applied to the background dataset. This causes each row to
sum to the model output for that sample. For models with vector outputs this returns a list
of such matrices, one for each output.
"""

phi = None
if self.mimic_model_type == "xgboost":
if not str(type(X)).endswith("xgboost.core.DMatrix'>"):
X = xgboost.DMatrix(X)
phi = self.trees.predict(X, pred_contribs=True)

if phi is not None:
if len(phi.shape) == 3:
return [phi[:, i, :] for i in range(phi.shape[1])]
else:
return phi
8 changes: 8 additions & 0 deletions shap/explainers/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def __init__(self, model, **kwargs):
elif str(type(model)).endswith("lightgbm.basic.Booster'>"):
self.model_type = "lightgbm"
self.trees = model
elif str(type(model)).endswith("lightgbm.sklearn.LGBMRegressor'>"):
self.model_type = "lightgbm"
self.trees = model.booster_
elif str(type(model)).endswith("lightgbm.sklearn.LGBMClassifier'>"):
self.model_type = "lightgbm"
self.trees = model.booster_
elif str(type(model)).endswith("catboost.core.CatBoostRegressor'>"):
self.model_type = "catboost"
self.trees = model
Expand Down Expand Up @@ -100,6 +106,8 @@ def shap_values(self, X, **kwargs):
phi = self.trees.predict(X, pred_contribs=True)
elif self.model_type == "lightgbm":
phi = self.trees.predict(X, pred_contrib=True)
if phi.shape[1] != X.shape[1] + 1:
phi = phi.reshape(X.shape[0], phi.shape[1]//(X.shape[1]+1), X.shape[1]+1)
elif self.model_type == "catboost": # thanks to the CatBoost team for implementing this...
phi = self.trees.get_feature_importance(data=catboost.Pool(X), fstr_type='ShapValues')

Expand Down
27 changes: 27 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,30 @@ def test_mixed_types():
bst = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 1000)
shap_values = shap.TreeExplainer(bst).shap_values(X)
shap.dependence_plot(0, shap_values, X, show=False)

def test_lightgbm():
import lightgbm
import shap

# train XGBoost model
X, y = shap.datasets.boston()
model = lightgbm.sklearn.LGBMRegressor()
model.fit(X, y)

# explain the model's predictions using SHAP values
shap_values = shap.TreeExplainer(model).shap_values(X)

def test_lightgbm_multiclass():
import lightgbm
import shap

# train XGBoost model
X, Y = shap.datasets.iris()
model = lightgbm.sklearn.LGBMClassifier()
model.fit(X, Y)

# explain the model's predictions using SHAP values
shap_values = shap.TreeExplainer(model).shap_values(X)

# ensure plot works for first class
shap.dependence_plot(0, shap_values[0], X, show=False)

0 comments on commit 2d8d103

Please sign in to comment.