-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* lime explainer done * getting rid of 2.7 check * bugs, docs and tests * added xgb test * nn tests * update gitignore * update tests * added doctest * changes as per review
- Loading branch information
Showing
14 changed files
with
240 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ python: | |
- 3.7 | ||
- 3.6 | ||
- 3.5 | ||
- 2.7 | ||
install: pip install -U tox-travis | ||
script: tox | ||
deploy: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
pandas==0.25.1 | ||
scikit-learn==0.21.3 | ||
tqdm==4.36.1 | ||
lime==0.1.1.36 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import pandas as pd | ||
import numpy as np | ||
import pytest | ||
from keras import layers, models | ||
from keras.wrappers.scikit_learn import KerasClassifier | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
|
||
@pytest.fixture | ||
def fake_dataset(): | ||
return (pd.DataFrame([list(range(100)), np.random.normal(size=100).tolist()], index=['real', 'fake']).T, | ||
np.array(range(100)) > 50) | ||
|
||
|
||
@pytest.fixture | ||
def fitted_logistic_regression(fake_dataset): | ||
model = LogisticRegression() | ||
return model.fit(*fake_dataset) | ||
|
||
|
||
@pytest.fixture | ||
def fitted_neural_network(fake_dataset): | ||
|
||
def make_neural_network(): | ||
model = models.Sequential([ | ||
layers.Dense(2, input_shape=(2,), activation='softmax') | ||
]) | ||
model.compile(loss='categorical_crossentropy', optimizer='adam') | ||
return model | ||
|
||
model = KerasClassifier(make_neural_network, epochs=10, batch_size=100) | ||
model.fit(*fake_dataset) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from xgboost import XGBClassifier | ||
|
||
from trelawney.lime_explainer import LimeExplainer | ||
|
||
|
||
def _do_explainer_test(explainer): | ||
explanation = explainer.explain_local(pd.DataFrame([[5, 0.1], [95, -0.5]])) | ||
assert len(explanation) == 2 | ||
for single_explanation in explanation: | ||
assert abs(single_explanation['real']) > abs(single_explanation['fake']) | ||
|
||
|
||
def test_lime_explainer_single(fake_dataset, fitted_logistic_regression): | ||
explainer = LimeExplainer(class_names=['false', 'true']) | ||
explainer.fit(fitted_logistic_regression, *fake_dataset) | ||
explanation = explainer.explain_local(pd.DataFrame([[5, 0.1]])) | ||
assert len(explanation) == 1 | ||
single_explanation = explanation[0] | ||
assert abs(single_explanation['real']) > abs(single_explanation['fake']) | ||
|
||
|
||
def test_lime_explainer_multiple(fake_dataset, fitted_logistic_regression): | ||
explainer = LimeExplainer(class_names=['false', 'true']) | ||
explainer.fit(fitted_logistic_regression, *fake_dataset) | ||
_do_explainer_test(explainer) | ||
|
||
|
||
def test_lime_xgb(fake_dataset): | ||
model = XGBClassifier() | ||
x, y = fake_dataset | ||
model.fit(x.values, y) | ||
|
||
explainer = LimeExplainer() | ||
explainer.fit(model, *fake_dataset) | ||
with pytest.raises(TypeError): | ||
explainer.explain_local(x.values) | ||
_do_explainer_test(explainer) | ||
|
||
|
||
def test_lime_nn(fake_dataset, fitted_neural_network): | ||
|
||
explainer = LimeExplainer(class_names=['false', 'true']) | ||
explainer.fit(fitted_neural_network, *fake_dataset) | ||
explanation = explainer.explain_local(pd.DataFrame([[5, 0.1], [95, -0.5]])) | ||
assert len(explanation) == 2 | ||
for single_explanation in explanation: | ||
assert abs(single_explanation['real']) > abs(single_explanation['fake']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import List, Optional, Dict | ||
|
||
import pandas as pd | ||
import sklearn | ||
from lime import lime_tabular | ||
from tqdm import tqdm | ||
|
||
from trelawney.base_explainer import BaseExplainer | ||
|
||
|
||
class LimeExplainer(BaseExplainer): | ||
""" | ||
Lime stands for local interpretable model-agnostic explanations and is a package based on | ||
`this article <https://www.arxiv.org/abs/1602.04938>`_. Lime will explain a single prediction of you model | ||
by crechariotsating a local approximation of your model around said prediction.'sphinx.ext.autodoc', 'sphinx.ext.viewcode'] | ||
.. testsetup:: | ||
>>> import pandas as pd | ||
>>> import numpy as np | ||
>>> from trelawney.lime_explainer import LimeExplainer | ||
>>> from sklearn.linear_model import LogisticRegression | ||
.. doctest:: | ||
>>> X = pd.DataFrame([np.array(range(100)), np.random.normal(size=100).tolist()], index=['real', 'fake']).T | ||
>>> y = np.array(range(100)) > 50 | ||
>>> # training the base model | ||
>>> model = LogisticRegression().fit(X, y) | ||
>>> # creating and fiting the explainer | ||
>>> explainer = LimeExplainer() | ||
>>> explainer.fit(model, X, y) | ||
>>> # explaining observation | ||
>>> explanation = explainer.explain_local(pd.DataFrame([[5, 0.1]]))[0] | ||
>>> abs(explanation['real']) > abs(explanation['fake']) | ||
True | ||
""" | ||
|
||
def __init__(self, class_names: Optional[List[str]] = None, categorical_features: Optional[List[str]] = None, ): | ||
self._explainer = None | ||
if class_names is not None and len(class_names) != 2: | ||
raise NotImplementedError('Trelawney only handles binary classification case for now. PR welcome ;)') | ||
self.class_names = class_names | ||
self._output_len = None | ||
self.categorical_features = categorical_features | ||
self._model_to_explain = None | ||
|
||
def fit(self, model: sklearn.base.BaseEstimator, x_train: pd.DataFrame, y_train: pd.DataFrame, ): | ||
self._model_to_explain = model | ||
self._explainer = lime_tabular.LimeTabularExplainer(x_train.values, feature_names=x_train.columns, | ||
class_names=self.class_names, | ||
categorical_features=self.categorical_features, | ||
discretize_continuous=True) | ||
|
||
def feature_importance(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -> Dict[str, float]: | ||
raise NotImplementedError('we are not sure global explaination is mathematically sound for LIME, it is still' | ||
' debated, refer tp https://github.com/skanderkam/trelawney/issues/10') | ||
|
||
@staticmethod | ||
def _extract_col_from_explanation(col_explanation): | ||
is_left_term = len([x for x in col_explanation if x in ['<', '>']]) < 2 | ||
if is_left_term: | ||
return col_explanation.split()[0] | ||
return col_explanation.split()[2] | ||
|
||
def explain_local(self, x_explain: pd.DataFrame, n_cols: Optional[int] = None) -> List[Dict[str, float]]: | ||
if not isinstance(x_explain, pd.DataFrame): | ||
raise TypeError('{} is not supported, please use dataframes'.format(type(x_explain))) | ||
n_cols = n_cols or len(x_explain.columns) | ||
res = [] | ||
for individual_sample in tqdm(x_explain.iterrows()): | ||
individual_explanation = self._explainer.explain_instance(individual_sample[1], | ||
self._model_to_explain.predict_proba, | ||
num_features=n_cols, | ||
top_labels=2) | ||
res.append({self._extract_col_from_explanation(col_explanation): col_value | ||
for col_explanation, col_value in individual_explanation.as_list()}) | ||
return res |