From 15a48bf8aa01cdf500972ed21be653edd6f210c6 Mon Sep 17 00:00:00 2001 From: Stijn Tonk Date: Sun, 3 Jan 2016 22:24:33 +0100 Subject: [PATCH] adding safe_indexing to _shuffle function --- sklearn/cross_validation.py | 2 +- sklearn/model_selection/_validation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 1da2121f85507..e9c1a70dd57d5 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1772,7 +1772,7 @@ def _shuffle(y, labels, random_state): for label in np.unique(labels): this_mask = (labels == label) ind[this_mask] = random_state.permutation(ind[this_mask]) - return y[ind] + return safe_indexing(y, ind) def check_cv(cv, X=None, y=None, classifier=False): diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 88c3922f99363..0010b89c82778 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -636,7 +636,7 @@ def _shuffle(y, groups, random_state): for group in np.unique(groups): this_mask = (groups == group) indices[this_mask] = random_state.permutation(indices[this_mask]) - return y[indices] + return safe_indexing(y, indices) def learning_curve(estimator, X, y, groups=None,