From f161fabd669223c585fc27c0f6ff22f875bec93d Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Wed, 26 Oct 2011 16:55:39 +0200 Subject: [PATCH] BUG handle two-class multilabel case in LabelBinarizer Would throw an exception due to special handling of binary case. --- sklearn/preprocessing/__init__.py | 6 +++--- sklearn/preprocessing/tests/test_preprocessing.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/__init__.py b/sklearn/preprocessing/__init__.py index 732b6f386113e..90965b7a5fc2c 100644 --- a/sklearn/preprocessing/__init__.py +++ b/sklearn/preprocessing/__init__.py @@ -514,10 +514,10 @@ def transform(self, y): Y : numpy array of shape [n_samples, n_classes] """ - if len(self.classes_) == 2: - Y = np.zeros((len(y), 1)) - else: + if self.multilabel or len(self.classes_) > 2: Y = np.zeros((len(y), len(self.classes_))) + else: + Y = np.zeros((len(y), 1)) y_is_multilabel = _is_multilabel(y) diff --git a/sklearn/preprocessing/tests/test_preprocessing.py b/sklearn/preprocessing/tests/test_preprocessing.py index 330d5902b9803..68dba4fb699e9 100644 --- a/sklearn/preprocessing/tests/test_preprocessing.py +++ b/sklearn/preprocessing/tests/test_preprocessing.py @@ -308,6 +308,17 @@ def test_label_binarizer_multilabel(): assert_array_equal(expected, got) assert_equal(lb.inverse_transform(got), inp) + # regression test for the two-class multilabel case + lb = LabelBinarizer() + + inp = [[1, 0], [0], [1], [0, 1]] + expected = np.array([[1, 1], + [1, 0], + [0, 1], + [1, 1]]) + got = lb.fit_transform(inp) + assert_array_equal(expected, got) + def test_label_binarizer_errors(): """Check that invalid arguments yield ValueError"""