Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Argument validation in make_multilabel_classification() #16006

Merged
merged 9 commits into from Jan 4, 2020
9 changes: 9 additions & 0 deletions sklearn/datasets/_samples_generator.py
Expand Up @@ -342,6 +342,15 @@ def make_multilabel_classification(n_samples=100, n_features=20, n_classes=5,
Only returned if ``return_distributions=True``.

"""
# Validation of the arguments
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
if n_classes == 0 and not allow_unlabeled:
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Invalid set of arguments passed: " +
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
"n_classes = 0 and allow_unlabeled = False"
)
if length == 0:
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Invalid argument passed: length = 0")
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

generator = check_random_state(random_state)
p_c = generator.rand(n_classes)
p_c /= p_c.sum()
Expand Down
12 changes: 12 additions & 0 deletions sklearn/datasets/tests/test_samples_generator.py
Expand Up @@ -222,6 +222,18 @@ def test_make_multilabel_classification_return_indicator_sparse():
assert sp.issparse(Y)


def test_make_multilabel_classification_valid_arguments():
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
msg = ("Invalid set of arguments passed: " +
"n_classes = 0 and allow_unlabeled = False")
with pytest.raises(ValueError, match=msg):
make_multilabel_classification(allow_unlabeled=False,
n_classes=0)

msg = "Invalid argument passed: length = 0"
with pytest.raises(ValueError, match=msg):
make_multilabel_classification(length=0)


def test_make_hastie_10_2():
X, y = make_hastie_10_2(n_samples=100, random_state=0)
assert X.shape == (100, 10), "X shape mismatch"
Expand Down