From 93783462c2194bba2ce6469934ffebc6ec12b3c2 Mon Sep 17 00:00:00 2001 From: Rushabh Vasani Date: Sat, 4 Jan 2020 23:05:25 +0530 Subject: [PATCH] Argument validation in make_multilabel_classification() (#16006) --- doc/whats_new/v0.23.rst | 4 ++++ sklearn/datasets/_samples_generator.py | 11 +++++++++++ sklearn/datasets/tests/test_samples_generator.py | 12 ++++++++++++ 3 files changed, 27 insertions(+) diff --git a/doc/whats_new/v0.23.rst b/doc/whats_new/v0.23.rst index 1941aacb7a7b0..b476c34b380cc 100644 --- a/doc/whats_new/v0.23.rst +++ b/doc/whats_new/v0.23.rst @@ -63,6 +63,10 @@ Changelog by :user:`Stephanie Andrews ` and :user:`Reshama Shaikh `. +- |Fix| :func:`datasets.make_multilabel_classification` now generates + `ValueError` for arguments `n_classes < 1` OR `length < 1`. + :pr:`16006` by :user:`Rushabh Vasani `. + :mod:`sklearn.feature_extraction` ................................. diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index 8893aedbdfc5a..10c87d988c324 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -342,6 +342,17 @@ def make_multilabel_classification(n_samples=100, n_features=20, n_classes=5, Only returned if ``return_distributions=True``. """ + if n_classes < 1: + raise ValueError( + "'n_classes' should be an integer greater than 0. Got {} instead." + .format(n_classes) + ) + if length < 1: + raise ValueError( + "'length' should be an integer greater than 0. Got {} instead." + .format(length) + ) + generator = check_random_state(random_state) p_c = generator.rand(n_classes) p_c /= p_c.sum() diff --git a/sklearn/datasets/tests/test_samples_generator.py b/sklearn/datasets/tests/test_samples_generator.py index 433baca985b87..c683e277c705a 100644 --- a/sklearn/datasets/tests/test_samples_generator.py +++ b/sklearn/datasets/tests/test_samples_generator.py @@ -222,6 +222,18 @@ def test_make_multilabel_classification_return_indicator_sparse(): assert sp.issparse(Y) +@pytest.mark.parametrize( + "params, err_msg", + [ + ({"n_classes": 0}, "'n_classes' should be an integer"), + ({"length": 0}, "'length' should be an integer") + ] +) +def test_make_multilabel_classification_valid_arguments(params, err_msg): + with pytest.raises(ValueError, match=err_msg): + make_multilabel_classification(**params) + + def test_make_hastie_10_2(): X, y = make_hastie_10_2(n_samples=100, random_state=0) assert X.shape == (100, 10), "X shape mismatch"