Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

BUG handle two-class multilabel case in LabelBinarizer

Would throw an exception due to special handling of binary case.
  • Loading branch information...
commit f161fabd669223c585fc27c0f6ff22f875bec93d 1 parent b5429ab
Lars larsmans authored
6 sklearn/preprocessing/__init__.py
View
@@ -514,10 +514,10 @@ def transform(self, y):
Y : numpy array of shape [n_samples, n_classes]
"""
- if len(self.classes_) == 2:
- Y = np.zeros((len(y), 1))
- else:
+ if self.multilabel or len(self.classes_) > 2:
Y = np.zeros((len(y), len(self.classes_)))
+ else:
+ Y = np.zeros((len(y), 1))
y_is_multilabel = _is_multilabel(y)
11 sklearn/preprocessing/tests/test_preprocessing.py
View
@@ -308,6 +308,17 @@ def test_label_binarizer_multilabel():
assert_array_equal(expected, got)
assert_equal(lb.inverse_transform(got), inp)
+ # regression test for the two-class multilabel case
+ lb = LabelBinarizer()
+
+ inp = [[1, 0], [0], [1], [0, 1]]
+ expected = np.array([[1, 1],
+ [1, 0],
+ [0, 1],
+ [1, 1]])
+ got = lb.fit_transform(inp)
+ assert_array_equal(expected, got)
+
def test_label_binarizer_errors():
"""Check that invalid arguments yield ValueError"""
Please sign in to comment.
Something went wrong with that request. Please try again.