Skip to content

Commit

Permalink
BUG ensure that parallel/sequential give the same permutation importa…
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored and Pan Jan committed Mar 3, 2020
1 parent 11d2c38 commit b6d097c
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 32 deletions.
13 changes: 13 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,25 @@ This is a bug-fix release to primarily resolve some packaging issues in version
Changelog
---------


:mod:`sklearn.cluster`
......................

- |Fix| :class:`KMeans` with ``algorithm="elkan"`` now uses the same stopping
criterion as with the default ``algorithm="full"``. :pr:`15930` by
:user:`inder128`.

:mod:`sklearn.inspection`
.........................

- |Fix| :func:`inspection.permutation_importance` will return the same
`importances` when a `random_state` is given for both `n_jobs=1` or
`n_jobs>1` both with shared memory backends (thread-safety) and
isolated memory, process-based backends.
Also avoid casting the data as object dtype and avoid read-only error
on large dataframes with `n_jobs>1` as reported in :issue:`15810`.
Follow-up of :pr:`15898` by :user:`Shivam Gargsya <shivamgargsya>`.
:pr:`15933` by :user:`Guillaume Lemaitre <glemaitre>` and `Olivier Grisel`_.

:mod:`sklearn.metrics`
......................
Expand Down
61 changes: 29 additions & 32 deletions sklearn/inspection/_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,36 @@
from joblib import delayed

from ..metrics import check_scoring
from ..utils import Bunch
from ..utils import check_random_state
from ..utils import check_array
from ..utils import Bunch


def _safe_column_setting(X, col_idx, values):
"""Set column on X using `col_idx`"""
if hasattr(X, "iloc"):
X.iloc[:, col_idx] = values
else:
X[:, col_idx] = values


def _safe_column_indexing(X, col_idx):
"""Return column from X using `col_idx`"""
if hasattr(X, "iloc"):
return X.iloc[:, col_idx].values
else:
return X[:, col_idx]


def _calculate_permutation_scores(estimator, X, y, col_idx, random_state,
n_repeats, scorer):
"""Calculate score when `col_idx` is permuted."""
original_feature = _safe_column_indexing(X, col_idx).copy()
temp = original_feature.copy()
random_state = check_random_state(random_state)

# Work on a copy of X to to ensure thread-safety in case of threading based
# parallelism. Furthermore, making a copy is also useful when the joblib
# backend is 'loky' (default) or the old 'multiprocessing': in those cases,
# if X is large it will be automatically be backed by a readonly memory map
# (memmap). X.copy() on the other hand is always guaranteed to return a
# writable data-structure whose columns can be shuffled inplace.
X_permuted = X.copy()
scores = np.zeros(n_repeats)
shuffling_idx = np.arange(X.shape[0])
for n_round in range(n_repeats):
random_state.shuffle(temp)
_safe_column_setting(X, col_idx, temp)
feature_score = scorer(estimator, X, y)
random_state.shuffle(shuffling_idx)
if hasattr(X_permuted, "iloc"):
col = X_permuted.iloc[shuffling_idx, col_idx]
col.index = X_permuted.index
X_permuted.iloc[:, col_idx] = col
else:
X_permuted[:, col_idx] = X_permuted[shuffling_idx, col_idx]
feature_score = scorer(estimator, X_permuted, y)
scores[n_round] = feature_score

_safe_column_setting(X, col_idx, original_feature)
return scores


Expand Down Expand Up @@ -104,20 +99,22 @@ def permutation_importance(estimator, X, y, scoring=None, n_repeats=5,
.. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32,
2001. https://doi.org/10.1023/A:1010933404324
"""
if hasattr(X, "iloc"):
X = X.copy() # Dataframe
else:
X = check_array(X, force_all_finite='allow-nan', dtype=np.object,
copy=True)

if not hasattr(X, "iloc"):
X = check_array(X, force_all_finite='allow-nan', dtype=None)

# Precompute random seed from the random state to be used
# to get a fresh independent RandomState instance for each
# parallel call to _calculate_permutation_scores, irrespective of
# the fact that variables are shared or not depending on the active
# joblib backend (sequential, thread-based or process-based).
random_state = check_random_state(random_state)
scorer = check_scoring(estimator, scoring=scoring)
random_seed = random_state.randint(np.iinfo(np.int32).max + 1)

scorer = check_scoring(estimator, scoring=scoring)
baseline_score = scorer(estimator, X, y)
scores = np.zeros((X.shape[1], n_repeats))

scores = Parallel(n_jobs=n_jobs)(delayed(_calculate_permutation_scores)(
estimator, X, y, col_idx, random_state, n_repeats, scorer
estimator, X, y, col_idx, random_seed, n_repeats, scorer
) for col_idx in range(X.shape[1]))

importances = baseline_score - np.array(scores)
Expand Down
126 changes: 126 additions & 0 deletions sklearn/inspection/tests/test_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
from sklearn.compose import ColumnTransformer
from sklearn.datasets import load_boston
from sklearn.datasets import load_iris
from sklearn.datasets import make_classification
from sklearn.datasets import make_regression
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from sklearn.impute import SimpleImputer
from sklearn.inspection import permutation_importance
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import scale
from sklearn.utils import parallel_backend
from sklearn.utils._testing import _convert_container


@pytest.mark.parametrize("n_jobs", [1, 2])
def test_permutation_importance_correlated_feature_regression(n_jobs):
Expand Down Expand Up @@ -150,3 +156,123 @@ def test_permutation_importance_linear_regresssion():
scoring='neg_mean_squared_error')
assert_allclose(expected_importances, results.importances_mean,
rtol=1e-1, atol=1e-6)


def test_permutation_importance_equivalence_sequential_parallel():
# regression test to make sure that sequential and parallel calls will
# output the same results.
X, y = make_regression(n_samples=500, n_features=10, random_state=0)
lr = LinearRegression().fit(X, y)

importance_sequential = permutation_importance(
lr, X, y, n_repeats=5, random_state=0, n_jobs=1
)

# First check that the problem is structured enough and that the model is
# complex enough to not yield trivial, constant importances:
imp_min = importance_sequential['importances'].min()
imp_max = importance_sequential['importances'].max()
assert imp_max - imp_min > 0.3

# The actually check that parallelism does not impact the results
# either with shared memory (threading) or without isolated memory
# via process-based parallelism using the default backend
# ('loky' or 'multiprocessing') depending on the joblib version:

# process-based parallelism (by default):
importance_processes = permutation_importance(
lr, X, y, n_repeats=5, random_state=0, n_jobs=2)
assert_allclose(
importance_processes['importances'],
importance_sequential['importances']
)

# thread-based parallelism:
with parallel_backend("threading"):
importance_threading = permutation_importance(
lr, X, y, n_repeats=5, random_state=0, n_jobs=2
)
assert_allclose(
importance_threading['importances'],
importance_sequential['importances']
)


@pytest.mark.parametrize("n_jobs", [None, 1, 2])
def test_permutation_importance_equivalence_array_dataframe(n_jobs):
# This test checks that the column shuffling logic has the same behavior
# both a dataframe and a simple numpy array.
pd = pytest.importorskip('pandas')

# regression test to make sure that sequential and parallel calls will
# output the same results.
X, y = make_regression(n_samples=100, n_features=5, random_state=0)
X_df = pd.DataFrame(X)

# Add a categorical feature that is statistically linked to y:
binner = KBinsDiscretizer(n_bins=3, encode="ordinal")
cat_column = binner.fit_transform(y.reshape(-1, 1))

# Concatenate the extra column to the numpy array: integers will be
# cast to float values
X = np.hstack([X, cat_column])
assert X.dtype.kind == "f"

# Insert extra column as a non-numpy-native dtype (while keeping backward
# compat for old pandas versions):
if hasattr(pd, "Categorical"):
cat_column = pd.Categorical(cat_column.ravel())
else:
cat_column = cat_column.ravel()
new_col_idx = len(X_df.columns)
X_df[new_col_idx] = cat_column
assert X_df[new_col_idx].dtype == cat_column.dtype

# Stich an aribtrary index to the dataframe:
X_df.index = np.arange(len(X_df)).astype(str)

rf = RandomForestRegressor(n_estimators=5, max_depth=3, random_state=0)
rf.fit(X, y)

n_repeats = 3
importance_array = permutation_importance(
rf, X, y, n_repeats=n_repeats, random_state=0, n_jobs=n_jobs
)

# First check that the problem is structured enough and that the model is
# complex enough to not yield trivial, constant importances:
imp_min = importance_array['importances'].min()
imp_max = importance_array['importances'].max()
assert imp_max - imp_min > 0.3

# Now check that importances computed on dataframe matche the values
# of those computed on the array with the same data.
importance_dataframe = permutation_importance(
rf, X_df, y, n_repeats=n_repeats, random_state=0, n_jobs=n_jobs
)
assert_allclose(
importance_array['importances'],
importance_dataframe['importances']
)


@pytest.mark.parametrize("input_type", ["array", "dataframe"])
def test_permutation_importance_large_memmaped_data(input_type):
# Smoke, non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/15810
n_samples, n_features = int(5e4), 4
X, y = make_classification(n_samples=n_samples, n_features=n_features,
random_state=0)
assert X.nbytes > 1e6 # trigger joblib memmaping

X = _convert_container(X, input_type)
clf = DummyClassifier(strategy='prior').fit(X, y)

# Actual smoke test: should not raise any error:
n_repeats = 5
r = permutation_importance(clf, X, y, n_repeats=n_repeats, n_jobs=2)

# Auxiliary check: DummyClassifier is feature independent:
# permutating feature should not change the predictions
expected_importances = np.zeros((n_features, n_repeats))
assert_allclose(expected_importances, r.importances)

0 comments on commit b6d097c

Please sign in to comment.