Permalink
Browse files

BUG handle two-class multilabel case in LabelBinarizer

Would throw an exception due to special handling of binary case.
  • Loading branch information...
1 parent b5429ab commit f161fabd669223c585fc27c0f6ff22f875bec93d @larsmans larsmans committed Oct 26, 2011
Showing with 14 additions and 3 deletions.
  1. +3 −3 sklearn/preprocessing/__init__.py
  2. +11 −0 sklearn/preprocessing/tests/test_preprocessing.py
@@ -514,10 +514,10 @@ def transform(self, y):
Y : numpy array of shape [n_samples, n_classes]
"""
- if len(self.classes_) == 2:
- Y = np.zeros((len(y), 1))
- else:
+ if self.multilabel or len(self.classes_) > 2:
Y = np.zeros((len(y), len(self.classes_)))
+ else:
+ Y = np.zeros((len(y), 1))
y_is_multilabel = _is_multilabel(y)
@@ -308,6 +308,17 @@ def test_label_binarizer_multilabel():
assert_array_equal(expected, got)
assert_equal(lb.inverse_transform(got), inp)
+ # regression test for the two-class multilabel case
+ lb = LabelBinarizer()
+
+ inp = [[1, 0], [0], [1], [0, 1]]
+ expected = np.array([[1, 1],
+ [1, 0],
+ [0, 1],
+ [1, 1]])
+ got = lb.fit_transform(inp)
+ assert_array_equal(expected, got)
+
def test_label_binarizer_errors():
"""Check that invalid arguments yield ValueError"""

0 comments on commit f161fab

Please sign in to comment.