Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX validate in fit for LabelBinarizer estimator #21434

Merged
merged 11 commits into from
Nov 2, 2021

Conversation

krumetoft
Copy link
Contributor

Reference Issues/PRs

This closes the LabelBinarizer part of #21406

What does this implement/fix? Explain your changes.

Removed a parameter check in init of LabelBinarizer, that is also done in function label_binarize during transform.

#DataUmbrella sprint

@glemaitre
Copy link
Member

I think that we should add an entry in the changelog since it could have an effect on third-party libraries.

Please add an entry to the change log at doc/whats_new/v1.1.rst. Like the other entries there, please reference this pull request with :pr: and credit yourself (and other contributors if applicable) with :user:. For instance, it should look like:

- |Fix| :class:`preprocessing.LabelBinarizer` nows validate input parameters in `fit`
  instead of `__init__`.
  :pr:`21434` by :user:`krumetoft  <krumetoft>`.

It should in the section:

:mod:`sklearn.preprocsessing`

@glemaitre glemaitre self-requested a review October 23, 2021 20:17
@glemaitre
Copy link
Member

You will need to edit the test where we were testing the pattern:

with pytest.raises(...):
    LabelBinarizer(...)

by calling fit on the instance since now __init__ will not raise the error anymore. At the same time, if you could add all the error messages that we should match, it would be a much better test. I am getting the following diff:

diff --git a/sklearn/preprocessing/_label.py b/sklearn/preprocessing/_label.py
index 72bd8c6d65..12f6d08b5c 100644
--- a/sklearn/preprocessing/_label.py
+++ b/sklearn/preprocessing/_label.py
@@ -275,6 +275,19 @@ class LabelBinarizer(TransformerMixin, BaseEstimator):
         self : object
             Returns the instance itself.
         """
+        if self.neg_label >= self.pos_label:
+            raise ValueError(
+                f"neg_label={self.neg_label} must be strictly less than "
+                f"pos_label={self.pos_label}."
+            )
+
+        if self.sparse_output and (self.pos_label == 0 or self.neg_label != 0):
+            raise ValueError(
+                "Sparse binarization is only supported with non "
+                "zero pos_label and zero neg_label, got "
+                f"pos_label={self.pos_label} and neg_label={self.neg_label}"
+            )
+
         self.y_type_ = type_of_target(y)
         if "multioutput" in self.y_type_:
             raise ValueError(
diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py
index 5142144bcb..3dc4afaf89 100644
--- a/sklearn/preprocessing/tests/test_label.py
+++ b/sklearn/preprocessing/tests/test_label.py
@@ -124,25 +124,35 @@ def test_label_binarizer_errors():
     lb = LabelBinarizer().fit(one_class)
 
     multi_label = [(2, 3), (0,), (0, 2)]
-    with pytest.raises(ValueError):
+    err_msg = "You appear to be using a legacy multi-label data representation."
+    with pytest.raises(ValueError, match=err_msg):
         lb.transform(multi_label)
 
     lb = LabelBinarizer()
-    with pytest.raises(ValueError):
+
+    err_msg = "This LabelBinarizer instance is not fitted yet"
+    with pytest.raises(ValueError, match=err_msg):
         lb.transform([])
-    with pytest.raises(ValueError):
+    with pytest.raises(ValueError, match=err_msg):
         lb.inverse_transform([])
 
-    with pytest.raises(ValueError):
-        LabelBinarizer(neg_label=2, pos_label=1)
-    with pytest.raises(ValueError):
-        LabelBinarizer(neg_label=2, pos_label=2)
-
-    with pytest.raises(ValueError):
-        LabelBinarizer(neg_label=1, pos_label=2, sparse_output=True)
+    input_labels = [0, 1, 0, 1]
+    err_msg = "neg_label=2 must be strictly less than pos_label=1."
+    with pytest.raises(ValueError, match=err_msg):
+        LabelBinarizer(neg_label=2, pos_label=1).fit(input_labels)
+    err_msg = "neg_label=2 must be strictly less than pos_label=2."
+    with pytest.raises(ValueError, match=err_msg):
+        LabelBinarizer(neg_label=2, pos_label=2).fit(input_labels)
+    err_msg = (
+        "Sparse binarization is only supported with non zero pos_label and zero "
+        "neg_label, got pos_label=2 and neg_label=1"
+    )
+    with pytest.raises(ValueError, match=err_msg):
+        LabelBinarizer(neg_label=1, pos_label=2, sparse_output=True).fit(input_labels)
 
     # Fail on y_type
-    with pytest.raises(ValueError):
+    err_msg = "foo format is not supported"
+    with pytest.raises(ValueError, match=err_msg):
         _inverse_binarize_thresholding(
             y=csr_matrix([[1, 2], [2, 1]]),
             output_type="foo",
@@ -152,11 +162,13 @@ def test_label_binarizer_errors():
 
     # Sequence of seq type should raise ValueError
     y_seq_of_seqs = [[], [1, 2], [3], [0, 1, 3], [2]]
-    with pytest.raises(ValueError):
+    err_msg = "You appear to be using a legacy multi-label data representation"
+    with pytest.raises(ValueError, match=err_msg):
         LabelBinarizer().fit_transform(y_seq_of_seqs)
 
     # Fail on the number of classes
-    with pytest.raises(ValueError):
+    err_msg = "The number of class is not equal to the number of dimension of y."
+    with pytest.raises(ValueError, match=err_msg):
         _inverse_binarize_thresholding(
             y=csr_matrix([[1, 2], [2, 1]]),
             output_type="foo",
@@ -165,7 +177,8 @@ def test_label_binarizer_errors():
         )
 
     # Fail on the dimension of 'binary'
-    with pytest.raises(ValueError):
+    err_msg = "output_type='binary', but y.shape"
+    with pytest.raises(ValueError, match=err_msg):
         _inverse_binarize_thresholding(
             y=np.array([[1, 2, 3], [2, 1, 3]]),
             output_type="binary",
@@ -174,9 +187,10 @@ def test_label_binarizer_errors():
         )
 
     # Fail on multioutput data
-    with pytest.raises(ValueError):
+    err_msg = "Multioutput target data is not supported with label binarization"
+    with pytest.raises(ValueError, match=err_msg):
         LabelBinarizer().fit(np.array([[1, 3], [2, 1]]))
-    with pytest.raises(ValueError):
+    with pytest.raises(ValueError, match=err_msg):
         label_binarize(np.array([[1, 3], [2, 1]]), classes=[1, 2, 3])
 
 

You can use it to edit your PR.

@glemaitre glemaitre changed the title LabelBinarizer - Removed a redundant param check in init #21406 FIX validate in fit for LabelBinarizer Oct 23, 2021
@krumetoft
Copy link
Contributor Author

Thank you, @glemaitre ! I've just updated doc/whats_new/v1.1. and test_label.py.

On a second review, I am wondering if it is an issue that input will be tested in transform, but not in fit (as the tests are performed by label_binarize in transform only)?

@glemaitre
Copy link
Member

On a second review, I am wondering if it is an issue that input will be tested in transform, but not in fit (as the tests are performed by label_binarize in transform only)?

If you check my diff, I move the validation in fit:

diff --git a/sklearn/preprocessing/_label.py b/sklearn/preprocessing/_label.py
index 72bd8c6d65..12f6d08b5c 100644
--- a/sklearn/preprocessing/_label.py
+++ b/sklearn/preprocessing/_label.py
@@ -275,6 +275,19 @@ class LabelBinarizer(TransformerMixin, BaseEstimator):
         self : object
             Returns the instance itself.
         """
+        if self.neg_label >= self.pos_label:
+            raise ValueError(
+                f"neg_label={self.neg_label} must be strictly less than "
+                f"pos_label={self.pos_label}."
+            )
+
+        if self.sparse_output and (self.pos_label == 0 or self.neg_label != 0):
+            raise ValueError(
+                "Sparse binarization is only supported with non "
+                "zero pos_label and zero neg_label, got "
+                f"pos_label={self.pos_label} and neg_label={self.neg_label}"
+            )
+
         self.y_type_ = type_of_target(y)
         if "multioutput" in self.y_type_:
             raise ValueError(

Otherwise, we don't do any validation.

@krumetoft
Copy link
Contributor Author

Apologies, I missed that - corrected now!

@glemaitre
Copy link
Member

Can you run black on the 2 files that are detected by our linter: https://dev.azure.com/scikit-learn/scikit-learn/_build/results?buildId=34025&view=logs&jobId=32e2e1bb-a28f-5b18-6cfc-3f01273f5609&j=32e2e1bb-a28f-5b18-6cfc-3f01273f5609&t=fc67071d-c3d4-58b8-d38e-cafc0d3c731a

In the future, you might want to install pre-commit. It will make the reformatting for you before committing.

@glemaitre glemaitre changed the title FIX validate in fit for LabelBinarizer FIX validate in fit for LabelBinarizer estimator Oct 24, 2021
@krumetoft
Copy link
Contributor Author

Can you run black on the 2 files that are detected by our linter: https://dev.azure.com/scikit-learn/scikit-learn/_build/results?buildId=34025&view=logs&jobId=32e2e1bb-a28f-5b18-6cfc-3f01273f5609&j=32e2e1bb-a28f-5b18-6cfc-3f01273f5609&t=fc67071d-c3d4-58b8-d38e-cafc0d3c731a

In the future, you might want to install pre-commit. It will make the reformatting for you before committing.

Thank you for your patience and suggestion - I will use pre-commit.

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Thank you, @krumetoft, for this first-time contribution!

@krumetoft
Copy link
Contributor Author

Thank you, @jjerphan! I need to thank Guillaume for his patience and guidance.
In case it is important for sprint tracking, I am a repeater :) but with profile @krumeto . I just could not figure out in time before the sprint how to have safely use two git profiles at the same time and ended up working with the corporate one.

@jjerphan
Copy link
Member

You can change a local git repository config to use another identity (e.g. @krumeto in this case) using git log with the --local option.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments about test. Otherwise LGTM

sklearn/preprocessing/tests/test_label.py Outdated Show resolved Hide resolved
sklearn/preprocessing/tests/test_label.py Outdated Show resolved Hide resolved
sklearn/preprocessing/tests/test_label.py Outdated Show resolved Hide resolved
@krumetoft
Copy link
Contributor Author

Hey @thomasjpfan, thank you, it makes sense! I changed test_label.py accordingly.

@jjerphan
Copy link
Member

jjerphan commented Oct 29, 2021

Regarding your last commit, @krumetoft, you can use the Co-authored-by attribute, see this piece of docs from GitHub.

@krumeto
Copy link
Contributor

krumeto commented Oct 29, 2021

Thank you, @jjerphan! I saw the piece of code in one of Olivier's comments and decided to give it a run :)

@glemaitre
Copy link
Member

Arff I did not see that there is a conflict. @krumetoft @krumeto Could you solve the merge conflict?

@krumeto
Copy link
Contributor

krumeto commented Nov 2, 2021

Hey @glemaitre The merge conflict is resolved (line 292 in _label.py).

@jjerphan jjerphan merged commit c241fe7 into scikit-learn:main Nov 2, 2021
@jjerphan
Copy link
Member

jjerphan commented Nov 2, 2021

Thank you, @krumeto!

samronsin pushed a commit to samronsin/scikit-learn that referenced this pull request Nov 30, 2021
)

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants