From 90c57fc577d47ac0e12221af47fefc9a25e75c0f Mon Sep 17 00:00:00 2001 From: Stijn Tonk Date: Thu, 29 Dec 2016 02:46:53 +0100 Subject: [PATCH] FIX Split data using _safe_split in _permutaion_test_score (#5697) Squashed commits: [94fd9f4] split data using _safe_split in _permutaion_test_scorer [522053b] adding test case test_permutation_test_score_pandas() to check if permutation_test_score plays nice with pandas dataframe/series [21b23ce] running test_permutation_test_score_pandas on iris data to prevent warnings. [15a48bf] adding safe_indexing to _shuffle function [9ea5c9e] adding test case test_permutation_test_score_pandas() to check if permutation_test_score plays nice with pandas dataframe/series [3cf5e8f] split data using _safe_split in _permutaion_test_scorer to fix error when using Pandas DataFrame/Series --- sklearn/cross_validation.py | 8 +++++--- sklearn/model_selection/_validation.py | 8 +++++--- .../model_selection/tests/test_validation.py | 19 +++++++++++++++++++ 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index a4a1e3d65c7ca..03c74b88f5f28 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1756,8 +1756,10 @@ def _permutation_test_score(estimator, X, y, cv, scorer): """Auxiliary function for permutation_test_score""" avg_score = [] for train, test in cv: - estimator.fit(X[train], y[train]) - avg_score.append(scorer(estimator, X[test], y[test])) + X_train, y_train = _safe_split(estimator, X, y, train) + X_test, y_test = _safe_split(estimator, X, y, test, train) + estimator.fit(X_train, y_train) + avg_score.append(scorer(estimator, X_test, y_test)) return np.mean(avg_score) @@ -1770,7 +1772,7 @@ def _shuffle(y, labels, random_state): for label in np.unique(labels): this_mask = (labels == label) ind[this_mask] = random_state.permutation(ind[this_mask]) - return y[ind] + return safe_indexing(y, ind) def check_cv(cv, X=None, y=None, classifier=False): diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 6f8fd352d210e..a703ee964e03f 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -688,8 +688,10 @@ def _permutation_test_score(estimator, X, y, groups, cv, scorer): """Auxiliary function for permutation_test_score""" avg_score = [] for train, test in cv.split(X, y, groups): - estimator.fit(X[train], y[train]) - avg_score.append(scorer(estimator, X[test], y[test])) + X_train, y_train = _safe_split(estimator, X, y, train) + X_test, y_test = _safe_split(estimator, X, y, test, train) + estimator.fit(X_train, y_train) + avg_score.append(scorer(estimator, X_test, y_test)) return np.mean(avg_score) @@ -702,7 +704,7 @@ def _shuffle(y, groups, random_state): for group in np.unique(groups): this_mask = (groups == group) indices[this_mask] = random_state.permutation(indices[this_mask]) - return y[indices] + return safe_indexing(y, indices) def learning_curve(estimator, X, y, groups=None, diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index fe3f3df926034..5d28fbe23cc1a 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -1079,3 +1079,22 @@ def test_score_memmap(): break except WindowsError: sleep(1.) + + +def test_permutation_test_score_pandas(): + # check permutation_test_score doesn't destroy pandas dataframe + types = [(MockDataFrame, MockDataFrame)] + try: + from pandas import Series, DataFrame + types.append((Series, DataFrame)) + except ImportError: + pass + for TargetType, InputFeatureType in types: + # X dataframe, y series + iris = load_iris() + X, y = iris.data, iris.target + X_df, y_ser = InputFeatureType(X), TargetType(y) + check_df = lambda x: isinstance(x, InputFeatureType) + check_series = lambda x: isinstance(x, TargetType) + clf = CheckingClassifier(check_X=check_df, check_y=check_series) + permutation_test_score(clf, X_df, y_ser)