Skip to content

Commit

Permalink
Argument validation in make_multilabel_classification() (scikit-learn…
Browse files Browse the repository at this point in the history
  • Loading branch information
rushabh-v authored and Pan Jan committed Mar 3, 2020
1 parent 48c47aa commit 9378346
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.23.rst
Expand Up @@ -63,6 +63,10 @@ Changelog
by :user:`Stephanie Andrews <gitsteph>` and
:user:`Reshama Shaikh <reshamas>`.

- |Fix| :func:`datasets.make_multilabel_classification` now generates
`ValueError` for arguments `n_classes < 1` OR `length < 1`.
:pr:`16006` by :user:`Rushabh Vasani <rushabh-v>`.

:mod:`sklearn.feature_extraction`
.................................

Expand Down
11 changes: 11 additions & 0 deletions sklearn/datasets/_samples_generator.py
Expand Up @@ -342,6 +342,17 @@ def make_multilabel_classification(n_samples=100, n_features=20, n_classes=5,
Only returned if ``return_distributions=True``.
"""
if n_classes < 1:
raise ValueError(
"'n_classes' should be an integer greater than 0. Got {} instead."
.format(n_classes)
)
if length < 1:
raise ValueError(
"'length' should be an integer greater than 0. Got {} instead."
.format(length)
)

generator = check_random_state(random_state)
p_c = generator.rand(n_classes)
p_c /= p_c.sum()
Expand Down
12 changes: 12 additions & 0 deletions sklearn/datasets/tests/test_samples_generator.py
Expand Up @@ -222,6 +222,18 @@ def test_make_multilabel_classification_return_indicator_sparse():
assert sp.issparse(Y)


@pytest.mark.parametrize(
"params, err_msg",
[
({"n_classes": 0}, "'n_classes' should be an integer"),
({"length": 0}, "'length' should be an integer")
]
)
def test_make_multilabel_classification_valid_arguments(params, err_msg):
with pytest.raises(ValueError, match=err_msg):
make_multilabel_classification(**params)


def test_make_hastie_10_2():
X, y = make_hastie_10_2(n_samples=100, random_state=0)
assert X.shape == (100, 10), "X shape mismatch"
Expand Down

0 comments on commit 9378346

Please sign in to comment.