Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Adds Permutation Importance #13146

Merged
merged 98 commits into from Jul 17, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
00c56de
ENH Adds files
thomasjpfan Feb 8, 2019
29be4f4
ENH Adds permutation importance
thomasjpfan Feb 12, 2019
2e09bfb
RFC Better names
thomasjpfan Feb 12, 2019
f7bb490
STY Flake8
thomasjpfan Feb 12, 2019
6f0175c
ENH: Adds inspect module
thomasjpfan Feb 12, 2019
bf44eb1
DOC Adds pre_dispatch
thomasjpfan Feb 12, 2019
85ed781
DOC Adds permutation importance example
thomasjpfan Feb 12, 2019
66e71dd
Trigger CI
thomasjpfan Feb 13, 2019
a93a9f3
BLD Adds inspect to configuration
thomasjpfan Feb 13, 2019
ee1e77f
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Feb 19, 2019
0670997
RFC Update to only inspect fitted model
thomasjpfan Feb 19, 2019
334c8c3
RFC Removes parameters
thomasjpfan Feb 19, 2019
354ac62
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Mar 1, 2019
260fa54
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Mar 14, 2019
92894a1
ENH: Adds pandas support
thomasjpfan Mar 14, 2019
f45c653
STY Flake8
thomasjpfan Mar 14, 2019
50d8550
DOC Adds new permutation importance example
thomasjpfan Mar 15, 2019
74e915f
ENH Renames module to model_inspection
thomasjpfan Mar 15, 2019
2a7d8e2
DOC Fix links
thomasjpfan Mar 15, 2019
920362a
DOC Fixes image link
thomasjpfan Mar 15, 2019
747599b
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Mar 15, 2019
a48e151
DOC Fixes image link
thomasjpfan Mar 15, 2019
51b745d
DOC Spelling
thomasjpfan Mar 16, 2019
23c8d11
DOC
thomasjpfan Mar 17, 2019
4241414
TST Fix keyword
thomasjpfan Mar 17, 2019
a12bc0c
Rework RF Imp vs Perm Imp example (#4)
ogrisel Apr 1, 2019
e864071
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Apr 1, 2019
9a57e20
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 9, 2019
5798338
WIP
thomasjpfan May 9, 2019
72b9003
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 9, 2019
37d52ba
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 10, 2019
fe675f9
WIP
thomasjpfan May 10, 2019
ced888d
WIP
thomasjpfan May 13, 2019
b0357fc
DOC Adds multcollinear features example
thomasjpfan May 15, 2019
91bf4e2
WIP
thomasjpfan May 15, 2019
a1d5880
DOC: Clean up docs
thomasjpfan May 15, 2019
4eb1e82
TST Adds tests for strings
thomasjpfan May 15, 2019
6f98f11
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 15, 2019
1656985
STY Indent correction
thomasjpfan May 16, 2019
0d34d80
WIP
thomasjpfan May 18, 2019
62868f6
ENH Uses check_X_y
thomasjpfan May 18, 2019
e7efe6d
TST Adds test with strings
thomasjpfan May 18, 2019
d75b557
STY Fix
thomasjpfan May 22, 2019
e3bbcda
TST Adds column transformer to test
thomasjpfan May 22, 2019
6c60e43
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 28, 2019
ed469d6
CLN Address comments
thomasjpfan May 29, 2019
6180975
CLN Removes import
thomasjpfan May 29, 2019
24d740e
TST Adds test with nan
thomasjpfan May 29, 2019
914335d
CLN Removes import
thomasjpfan May 29, 2019
f0beac6
ENH Parallel
thomasjpfan May 29, 2019
ac8d5a3
DOC comments
thomasjpfan May 29, 2019
31e9408
ENH Better handling of pandas
thomasjpfan May 29, 2019
be3f65b
ENH Clear checking of pandas dataframe
thomasjpfan May 29, 2019
e1df6a6
STY Formatting
thomasjpfan May 29, 2019
78aba62
ENH Copies in parallel helper
thomasjpfan May 29, 2019
d6ca3c5
DOC Adds comments
thomasjpfan May 29, 2019
a2aa960
BUG Fix copying
thomasjpfan May 29, 2019
9ff6aa1
BUG Fix for pandas
thomasjpfan May 29, 2019
f112cd3
BUG Fix for pandas
thomasjpfan May 29, 2019
884d648
REV
thomasjpfan May 29, 2019
c64e6a1
BLD Trigger CI
thomasjpfan May 29, 2019
d2fad37
BUG Fix
thomasjpfan May 30, 2019
50b6b98
BUG Fix
thomasjpfan May 30, 2019
14b3efd
TST Does this work
thomasjpfan May 30, 2019
f41f5b3
BUG Fixes test
thomasjpfan May 30, 2019
3cd43ce
BUG Fixes test
thomasjpfan May 30, 2019
318c961
BUG Fix
thomasjpfan May 30, 2019
aa6c79d
BUG Fix
thomasjpfan May 30, 2019
5292136
BUG Fix
thomasjpfan May 30, 2019
9b53e35
Merge branch 'permutation_importance_v2' into permutation_importance
thomasjpfan May 30, 2019
bc3ea96
STY Fix
thomasjpfan May 30, 2019
7d79a49
TST Fix
thomasjpfan May 30, 2019
b487618
TST Fix segfault
thomasjpfan May 31, 2019
7a83608
CLN Address comments
thomasjpfan Jun 17, 2019
af9c961
CLN Address comments
thomasjpfan Jun 17, 2019
664d863
ENH Returns a bunch
thomasjpfan Jun 17, 2019
8a022c6
STY Flake8
thomasjpfan Jun 17, 2019
78ed4e8
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jun 17, 2019
fbebc5e
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jun 24, 2019
d62df83
CLN Renames bunch key
thomasjpfan Jun 25, 2019
118601a
DOC Updates api
thomasjpfan Jun 25, 2019
9f1325f
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jun 25, 2019
2655f82
DOC Updates api
thomasjpfan Jun 25, 2019
ca9a78b
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 1, 2019
1748227
TST Adds permutation test with linear_regression
thomasjpfan Jul 2, 2019
e1607ff
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 2, 2019
fb4f926
DOC update
thomasjpfan Jul 4, 2019
eb154a9
DOC Fix label cutoff
thomasjpfan Jul 4, 2019
b1f9c70
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 4, 2019
946ca59
CLN Address comments
thomasjpfan Jul 9, 2019
5676930
TST Adds test for random_state effect
thomasjpfan Jul 9, 2019
78cefef
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 9, 2019
204c3ab
DOC Adds permutation importance
thomasjpfan Jul 9, 2019
dab6801
DOC Adds ogrisel suggestion
thomasjpfan Jul 9, 2019
c67667f
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 16, 2019
f90eacf
DOC Address guillaumes comments
thomasjpfan Jul 16, 2019
6b428d7
DOC Address andreas comments
thomasjpfan Jul 16, 2019
94c4c56
DOC Update
thomasjpfan Jul 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/inspect/README.txt
@@ -0,0 +1,7 @@
.. _model_inspection_examples:

Model Inspection
----------------

Examples concerning the :mod:`sklearn.inspect` module.

76 changes: 76 additions & 0 deletions examples/inspect/plot_permutation_importance.py
@@ -0,0 +1,76 @@
"""
==================================================
Permutation Importance vs Random Forest Importance
==================================================

The random forest `feature_importances_`, are computed from train set
statistics and are subject to bias with the cardinality of the feature. The
permutation importance of a feature is calculated by measuring how much the
model performance decreases when the feature is permutated.

In this example, we add a column of random numbers to the diabetes dataset.
Then we fit a :class:`sklearn.ensemble.RandomForestRegressor` to this modified
dataset. The feature importance from the random forest is plotted. In this
case, the ``RANDOM`` feature is considerd more important than the ``age`` or
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
``sex`` feature.

Next, we use :func:`sklearn.inspect.permutation_importance` to calcuate the
permutation importance for each feature.
The `sklearn.inspect.permutation_importance` returns a numpy array where
values in each row are the cross-validated scores for a feature. The
permutation importance for the random forest is plotted. In this case,
The ``RANDOM`` feature is less important than ``sex`` and ``age``.
"""
print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor
from sklearn.inspect import permutation_importance


def plot_importances(importances, features, highlight=None, ax=None):
N = features.shape[0]

if ax is None:
_, ax = plt.subplots()
y_ticks = range(1, N + 1)
arg_sorted = np.argsort(importances)

color = ["blue" for _ in range(N)]
labels = features[arg_sorted]

if highlight is not None:
for idx, label in enumerate(labels):
if label == highlight:
color[idx] = "red"

ax.barh(y_ticks, importances[arg_sorted], color=color)
ax.set_yticks(y_ticks)
ax.set_xlim(0, np.max(importances)*1.05)
ax.set_ylim(0, N + 1)
ax.set_yticklabels(features[arg_sorted])


ds = load_diabetes()
X, y = ds.data, ds.target
features = np.array(ds.feature_names + ["RAND"])
rng = np.random.RandomState(42)
X = np.hstack([X, rng.normal(scale=1, size=(X.shape[0], 1))])

rf = RandomForestRegressor(n_estimators=50, random_state=rng)
rf.fit(X, y)

fig, (ax1, ax2) = plt.subplots(1, 2)
plot_importances(rf.feature_importances_, features, highlight="RAND", ax=ax1)
ax1.set_title("Feature importance from random forest")

perm_importances = permutation_importance(rf, X, y, random_state=rng,
scoring="explained_variance")
perm_importances_mean = perm_importances.mean(axis=1)
plot_importances(perm_importances_mean, features, highlight="RAND", ax=ax2)
ax2.set_title("Permutation importance")
fig.tight_layout()
plt.show()
1 change: 1 addition & 0 deletions sklearn/__init__.py
Expand Up @@ -75,6 +75,7 @@
'naive_bayes', 'neighbors', 'neural_network', 'pipeline',
'preprocessing', 'random_projection', 'semi_supervised',
'svm', 'tree', 'discriminant_analysis', 'impute', 'compose',
'inspect',
# Non-modules:
'clone', 'get_config', 'set_config', 'config_context',
'show_versions']
Expand Down
3 changes: 3 additions & 0 deletions sklearn/inspect/__init__.py
@@ -0,0 +1,3 @@
from .permutation_importance import permutation_importance

__all__ = ['permutation_importance']
182 changes: 182 additions & 0 deletions sklearn/inspect/permutation_importance.py
@@ -0,0 +1,182 @@
"""Permutation importance for estimators"""
from contextlib import contextmanager

import numpy as np

from ..base import is_classifier, clone
from ..utils import check_random_state
from ..utils._joblib import Parallel, delayed
from ..model_selection import check_cv
from ..metrics import check_scoring
from ..utils.metaestimators import _safe_split


@contextmanager
def _permute_column(X, column, random_state):
"""Context manager to permute a column"""
original_feature = X[:, column].copy()
X[:, column] = random_state.permutation(X[:, column])
yield X
X[:, column] = original_feature


def _fit_and_calcuate_permutation_importance(estimator, X, y, train_indices,
test_indices, columns, scoring,
random_state):
"""Fits and calculates permutation importance
Fits ``estimator`` on ``X`` and ``y``
Parameters
----------
estimator : object
A supervised learning estimator with a `fit` and is compatible with
``scorer``.
X : array-like, shape = (n_samples, n_features)
Training data.
y : array-like, shape = (n_samples, ...)
Target relative to ``X``.
train_indices : array of int
Train indicies.
test_indices : array of int
Test indices.
columns : list of integers
A list of columns to calculate the permutation importance. If `None`,
all columns will be used.
scoring : string, callable or None
A string (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.
random_state: : RandomState instance
Random number generator.
Returns
-------
permutation_importance_scores : list
Permutation importance scores for each column on the validation set
defined by ``test_indices``.
"""
X_train, y_train = _safe_split(estimator, X, y, train_indices)
X_test, y_test = _safe_split(estimator, X, y, test_indices, train_indices)

estimator.fit(X_train, y_train)
baseline_score = scoring(estimator, X_test, y_test)

permutation_importance_scores = []
for column in columns:
with _permute_column(X_test, column, random_state) as X_perm:
feature_score = scoring(estimator, X_perm, y_test)
permutation_importance_scores.append(baseline_score -
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean when this value is negative? Do we need to clip in that case??

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Negative means that the model performed better with the feature permuted. This could mean that the feature should be dropped.

There is a paragraph about this in https://explained.ai/rf-importance/index.html at Figure 3(a)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I think both the docstring and the user guide should explain the meaning of negative importance.

feature_score)

return permutation_importance_scores


def permutation_importance(estimator, X, y, columns=None, scoring=None, cv=5,
n_jobs=None, pre_dispatch='2*n_jobs',
random_state=None):
"""Permutation importance for feature evaluation.
The permutation importance of a feature is calculated as follows. First,
the estimator is trained on a training set. Then a baseline metric, defined
by ``scoring``, is evaluated on a validation set. Next, a feature column
from the validation set is permuted and the metric is evaluated again.
The permutation importance is defined to be the difference between the
baseline metric and metric from permutating the feature column.
Parameters
----------
estimator : object
A supervised learning estimator with a `fit` and is compatible with
``scorer``.
X : array-like, shape = (n_samples, n_features)
Training data.
y : array-like, shape = (n_samples, ...)
Target relative to ``X``.
columns : list of integers, optional (default=None)
A list of columns to calculate the permutation importance. If `None`,
all columns will be used
scoring : string, callable or None, optional (default=None)
A string (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.
cv : int, cross-validation generator or an iterable, optional (default=5)
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- integer, to specify the number of folds.
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, :class:`KFold` is used.
Refer :ref:`User Guide <cross_validation>` for the various
cross-validation strategies that can be used here.
n_jobs : int or None, optional (default=None)
Number of CPUs to use during the cross validation.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
pre_dispatch : int, or string, optional
Controls the number of jobs that get dispatched during parallel
execution. Reducing this number can be useful to avoid an
explosion of memory consumption when more jobs get dispatched
than CPUs can process. This parameter can be:
- None, in which case all the jobs are immediately
created and spawned. Use this for lightweight and
fast-running jobs, to avoid delays due to on-demand
spawning of the jobs
- An int, giving the exact number of total jobs that are
spawned
- A string, giving an expression as a function of n_jobs,
as in '2*n_jobs'
random_state : int, RandomState instance or None, optional, default None
The seed of the pseudo random number generator that selects a random
feature to update. If int, random_state is the seed used by the random
number generator; If RandomState instance, random_state is the random
number generator; If None, the random number generator is the
RandomState instance used by `np.random`.
Returns
-------
permutation_importance_scores : array, shape (n_columns, n_cv)
Permutation importance scores where the rows are ordered corresponding
to the ``columns`` argument.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a reference - and a user guide!


cv = check_cv(cv, y, classifier=is_classifier(estimator))
random_state = check_random_state(random_state)
scoring = check_scoring(estimator, scoring=scoring)

parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch)

if columns is None:
columns = range(0, X.shape[1])

with parallel:
permutation_importance_scores = parallel(
delayed(_fit_and_calcuate_permutation_importance)(
clone(estimator), X, y, train_indices,
test_indices, columns, scoring, random_state
) for train_indices, test_indices in cv.split(X, y))

return np.array(permutation_importance_scores).T
Empty file.
35 changes: 35 additions & 0 deletions sklearn/inspect/tests/test_permutation_importance.py
@@ -0,0 +1,35 @@
import pytest

import numpy as np

from sklearn.datasets import load_boston
from sklearn.inspect import permutation_importance
from sklearn.ensemble import RandomForestRegressor


@pytest.mark.parametrize("columns", [
None, [0, 2, 4, 6, 8, 10, 12, 13], [1, 3, 5, 7, 9, 11, 13]
])
@pytest.mark.parametrize("scoring", [
None, "neg_mean_absolute_error"
])
def test_permutation_importance_correlated_feature_is_important(
columns, scoring):
rng = np.random.RandomState(42)
X, y = load_boston(return_X_y=True)

# Adds correlated feature to X
y_with_little_noise = y + rng.normal(scale=0.001, size=y.shape[0])
X = np.hstack([X, y_with_little_noise.reshape(-1, 1)])

rf = RandomForestRegressor(n_estimators=50, random_state=42)
permute_scores = permutation_importance(rf, X, y, columns=columns, cv=4,
random_state=42, scoring=scoring)

if columns is None:
assert permute_scores.shape == (X.shape[1], 4)
else:
assert permute_scores.shape == (len(columns), 4)

permuate_score_means = np.mean(permute_scores, axis=-1)
assert np.all(permuate_score_means[-1] > permuate_score_means[:-1])