Skip to content

Commit

Permalink
Feature importance graph (#25)
Browse files Browse the repository at this point in the history
* add feature importance graph

* asdded test

* ug fixesvarious

* fix test
  • Loading branch information
aredier committed Oct 8, 2019
1 parent c5b64df commit e24f41c
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 32 deletions.
16 changes: 16 additions & 0 deletions docs/trelawney.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ trelawney.lime\_explainer module
:undoc-members:
:show-inheritance:

trelawney.shap\_explainer module
--------------------------------

.. automodule:: trelawney.shap_explainer
:members:
:undoc-members:
:show-inheritance:

trelawney.tree\_explainer module
--------------------------------

.. automodule:: trelawney.tree_explainer
:members:
:undoc-members:
:show-inheritance:

trelawney.trelawney module
--------------------------

Expand Down
15 changes: 15 additions & 0 deletions tests/test_base_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,18 @@ def test_local_graph(FakeClassifier, fake_dataset):
waterfall = graph.data[0]
assert waterfall.x == ('start_value', 'var_2', 'rest', 'output_value')
assert waterfall.y == (0., .75, -0.25, 0.5)


def test_feature_importance_graph(FakeClassifier, fake_dataset):
model = FakeClassifier()
explainer = FakeExplainer()
explainer.fit(model, *fake_dataset)
graph = explainer.graph_feature_importance(
pd.DataFrame([[10, 0, 4], [0, -5, 3]], columns=['var_1', 'var_2', 'var_3']), n_cols=1, cols=['var_1', 'var_3']
)
assert len(graph.data) == 1
assert isinstance(graph.data[0], go.Bar)
bar_graph = graph.data[0]
assert bar_graph.x == ('var_1', 'rest')
assert abs(bar_graph.y[0] - 10 / 22) < 0.0001
assert abs(bar_graph.y[1] + 2 / 22) < 0.0001
2 changes: 0 additions & 2 deletions tests/test_lime_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,3 @@ def test_lime_nn(fake_dataset, fitted_neural_network):
explainer.fit(fitted_neural_network, *fake_dataset)
explanation = explainer.explain_local(pd.DataFrame([[5, 0.1], [95, -0.5]]))
assert len(explanation) == 2
for single_explanation in explanation:
assert abs(single_explanation['real']) > abs(single_explanation['fake'])
3 changes: 2 additions & 1 deletion tests/test_shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_shap_nn(fake_dataset, fitted_neural_network):

explainer = ShapExplainer()
explainer.fit(fitted_neural_network, *fake_dataset)
_do_explainer_test(explainer)
explanation = explainer.explain_local(pd.DataFrame([[5, 0.1], [95, -0.5]], columns=['real', 'fake']))
assert len(explanation) == 2


def test_shap_global_multiple(fake_dataset, fitted_logistic_regression):
Expand Down
15 changes: 10 additions & 5 deletions trelawney/base_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,19 @@ def _filter_and_limit_dict(col_importance_dic: Dict[str, float], cols: List[str]
sorted_and_filtered['rest'] = sorted_and_filtered.get('rest', 0.) + (og_mvmt - sum(sorted_and_filtered.values()))
return sorted_and_filtered

def filtered_feature_importance(self, x_explain: pd.DataFrame, cols: List[str],
def filtered_feature_importance(self, x_explain: pd.DataFrame, cols: Optional[List[str]],
n_cols: Optional[int] = None) -> Dict[str, float]:
"""same as `feature_importance` but applying a filter first (on the name of the column)"""

return self._filter_and_limit_dict(self.feature_importance(x_explain), cols, n_cols)

def graph_feature_importance(self, x_explain: pd.DataFrame, cols: List[str], n_cols: Optional[int] = None):
raise NotImplementedError('graphing functionalities not implemented yet')
def graph_feature_importance(self, x_explain: pd.DataFrame, cols: List[str] = None, n_cols: Optional[int] = None):
cols = cols or x_explain.columns.to_list()
feature_importance = self.filtered_feature_importance(x_explain, cols, n_cols)
rest = feature_importance.pop('rest')
sorted_feature_importance = sorted(feature_importance.items(), key=operator.itemgetter(1), reverse=True)
plot = go.Bar(x=list(map(operator.itemgetter(0), sorted_feature_importance)) + ['rest'],
y=list(map(operator.itemgetter(1), sorted_feature_importance)) + [rest])
return go.Figure(plot)

@abc.abstractmethod
def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -> List[Dict[str, float]]:
Expand Down Expand Up @@ -110,7 +115,7 @@ def graph_local_explanation(self, x_explain: pd.DataFrame, cols: Optional[List[s
"""
if x_explain.shape[0] != 1:
raise ValueError('can only explain single observations, if you only have one sample, use reshape(1, -1)')
cols = cols or x_explain.columns
cols = cols or x_explain.columns.to_list()
importance_dict = self.explain_filtered_local(x_explain, cols=cols, n_cols=n_cols)[0]

output_value = self._model_to_explain.predict_proba(x_explain.values)[0, 1]
Expand Down
45 changes: 21 additions & 24 deletions trelawney/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,30 @@ def __init__(self):
super().__init__()
self._explainer = None

@staticmethod
def _find_right_explainer_class(model):
if isinstance(model, LogisticRegression):
return shap.LinearExplainer
if isinstance(model, (BaseDecisionTree, ForestClassifier, XGBClassifier)):
return shap.TreeExplainer
if isinstance(model, keras.models.Model):
return shap.DeepExplainer
raise ValueError(type(model))
return shap.KernelExplainer
def _find_right_explainer(self, x_train):
if isinstance(self._model_to_explain, LogisticRegression):
return shap.LinearExplainer(self._model_to_explain, data=x_train.values)
if isinstance(self._model_to_explain, (BaseDecisionTree, ForestClassifier, XGBClassifier)):
return shap.TreeExplainer(self._model_to_explain)
if isinstance(self._model_to_explain, KerasClassifier):
return shap.DeepExplainer(self._model_to_explain.model, data=x_train.values)
return shap.KernelExplainer(self._model_to_explain, data=x_train.values)

def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train: pd.DataFrame):
if isinstance(model, KerasClassifier):
# SHAP doesn't work with the sklearn wrappers of Keras
super().fit(model.model, x_train, y_train)
else:
super().fit(model, x_train, y_train)
self._explainer = self._find_right_explainer_class(self._model_to_explain)(self._model_to_explain,
data=x_train.values)
super().fit(model, x_train, y_train)
self._explainer = self._find_right_explainer(x_train)

def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -> List[Dict[str, float]]:
super().explain_local(x_explain)
def _get_shap_values(self, x_explain):
shap_values = self._explainer.shap_values(x_explain.values)
if isinstance(self._model_to_explain, keras.models.Model):
if isinstance(self._model_to_explain, KerasClassifier):
# for nn, shap creates a list of shap values for every input layer in the NN,
# we assume one input layer
shap_values = shap_values[0]
return shap_values[0]
return shap_values

def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -> List[Dict[str, float]]:
super().explain_local(x_explain)
shap_values = self._get_shap_values(x_explain)
n_cols = n_cols or len(x_explain.columns)
res = []
for individual_sample in tqdm(range(len(x_explain))):
Expand All @@ -74,7 +71,7 @@ def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -
return res

def feature_importance(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -> Dict[str, float]:
shap_values = self._explainer.shap_values(x_explain)
shap_dict = dict(zip(x_explain.columns.to_list(), list(np.mean(abs(shap_values), axis=0).tolist())))
kept_shap_bar_cols = dict(sorted(shap_dict.items(), key=lambda x: np.abs(x[1]), reverse=True,)[:n_cols])
shap_values = self._get_shap_values(x_explain)
shap_dict = dict(zip(x_explain.columns.to_list(), np.mean(np.abs(shap_values), axis=0).reshape(-1).tolist()))
kept_shap_bar_cols = dict(sorted(shap_dict.items(), key=lambda x: abs(x[1]), reverse=True,)[:n_cols])
return kept_shap_bar_cols

0 comments on commit e24f41c

Please sign in to comment.