From 0ce21808198958eea88feb21b6e3f765a9be8d51 Mon Sep 17 00:00:00 2001 From: Noel Dawe Date: Mon, 24 Feb 2014 19:31:18 -0800 Subject: [PATCH] cross_validation --- sklearn/cross_validation.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 3fce33a0873d8..c93baf5bf3488 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1115,11 +1115,11 @@ def cross_val_score(estimator, X, y=None, sample_weight=None, # independent, and that it is pickle-able. parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) - scores = parallel(delayed(_fit_and_score)( - clone(estimator), X, y, sample_weight, scorer, - train, test, verbose, None, - fit_params) - for train, test in cv) + scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, + sample_weight, scorer, + train, test, verbose, None, + fit_params) + for train, test in cv) return np.array(scores)[:, 0] @@ -1211,8 +1211,8 @@ def _fit_and_score(estimator, X, y, sample_weight, X_test, y_test, sample_weight_test = _safe_split( estimator, X, y, sample_weight, test, train) - test_score_params = dict() - train_score_params = dict() + test_score_params = {} + train_score_params = {} if sample_weight is not None: fit_params = fit_params.copy() fit_params['sample_weight'] = sample_weight_train @@ -1274,10 +1274,7 @@ def _safe_split(estimator, X, y, sample_weight, indices, train_indices=None): y_subset = None if sample_weight is not None: - if not hasattr(sample_weight, "shape"): - sample_weight_subset = [sample_weight[idx] for idx in indices] - else: - sample_weight_subset = sample_weight[indices] + sample_weight_subset = np.asarray(sample_weight)[indices] else: sample_weight_subset = None