diff --git a/sklearn/dummy.py b/sklearn/dummy.py index b12b314c4a91f..e20c5d2c678ef 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -13,7 +13,7 @@ from .utils.validation import _num_samples from .utils.validation import check_array from .utils.validation import check_consistent_length -from .utils.validation import check_is_fitted +from .utils.validation import check_is_fitted, _check_sample_weight from .utils.random import _random_choice_csc from .utils.stats import _weighted_percentile from .utils.multiclass import class_distribution @@ -143,6 +143,9 @@ def fit(self, X, y, sample_weight=None): check_consistent_length(X, y, sample_weight) + if sample_weight is not None: + sample_weight = _check_sample_weight(sample_weight, X) + if self.strategy == "constant": if self.constant is None: raise ValueError("Constant target value has to be specified " @@ -471,8 +474,9 @@ def fit(self, X, y, sample_weight=None): self.n_outputs_ = y.shape[1] check_consistent_length(X, y, sample_weight) + if sample_weight is not None: - sample_weight = np.asarray(sample_weight) + sample_weight = _check_sample_weight(sample_weight, X) if self.strategy == "mean": self.constant_ = np.average(y, axis=0, weights=sample_weight)