Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG ensure that parallel/sequential give the same permutation importances #15933

Merged
merged 32 commits into from
Dec 20, 2019
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a47ebe2
BUG ensure that parallel/sequential provide the same results
glemaitre Dec 19, 2019
ab95fd5
iter
glemaitre Dec 19, 2019
d58272d
whats new
glemaitre Dec 19, 2019
7211a53
Check that the test is not trivial
ogrisel Dec 19, 2019
70e1ef4
Typo in PR number
ogrisel Dec 19, 2019
a6909ca
Fix thread-safety issue
ogrisel Dec 19, 2019
a24b9d5
Add non-regression test to check that issue 15810 is fixed.
ogrisel Dec 19, 2019
7ad0b26
Leaner test
ogrisel Dec 19, 2019
fb69870
Support joblib 0.11
ogrisel Dec 19, 2019
f80dc69
BUG Syntax error
thomasjpfan Dec 19, 2019
775e986
MAX_RAND_SEED should be int32
glemaitre Dec 19, 2019
0631299
cosmetic
glemaitre Dec 19, 2019
1a21a98
inplace operation
glemaitre Dec 19, 2019
023eca2
cosmit
glemaitre Dec 19, 2019
e213236
cosmit
glemaitre Dec 19, 2019
be8f1c1
Better comment explaining the need for X.copy()
ogrisel Dec 20, 2019
910ef4f
Fix random seed range
ogrisel Dec 20, 2019
723bf03
Test exact equivalence in column shuffling of pandas dataframes with …
ogrisel Dec 20, 2019
f5bda8c
Add acknowledgment to 15898
glemaitre Dec 20, 2019
e9770cf
factorize max_int_32
glemaitre Dec 20, 2019
9cdc7b8
make max_int_32 inclusive
glemaitre Dec 20, 2019
03ab3a1
explicitly call for max int32
glemaitre Dec 20, 2019
0c25e61
fix
glemaitre Dec 20, 2019
fe4cac6
revert max int32 changes
glemaitre Dec 20, 2019
7bdb93a
fix
glemaitre Dec 20, 2019
d399d96
Test with non-numpy-native column
ogrisel Dec 20, 2019
bdaffb5
reshuffling by position
glemaitre Dec 20, 2019
42e8cb5
remove unused import
glemaitre Dec 20, 2019
1834fca
[ci skip] typos & better comment
ogrisel Dec 20, 2019
74a1c54
Merge branch 'master' into is/random_state_parallel
ogrisel Dec 20, 2019
5cf37f6
TST: check dataframe with a weird index
ogrisel Dec 20, 2019
51f7467
FIX make column permutation robust to weird indices
ogrisel Dec 20, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ This is a bug-fix release to primarily resolve some packaging issues in version
Changelog
---------

:mod:`sklearn.inspection`
.........................

- |Fix| :func:`inspection.permutation_importance` will return the same
`importances` when a `random_state` is given for both `n_jobs=1` or
`n_jobs>1` both with shared memory backends (thread-safety) and
isolated memory, process-based backends.
Also avoid casting the data as object dtype and avoid read-only error
on large dataframes with `n_jobs>1` as reported in :issue:`15810`.
Follow-up of :pr:`15898` by :user:`Shivam Gargsya <shivamgargsya>`.
:pr:`15933` by :user:`Guillaume Lemaitre <glemaitre>` and `Olivier Grisel`_.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

:mod:`sklearn.metrics`
......................

Expand Down
62 changes: 30 additions & 32 deletions sklearn/inspection/_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,37 @@
from joblib import delayed

from ..metrics import check_scoring
from ..utils import Bunch
from ..utils import check_random_state
from ..utils import check_array
from ..utils import Bunch


def _safe_column_setting(X, col_idx, values):
"""Set column on X using `col_idx`"""
if hasattr(X, "iloc"):
X.iloc[:, col_idx] = values
else:
X[:, col_idx] = values


def _safe_column_indexing(X, col_idx):
"""Return column from X using `col_idx`"""
if hasattr(X, "iloc"):
return X.iloc[:, col_idx].values
else:
return X[:, col_idx]


def _calculate_permutation_scores(estimator, X, y, col_idx, random_state,
n_repeats, scorer):
"""Calculate score when `col_idx` is permuted."""
original_feature = _safe_column_indexing(X, col_idx).copy()
temp = original_feature.copy()
random_state = check_random_state(random_state)

# Work on a copy of X to to ensure thread-safety in case of threading based
# parallelism. Furthermore, making a copy is also useful when the joblib
# backend is 'loky' (default) or the old 'multiprocessing': in those cases,
# if X is large it will be automatically be backed by a readonly memory map
# (memmap). X.copy() on the other hand is always guaranteed to return a
# writable data-structure whose columns can be shuffled inplace.
X_permuted = X.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for the reviewers: the fact that we always make a copy here also fixes the issue with read-only memmaps as reported in #15810.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just by chance, you are not aware of a way to avoid to re-allocating the full dataframe, and instead make a view for most columns except for the one we want to change inplace?

For instance of one makes a slice of a dataframe, and then tries to modify to a column, pandas will raise a warning about a view being modified, but I'm not sure if it will actually change the original dataframe inplace or not in this case...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the internal block structure of the dataframe but this is considered private API and is likely to change in future versions of pandas. I would rather stay safe for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, I can see this being possible when pandas switches to using a columnar data structure to hold its data as detailed in the pandas roadmap.

scores = np.zeros(n_repeats)
shuffling_idx = np.arange(X.shape[0])
for n_round in range(n_repeats):
random_state.shuffle(temp)
_safe_column_setting(X, col_idx, temp)
feature_score = scorer(estimator, X, y)
random_state.shuffle(shuffling_idx)
if hasattr(X_permuted, "iloc"):
# reset the index such that pandas reaffect by position instead of
# indices
X_permuted.iloc[:, col_idx] = X_permuted.iloc[
shuffling_idx, col_idx].reset_index(drop=True)
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
else:
X_permuted[:, col_idx] = X_permuted[shuffling_idx, col_idx]
feature_score = scorer(estimator, X_permuted, y)
scores[n_round] = feature_score

_safe_column_setting(X, col_idx, original_feature)
return scores


Expand Down Expand Up @@ -104,20 +100,22 @@ def permutation_importance(estimator, X, y, scoring=None, n_repeats=5,
.. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32,
2001. https://doi.org/10.1023/A:1010933404324
"""
if hasattr(X, "iloc"):
X = X.copy() # Dataframe
else:
X = check_array(X, force_all_finite='allow-nan', dtype=np.object,
copy=True)

if not hasattr(X, "iloc"):
X = check_array(X, force_all_finite='allow-nan', dtype=None)
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

# Precompute random seed from the random state to be used
# to get a fresh independent RandomState instance for each
# parallel call to _calculate_permutation_scores, irrespective of
# the fact that variables are shared or not depending on the active
# joblib backend (sequential, thread-based or process-based).
random_state = check_random_state(random_state)
scorer = check_scoring(estimator, scoring=scoring)
random_seed = random_state.randint(np.iinfo(np.int32).max + 1)

scorer = check_scoring(estimator, scoring=scoring)
baseline_score = scorer(estimator, X, y)
scores = np.zeros((X.shape[1], n_repeats))

scores = Parallel(n_jobs=n_jobs)(delayed(_calculate_permutation_scores)(
estimator, X, y, col_idx, random_state, n_repeats, scorer
estimator, X, y, col_idx, random_seed, n_repeats, scorer
) for col_idx in range(X.shape[1]))

importances = baseline_score - np.array(scores)
Expand Down
123 changes: 123 additions & 0 deletions sklearn/inspection/tests/test_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
from sklearn.compose import ColumnTransformer
from sklearn.datasets import load_boston
from sklearn.datasets import load_iris
from sklearn.datasets import make_classification
from sklearn.datasets import make_regression
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from sklearn.impute import SimpleImputer
from sklearn.inspection import permutation_importance
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import scale
from sklearn.utils import parallel_backend
from sklearn.utils._testing import _convert_container


@pytest.mark.parametrize("n_jobs", [1, 2])
def test_permutation_importance_correlated_feature_regression(n_jobs):
Expand Down Expand Up @@ -150,3 +156,120 @@ def test_permutation_importance_linear_regresssion():
scoring='neg_mean_squared_error')
assert_allclose(expected_importances, results.importances_mean,
rtol=1e-1, atol=1e-6)


def test_permutation_importance_equivalence_sequential_parallel():
# regression test to make sure that sequential and parallel calls will
# output the same results.
X, y = make_regression(n_samples=500, n_features=10, random_state=0)
lr = LinearRegression().fit(X, y)

importance_sequential = permutation_importance(
lr, X, y, n_repeats=5, random_state=0, n_jobs=1
)

# First check that the problem is structured enough and that the model is
# complex enough to not yield trivial, constant importances:
imp_min = importance_sequential['importances'].min()
imp_max = importance_sequential['importances'].max()
assert imp_max - imp_min > 0.3

# The actually check that parallelism does not impact the results
# either with shared memory (threading) or without isolated memory
# via process-based parallelism using the default backend
# ('loky' or 'multiprocessing') depending on the joblib version:

# process-based parallelism (by default):
importance_processes = permutation_importance(
lr, X, y, n_repeats=5, random_state=0, n_jobs=2)
assert_allclose(
importance_processes['importances'],
importance_sequential['importances']
)
ogrisel marked this conversation as resolved.
Show resolved Hide resolved

# thread-based parallelism:
with parallel_backend("threading"):
importance_threading = permutation_importance(
lr, X, y, n_repeats=5, random_state=0, n_jobs=2
)
assert_allclose(
importance_threading['importances'],
importance_sequential['importances']
)
ogrisel marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("n_jobs", [None, 1, 2])
def test_permutation_importance_equivalence_array_dataframe(n_jobs):
# This test checks that the column shuffling logic has the same behavior
# both a dataframe and a simple numpy array.
pd = pytest.importorskip('pandas')

# regression test to make sure that sequential and parallel calls will
# output the same results.
X, y = make_regression(n_samples=100, n_features=5, random_state=0)
X_df = pd.DataFrame(X)

# Add a categorical feature that is statistically linked to y:
binner = KBinsDiscretizer(n_bins=3, encode="ordinal")
cat_column = binner.fit_transform(y.reshape(-1, 1))

# Concatenate the extra column to the numpy array: integers will be
# cast to float values
X = np.hstack([X, cat_column])
assert X.dtype.kind == "f"

# Insert extra column as a non-numpy-native dtype (while keeping backward
# compat for old pandas versions):
if hasattr(pd, "Categorical"):
cat_column = pd.Categorical(cat_column.ravel())
else:
cat_column = cat_column.ravel()
new_col_idx = len(X_df.columns)
X_df[new_col_idx] = cat_column
assert X_df[new_col_idx].dtype == cat_column.dtype

rf = RandomForestRegressor(n_estimators=5, max_depth=3, random_state=0)
rf.fit(X, y)

n_repeats = 3
importance_array = permutation_importance(
rf, X, y, n_repeats=n_repeats, random_state=0, n_jobs=n_jobs
)

# First check that the problem is structured enough and that the model is
# complex enough to not yield trivial, constant importances:
imp_min = importance_array['importances'].min()
imp_max = importance_array['importances'].max()
assert imp_max - imp_min > 0.3

# Now check that importances computed on dataframe matche the values
# of those computed on the array with the same data.
importance_dataframe = permutation_importance(
rf, X_df, y, n_repeats=n_repeats, random_state=0, n_jobs=n_jobs
)
assert_allclose(
importance_array['importances'],
importance_dataframe['importances']
)


@pytest.mark.parametrize("input_type", ["array", "dataframe"])
def test_permutation_importance_large_memmaped_data(input_type):
# Smoke, non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/15810
n_samples, n_features = int(5e4), 4
X, y = make_classification(n_samples=n_samples, n_features=n_features,
random_state=0)
assert X.nbytes > 1e6 # trigger joblib memmaping

X = _convert_container(X, input_type)
clf = DummyClassifier(strategy='prior').fit(X, y)

# Actual smoke test: should not raise any error:
n_repeats = 5
r = permutation_importance(clf, X, y, n_repeats=n_repeats, n_jobs=2)

# Auxiliary check: DummyClassifier is feature independent:
# permutating feature should not change the predictions
expected_importances = np.zeros((n_features, n_repeats))
assert_allclose(expected_importances, r.importances)