-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support multi-class LightGBM output and the LightGBM sklearn interface
Fixes #103
- Loading branch information
Showing
4 changed files
with
121 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import numpy as np | ||
import multiprocessing | ||
import sys | ||
|
||
try: | ||
import xgboost | ||
except ImportError: | ||
pass | ||
except: | ||
print("xgboost is installed...but failed to load!") | ||
pass | ||
|
||
class MimicExplainer: | ||
"""Fits a mimic model to the original model and then explains predictions using the mimic model. | ||
Tree SHAP allows for very fast SHAP value explainations of flexible gradient boosted decision | ||
tree (GBDT) models. Since GBDT models are so flexible we can train them to mimic any black-box | ||
model and then using Tree SHAP we can explain them. This won't work well for images, but for | ||
any type of problem that GBDTs do reasonable well on, they should also be able to learn how to | ||
explain black-box models on the data. This mimic explainer also allows you to use a linear model, | ||
but keep in mind that will not do as well at explaining typical non-linear black-box models. In | ||
the future we could include other mimic model types given enough demand/help. Finally, we would | ||
like to note that this explainer is vaugely inspired by https://arxiv.org/abs/1802.07814 where | ||
they learn an explainer that can be applied to any input. | ||
""" | ||
|
||
def __init__(self, model, data, mimic_model="xgboost", mimic_model_params={}): | ||
self.mimic_model_type = mimic_model | ||
self.mimic_model_params = mimic_model_params | ||
|
||
# convert incoming inputs to standardized iml objects | ||
self.link = convert_to_link(link) | ||
self.model = convert_to_model(model) | ||
self.keep_index = kwargs.get("keep_index", False) | ||
self.data = convert_to_data(data, keep_index=self.keep_index) | ||
match_model_to_data(self.model, self.data) | ||
|
||
self.model_out = self.model.f(data.data) | ||
|
||
# enforce our current input type limitations | ||
assert isinstance(self.data, DenseData), "Shap explainer only supports the DenseData input currently." | ||
assert not self.data.transposed, "Shap explainer does not support transposed DenseData currently." | ||
|
||
# warn users about large background data sets | ||
if len(self.data.weights) < 100: | ||
log.warning("Using only " + str(len(self.data.weights)) + " training data samples could cause " + | ||
"the mimic model poorly to fit the real model. Consider using more training samples " + | ||
"or if you don't have more samples, using shap.inflate(data, N) to generate more.") | ||
|
||
self._train_mimic_model() | ||
|
||
def _train_mimic_model(self): | ||
|
||
if self.mimic_model_type == "xgboost": | ||
self.mimic_model = xgboost.train(self.mimic_model_params, xgboost.DMatrix(data.data)) | ||
|
||
def shap_values(self, X, **kwargs): | ||
""" Estimate the SHAP values for a set of samples. | ||
Parameters | ||
---------- | ||
X : numpy.array or pandas.DataFrame | ||
A matrix of samples (# samples x # features) on which to explain the model's output. | ||
Returns | ||
------- | ||
For a models with a single output this returns a matrix of SHAP values | ||
(# samples x # features + 1). The last column is the base value of the model, which is | ||
the expected value of the model applied to the background dataset. This causes each row to | ||
sum to the model output for that sample. For models with vector outputs this returns a list | ||
of such matrices, one for each output. | ||
""" | ||
|
||
phi = None | ||
if self.mimic_model_type == "xgboost": | ||
if not str(type(X)).endswith("xgboost.core.DMatrix'>"): | ||
X = xgboost.DMatrix(X) | ||
phi = self.trees.predict(X, pred_contribs=True) | ||
|
||
if phi is not None: | ||
if len(phi.shape) == 3: | ||
return [phi[:, i, :] for i in range(phi.shape[1])] | ||
else: | ||
return phi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters