Skip to content

Commit

Permalink
Waterfall graph (#16)
Browse files Browse the repository at this point in the history
* fix some bugs with filters

* implement waterfall

* added systematic rest calculation

* add doc

* fix tet instability
  • Loading branch information
aredier committed Oct 6, 2019
1 parent ef62206 commit 626dca7
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 56 deletions.
3 changes: 0 additions & 3 deletions .idea/.gitignore

This file was deleted.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pandas==0.25.1
scikit-learn==0.21.3
tqdm==4.36.1
lime==0.1.1.36
plotly==4.1.1
20 changes: 19 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from keras import layers, models
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn import base
from sklearn.linear_model import LogisticRegression


Expand All @@ -28,6 +29,23 @@ def make_neural_network():
model.compile(loss='categorical_crossentropy', optimizer='adam')
return model

model = KerasClassifier(make_neural_network, epochs=10, batch_size=100)
model = KerasClassifier(make_neural_network, epochs=100, batch_size=100)
model.fit(*fake_dataset)
return model


@pytest.fixture
def FakeClassifier():

class FakeClassifierInner(base.BaseEstimator):

def fit(self, X, y):
return self

def predict_proba(self, X):
return np.array([[0.5, 0.5] for _ in range(X.shape[0])])

def predict(self, X):
return self.predict_proba(X)[:, 1] >= 0.5

return FakeClassifierInner
146 changes: 105 additions & 41 deletions tests/test_base_explainer.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,139 @@
import operator
from typing import Optional
from typing import Optional, Dict, Tuple, List, Union

import pandas as pd
import pytest
import sklearn
import numpy as np
import plotly.graph_objs as go

from trelawney.base_explainer import BaseExplainer


class FakeExplainer(BaseExplainer):

def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train: pd.DataFrame):
pass
return super().fit(model, x_train, y_train)

@staticmethod
def _regularize(importance_dict: List[Tuple[str, float]]) -> List[Tuple[str, float]]:
total = sum(map(operator.itemgetter(1), importance_dict))
return [
(key, -(2 * (i % 2) - 1) * (value / total))
for i, (key, value) in enumerate(importance_dict)
]

def feature_importance(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None):
return dict(sorted(
importance = self._regularize(sorted(
((col, np.mean(np.abs(x_explain.loc[:, col]))) for col in x_explain.columns),
key=operator.itemgetter(1),
reverse=True
)[:n_cols])
))
total_mvmt = sum(map(operator.itemgetter(1), importance))
res = dict(importance[:n_cols])
res['rest'] = total_mvmt - sum(res.values())
return res

def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None):
return [
dict(sorted(sample_explanation.items(), key=operator.itemgetter(1), reverse=True)[:n_cols])
for sample_explanation in x_explain.abs().to_dict(orient='records')
]
res = []
for sample_explanation in x_explain.abs().to_dict(orient='records'):
importance = self._regularize(sorted(sample_explanation.items(), key=operator.itemgetter(1), reverse=True))

total_mvmt = sum(map(operator.itemgetter(1), importance))
res_ind = dict(importance[:n_cols])
res_ind['rest'] = total_mvmt - sum(res_ind.values())
res.append(res_ind)

return res


def _float_error_resilient_compare(left: Union[List[Dict], Dict], right: Union[List[Dict], Dict]):
assert len(left) == len(right)
if isinstance(left, list):
return [_float_error_resilient_compare(ind_left, ind_right) for ind_left, ind_right in zip(left, right)]
for key, value in left.items():
assert key in right
assert abs(value - right[key]) < 0.0001


def test_explainer_basic():

explainer = FakeExplainer()
assert explainer.feature_importance(pd.DataFrame([[10, 0], [0, -5]], columns=['var_1', 'var_2'])) == {
'var_1': 5., 'var_2': 2.5
}
assert explainer.feature_importance(pd.DataFrame([[10, 0], [0, -5]], columns=['var_1', 'var_2']), n_cols=1) == {
'var_1': 5.
}

assert explainer.explain_local(pd.DataFrame([[10, 0], [0, -5]], columns=['var_1', 'var_2'])) == [
{'var_1': 10., 'var_2': 0.},
{'var_2': 5., 'var_1': 0.}
]
assert explainer.explain_local(pd.DataFrame([[10, 0], [0, -5]], columns=['var_1', 'var_2']), n_cols=1) == [
{'var_1': 10.},
{'var_2': 5.}
]
_float_error_resilient_compare(
explainer.feature_importance(pd.DataFrame([[10, 0], [0, -5]], columns=['var_1', 'var_2'])),
{'var_1': 5 / 7.5, 'var_2': -2.5 / 7.5, 'rest': 0.}
)
_float_error_resilient_compare(
explainer.feature_importance(pd.DataFrame([[10, 0], [0, -5]], columns=['var_1', 'var_2']), n_cols=1),
{'var_1': 5. / 7.5, 'rest': -2.5 / 7.5}
)

_float_error_resilient_compare(
explainer.explain_local(pd.DataFrame([[10, 0], [0, -5]], columns=['var_1', 'var_2'])),
[{'var_1': 1., 'var_2': 0., 'rest': 0.}, {'var_2': 1., 'var_1': 0., 'rest': 0.}]
)

_float_error_resilient_compare(
explainer.explain_local(pd.DataFrame([[10, 0], [0, -5]], columns=['var_1', 'var_2']), n_cols=1),
[{'var_1': 1., 'rest': 0.},{'var_2': 1, 'rest': 0.}]
)


def test_explainer_filter():

explainer = FakeExplainer()
assert explainer.filtered_feature_importance(
pd.DataFrame([[10, 0, 4], [0, -5, 3]], columns=['var_1', 'var_2', 'var_3']),
cols=['var_1', 'var_3']) == {'var_1': 5., 'var_3': 3.5}
_float_error_resilient_compare(
explainer.filtered_feature_importance(pd.DataFrame(
[[10, 0, 4], [0, -5, 3]], columns=['var_1', 'var_2', 'var_3']), cols=['var_1', 'var_3']
),
{'var_1': 10 / 22, 'var_3': -7 / 22, 'rest': 5 / 22}
)

_float_error_resilient_compare(
explainer.filtered_feature_importance(
pd.DataFrame([[10, 0, 4], [0, -5, 3]], columns=['var_1', 'var_2', 'var_3']), n_cols=1,
cols=['var_1', 'var_3']
),
{'var_1': 10 / 22, 'rest': -2 / 22}
)

_float_error_resilient_compare(
explainer.explain_filtered_local(
pd.DataFrame([[10, 0, 4], [0, -5, 3]], columns=['var_1', 'var_2', 'var_3']), cols=['var_1', 'var_3']
),
[{'var_1': 10 / 14, 'var_3': -4 / 14, 'rest': 0.}, {'var_3': -3 / 8, 'var_1': 0., 'rest': 5 / 8}]
)

_float_error_resilient_compare(
explainer.explain_filtered_local(
pd.DataFrame([[10, 0, 4], [0, -5, 3]], columns=['var_1', 'var_2', 'var_3']),
cols=['var_1', 'var_3'], n_cols=1
),
[{'var_1': 10 / 14, 'rest': -4 / 14}, {'var_3': -3 / 8, 'rest': 5 / 8}]
)


def test_local_graph(FakeClassifier, fake_dataset):

model = FakeClassifier()
explainer = FakeExplainer()
explainer.fit(model, *fake_dataset)

assert explainer.filtered_feature_importance(
pd.DataFrame([[10, 0, 4], [0, -5, 3]], columns=['var_1', 'var_2', 'var_3']),
n_cols=1, cols=['var_1', 'var_3']) == {'var_1': 5.}
with pytest.raises(ValueError):
_ = explainer.graph_local_explanation(pd.DataFrame([[10, 30], [1, 2]], columns=['var_1', 'var_2']))

assert explainer.explain_filtered_local(
pd.DataFrame([[10, 0, 4], [0, -5, 3]], columns=['var_1', 'var_2', 'var_3']),
cols=['var_1', 'var_3']) == [
{'var_1': 10., 'var_3': 4.},
{'var_3': 3., 'var_1': 0.}
]
assert explainer.explain_filtered_local(
pd.DataFrame([[10, 0, 4], [0, -5, 3]], columns=['var_1', 'var_2', 'var_3']),
cols=['var_1', 'var_3'], n_cols=1) == [
{'var_1': 10.},
{'var_3': 3.}
]
graph = explainer.graph_local_explanation(pd.DataFrame([[10, 30]], columns=['var_1', 'var_2']))

assert len(graph.data) == 1
assert isinstance(graph.data[0], go.Waterfall)
waterfall = graph.data[0]
assert waterfall.x == ('start_value', 'var_2', 'var_1', 'rest', 'output_value')
assert waterfall.y == (0., .75, -0.25, 0., 0.5)

graph = explainer.graph_local_explanation(pd.DataFrame([[10, 30]], columns=['var_1', 'var_2']), n_cols=1)

assert len(graph.data) == 1
assert isinstance(graph.data[0], go.Waterfall)
waterfall = graph.data[0]
assert waterfall.x == ('start_value', 'var_2', 'rest', 'output_value')
assert waterfall.y == (0., .75, -0.25, 0.5)
56 changes: 50 additions & 6 deletions trelawney/base_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import sklearn
import pandas as pd
import plotly.graph_objs as go


class BaseExplainer(abc.ABC):
Expand All @@ -20,7 +21,9 @@ class BaseExplainer(abc.ABC):
in a dataset
"""

@abc.abstractmethod
def __init__(self):
self._model_to_explain = None

def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train: pd.DataFrame):
"""
fits the explainer if needed
Expand All @@ -29,7 +32,7 @@ def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train:
:param x_train: the dataset the model was trained on originally
:param y_train: the target the model was trained on originally
"""
pass
self._model_to_explain = model

@abc.abstractmethod
def feature_importance(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -> Dict[str, float]:
Expand All @@ -48,14 +51,17 @@ def feature_importance(self, x_explain: pd.DataFrame, n_cols: Optional[int] = No

@staticmethod
def _filter_and_limit_dict(col_importance_dic: Dict[str, float], cols: List[str], n_cols: int):
return dict(sorted(
og_mvmt = sum(col_importance_dic.values())
sorted_and_filtered = dict(sorted(
filter(
lambda col_and_importance: col_and_importance[0] in cols,
col_importance_dic.items()
),
key=operator.itemgetter(1),
key=lambda x: abs(x[1]),
reverse=True
)[:n_cols])
sorted_and_filtered['rest'] = sorted_and_filtered.get('rest', 0.) + (og_mvmt - sum(sorted_and_filtered.values()))
return sorted_and_filtered

def filtered_feature_importance(self, x_explain: pd.DataFrame, cols: List[str],
n_cols: Optional[int] = None) -> Dict[str, float]:
Expand Down Expand Up @@ -85,5 +91,43 @@ def explain_filtered_local(self, x_explain: pd.DataFrame, cols: List[str],
for sample_importance_dict in self.explain_local(x_explain)
]

def graph_local_explanation(self, x_explain: pd.DataFrame, cols: List[str], n_cols: Optional[int] = None):
raise NotImplementedError('graphing functionalities not implemented yet')
def graph_local_explanation(self, x_explain: pd.DataFrame, cols: Optional[List[str]] = None,
n_cols: Optional[int] = None) -> go.Figure:
"""
creates a waterfall plotly figure to represent the influance of each feature on the final decision for a single
prediction of the model.
You can filter the columns you want to see in your graph and limit the final number of columns you want to see.
If you choose to do so the filter will be applied first and of those filtered columns at most `n_cols` will be
kept
:param x_explain: the example of the model this must be a dataframe with a single ow
:param cols: the columns to keep if you want to filter (if None - default) all the columns will be kept
:param n_cols: the number of columns to limit the graph to. (if None - default) all the columns will be kept
:raises ValueError: if x_explain doesn't have the right shape
"""
if x_explain.shape[0] != 1:
raise ValueError('can only explain single observations, if you only have one sample, use reshape(1, -1)')
cols = cols or x_explain.columns
importance_dict = self.explain_filtered_local(x_explain, cols=cols, n_cols=n_cols)[0]

output_value = self._model_to_explain.predict_proba(x_explain.values)[0, 1]
start_value = output_value - sum(importance_dict.values())
rest = importance_dict.pop('rest')

sorted_importances = sorted(importance_dict.items(), key=lambda importance: abs(importance[1]), reverse=True)
fig = go.Figure(go.Waterfall(
orientation="v",
measure=['absolute', *['relative' for _ in importance_dict], 'relative', 'absolute'],
y=[start_value, *map(operator.itemgetter(1), sorted_importances), rest, output_value],
textposition="outside",
# text = ["+60", "+80", "", "-40", "-20", "Total"],
x=['start_value', *map(operator.itemgetter(0), sorted_importances), 'rest', 'output_value'],
connector={"line": {"color": "rgb(63, 63, 63)"}},
))
fig.update_layout(
title="explanation",
showlegend=True
)
return fig
15 changes: 10 additions & 5 deletions trelawney/lime_explainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
from typing import List, Optional, Dict

import pandas as pd
Expand Down Expand Up @@ -39,16 +40,16 @@ class LimeExplainer(BaseExplainer):
"""

def __init__(self, class_names: Optional[List[str]] = None, categorical_features: Optional[List[str]] = None, ):
super().__init__()
self._explainer = None
if class_names is not None and len(class_names) != 2:
raise NotImplementedError('Trelawney only handles binary classification case for now. PR welcome ;)')
self.class_names = class_names
self._output_len = None
self.categorical_features = categorical_features
self._model_to_explain = None

def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train: pd.DataFrame, ):
self._model_to_explain = model
super().fit(model, x_train, y_train)
self._explainer = lime_tabular.LimeTabularExplainer(x_train.values, feature_names=x_train.columns,
class_names=self.class_names,
categorical_features=self.categorical_features,
Expand All @@ -73,8 +74,12 @@ def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -
for individual_sample in tqdm(x_explain.iterrows()):
individual_explanation = self._explainer.explain_instance(individual_sample[1],
self._model_to_explain.predict_proba,
num_features=n_cols,
num_features=x_explain.shape[1],
top_labels=2)
res.append({self._extract_col_from_explanation(col_explanation): col_value
for col_explanation, col_value in individual_explanation.as_list()})
individual_explanation = sorted(individual_explanation.as_list(), key=operator.itemgetter(1), reverse=True)
skewed_individual_explanation = {self._extract_col_from_explanation(col_name): col_importance
for col_name, col_importance in individual_explanation[:n_cols]}
rest = sum(map(operator.itemgetter(1), individual_explanation)) - sum(skewed_individual_explanation.values())
skewed_individual_explanation['rest'] = rest
res.append(skewed_individual_explanation)
return res

0 comments on commit 626dca7

Please sign in to comment.