Skip to content

Commit

Permalink
cross_validation
Browse files Browse the repository at this point in the history
  • Loading branch information
ndawe committed Apr 22, 2014
1 parent c5867a7 commit 0ce2180
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions sklearn/cross_validation.py
Expand Up @@ -1115,11 +1115,11 @@ def cross_val_score(estimator, X, y=None, sample_weight=None,
# independent, and that it is pickle-able.
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
pre_dispatch=pre_dispatch)
scores = parallel(delayed(_fit_and_score)(
clone(estimator), X, y, sample_weight, scorer,
train, test, verbose, None,
fit_params)
for train, test in cv)
scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y,
sample_weight, scorer,
train, test, verbose, None,
fit_params)
for train, test in cv)
return np.array(scores)[:, 0]


Expand Down Expand Up @@ -1211,8 +1211,8 @@ def _fit_and_score(estimator, X, y, sample_weight,
X_test, y_test, sample_weight_test = _safe_split(
estimator, X, y, sample_weight, test, train)

test_score_params = dict()
train_score_params = dict()
test_score_params = {}
train_score_params = {}
if sample_weight is not None:
fit_params = fit_params.copy()
fit_params['sample_weight'] = sample_weight_train
Expand Down Expand Up @@ -1274,10 +1274,7 @@ def _safe_split(estimator, X, y, sample_weight, indices, train_indices=None):
y_subset = None

if sample_weight is not None:
if not hasattr(sample_weight, "shape"):
sample_weight_subset = [sample_weight[idx] for idx in indices]
else:
sample_weight_subset = sample_weight[indices]
sample_weight_subset = np.asarray(sample_weight)[indices]
else:
sample_weight_subset = None

Expand Down

0 comments on commit 0ce2180

Please sign in to comment.