Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new create_counterfactuals func #740

Merged
merged 2 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ nav:
- user_guide/evaluate/cochrans_q.md
- user_guide/evaluate/combined_ftest_5x2cv.md
- user_guide/evaluate/confusion_matrix.md
- user_guide/evaluate/create_counterfactual.md
- user_guide/evaluate/feature_importance_permutation.md
- user_guide/evaluate/ftest.md
- user_guide/evaluate/lift_score.md
Expand Down
3 changes: 2 additions & 1 deletion docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ The CHANGELOG for the current development version is available at
##### New Features

- The `bias_variance_decomp` now supports Keras estimators. ([#725](https://github.com/rasbt/mlxtend/pull/725) via [@hanzigs](https://github.com/hanzigs))
- Adds new `OneRClassifier` (One Rule Classfier) ([#726](https://github.com/rasbt/mlxtend/pull/726)
- Adds new `mlxtend.classifier.OneRClassifier` (One Rule Classfier) class, a simple rule-based classifier that is often used as a performance baseline or simple interpretable model. ([#726](https://github.com/rasbt/mlxtend/pull/726)
- Adds new `create_counterfactual` method for creating counterfactuals to explain model predictions. ([#740](https://github.com/rasbt/mlxtend/pull/740))


##### Changes
Expand Down
3 changes: 2 additions & 1 deletion docs/sources/USER_GUIDE_INDEX.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
- [bootstrap_point632_score](user_guide/evaluate/bootstrap_point632_score.md)
- [BootstrapOutOfBag](user_guide/evaluate/BootstrapOutOfBag.md)
- [cochrans_q](user_guide/evaluate/cochrans_q.md)
- [confusion_matrix](user_guide/evaluate/confusion_matrix.md)
- [combined_ftest_5x2cv](user_guide/evaluate/combined_ftest_5x2cv.md)
- [confusion_matrix](user_guide/evaluate/confusion_matrix.md)
- [create_counterfactual](user_guide/evaluate/create_counterfactual.md)
- [feature_importance_permutation](user_guide/evaluate/feature_importance_permutation.md)
- [ftest](user_guide/evaluate/ftest.md)
- [lift_score](user_guide/evaluate/lift_score.md)
Expand Down
9 changes: 6 additions & 3 deletions docs/sources/user_guide/evaluate/confusion_matrix.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
"collapsed": true,
"jupyter": {
"outputs_hidden": true
}
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -364,7 +367,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.8.3"
},
"toc": {
"nav_menu": {},
Expand All @@ -380,5 +383,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
630 changes: 630 additions & 0 deletions docs/sources/user_guide/evaluate/create_counterfactual.ipynb

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion mlxtend/evaluate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .proportion_difference import proportion_difference
from .bias_variance_decomp import bias_variance_decomp
from .accuracy import accuracy_score
from .counterfactual import create_counterfactual

__all__ = ["scoring", "confusion_matrix",
"mcnemar_table", "mcnemar_tables",
Expand All @@ -39,4 +40,4 @@
"RandomHoldoutSplit", "PredefinedHoldoutSplit",
"ftest", "combined_ftest_5x2cv",
"proportion_difference", "bias_variance_decomp",
"accuracy_score"]
"accuracy_score", "create_counterfactual"]
112 changes: 112 additions & 0 deletions mlxtend/evaluate/counterfactual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
#
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause

from scipy.optimize import minimize
import warnings

import numpy as np


def create_counterfactual(x_reference, y_desired, model, X_dataset,
y_desired_proba=None, lammbda=0.1, random_seed=None):

"""
Implementation of the counterfactual method by Wachter et al. 2017

References:

- Wachter, S., Mittelstadt, B., & Russell, C. (2017).
Counterfactual explanations without opening the black box:
Automated decisions and the GDPR. Harv. JL & Tech., 31, 841.,
https://arxiv.org/abs/1711.00399

Parameters
----------

x_reference : array-like, shape=[m_features]
The data instance (training example) to be explained.

y_desired : int
The desired class label for `x_reference`.

model : estimator
A (scikit-learn) estimator implementing `.predict()` and/or
`predict_proba()`.
- If `model` supports `predict_proba()`, then this is used by
default for the first loss term,
`(lambda * model.predict[_proba](x_counterfact) - y_desired[_proba])^2`
- Otherwise, method will fall back to `predict`.

X_dataset : array-like, shape=[n_examples, m_features]
A (training) dataset for picking the initial counterfactual
as initial value for starting the optimization procedure.

y_desired_proba : float (default: None)
A float within the range [0, 1] designating the desired
class probability for `y_desired`.
- If `y_desired_proba=None` (default), the first loss term
is `(lambda * model(x_counterfact) - y_desired)^2` where `y_desired`
is a class label
- If `y_desired_proba` is not None, the first loss term
is `(lambda * model(x_counterfact) - y_desired_proba)^2`

lammbda : Weighting parameter for the first loss term,
`(lambda * model(x_counterfact) - y_desired[_proba])^2`

random_seed : int (default=None)
If int, random_seed is the seed used by
the random number generator for selecting the inital counterfactual
from `X_dataset`.

"""
if y_desired_proba is not None:
use_proba = True
if not hasattr(model, "predict_proba"):
raise AttributeError("Your `model` does not support "
"`predict_proba`. Set `y_desired_proba` "
" to `None` to use `predict`instead.")
else:
use_proba = False

if y_desired_proba is None:
# class label
y_to_be_annealed_to = y_desired
else:
# class proba corresponding to class label y_desired
y_to_be_annealed_to = y_desired_proba

# start with random counterfactual
rng = np.random.RandomState(random_seed)
x_counterfact = X_dataset[rng.randint(X_dataset.shape[0])]

# compute median absolute deviation
mad = np.abs(np.median(X_dataset, axis=0) - x_reference)

def dist(x_reference, x_counterfact):
numerator = np.abs(x_reference - x_counterfact)
return np.sum(numerator/mad)

def loss(x_counterfact, lammbda):

if use_proba:
y_predict = model.predict_proba(
x_counterfact.reshape(1, -1)).flatten()[y_desired]
else:
y_predict = model.predict(x_counterfact.reshape(1, -1))

diff = lammbda*(y_predict - y_to_be_annealed_to)**2

return diff + dist(x_reference, x_counterfact)

res = minimize(loss, x_counterfact, args=(lammbda), method='Nelder-Mead')

if not res['success']:
warnings.warn(res['message'])

x_counterfact = res['x']

return x_counterfact
110 changes: 110 additions & 0 deletions mlxtend/evaluate/tests/test_counterfactual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause

from mlxtend.evaluate import create_counterfactual
import numpy as np
from sklearn.linear_model import LogisticRegression
from mlxtend.data import iris_data
from mlxtend.classifier import OneRClassifier
from mlxtend.utils import assert_raises


def test__medium_lambda():
X, y = iris_data()
clf = LogisticRegression()
clf.fit(X, y)

x_ref = X[15]

res = create_counterfactual(x_reference=x_ref,
y_desired=2,
model=clf,
X_dataset=X,
y_desired_proba=1.,
lammbda=1,
random_seed=123)

assert np.argmax(clf.predict_proba(x_ref.reshape(1, -1))) == 0
assert np.argmax(clf.predict_proba(res.reshape(1, -1))) == 2
assert round((clf.predict_proba(
0.65 >= res.reshape(1, -1))).flatten()[-1], 2) <= 0.69


def test__small_lambda():
X, y = iris_data()
clf = LogisticRegression()
clf.fit(X, y)

x_ref = X[15]

res = create_counterfactual(x_reference=x_ref,
y_desired=2,
model=clf,
X_dataset=X,
y_desired_proba=1.,
lammbda=0.0001,
random_seed=123)

assert np.argmax(clf.predict_proba(x_ref.reshape(1, -1))) == 0
assert np.argmax(clf.predict_proba(res.reshape(1, -1))) == 0
assert round((clf.predict_proba(
res.reshape(1, -1))).flatten()[-1], 2) == 0.0


def test__large_lambda():
X, y = iris_data()
clf = LogisticRegression()
clf.fit(X, y)

x_ref = X[15]

res = create_counterfactual(x_reference=x_ref,
y_desired=2,
model=clf,
X_dataset=X,
y_desired_proba=1.,
lammbda=100,
random_seed=123)

assert np.argmax(clf.predict_proba(x_ref.reshape(1, -1))) == 0
assert np.argmax(clf.predict_proba(res.reshape(1, -1))) == 2
assert round((clf.predict_proba(
res.reshape(1, -1))).flatten()[-1], 2) >= 0.96


def test__clf_with_no_proba_fail():
X, y = iris_data()
clf = OneRClassifier()
clf.fit(X, y)

x_ref = X[15]

s = ("Your `model` does not support "
"`predict_proba`. Set `y_desired_proba` "
" to `None` to use `predict`instead.")

assert_raises(AttributeError,
s,
create_counterfactual, x_ref, 2, clf, X, 1., 100, 123)


def test__clf_with_no_proba_pass():
X, y = iris_data()
clf = OneRClassifier()
clf.fit(X, y)

x_ref = X[15]

res = create_counterfactual(x_reference=x_ref,
y_desired=2,
model=clf,
X_dataset=X,
y_desired_proba=None,
lammbda=100,
random_seed=123)

assert clf.predict(x_ref.reshape(1, -1)) == 0
assert clf.predict(res.reshape(1, -1)) == 2