Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] added validation for sample weight in DummyRegressor #15505

Merged
merged 3 commits into from Nov 4, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions sklearn/dummy.py
Expand Up @@ -13,7 +13,7 @@
from .utils.validation import _num_samples
from .utils.validation import check_array
from .utils.validation import check_consistent_length
from .utils.validation import check_is_fitted
from .utils.validation import check_is_fitted, _check_sample_weight
from .utils.random import _random_choice_csc
from .utils.stats import _weighted_percentile
from .utils.multiclass import class_distribution
Expand Down Expand Up @@ -143,6 +143,9 @@ def fit(self, X, y, sample_weight=None):

check_consistent_length(X, y, sample_weight)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

if self.strategy == "constant":
if self.constant is None:
raise ValueError("Constant target value has to be specified "
Expand Down Expand Up @@ -471,8 +474,9 @@ def fit(self, X, y, sample_weight=None):
self.n_outputs_ = y.shape[1]

check_consistent_length(X, y, sample_weight)

if sample_weight is not None:
sample_weight = np.asarray(sample_weight)
sample_weight = _check_sample_weight(sample_weight, X)

if self.strategy == "mean":
self.constant_ = np.average(y, axis=0, weights=sample_weight)
Expand Down