Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG added classes parameter to LabelBinarizer #1643

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/modules/preprocessing.rst
Expand Up @@ -360,7 +360,7 @@ matrix from a list of multi-class labels::

>>> lb = preprocessing.LabelBinarizer()
>>> lb.fit([1, 2, 6, 4, 2])
LabelBinarizer(neg_label=0, pos_label=1)
LabelBinarizer(classes=None, label_type='auto', neg_label=0, pos_label=1)
>>> lb.classes_
array([1, 2, 4, 6])
>>> lb.transform([1, 6])
Expand Down
2 changes: 1 addition & 1 deletion sklearn/multiclass.py
Expand Up @@ -247,7 +247,7 @@ def predict_proba(self, X):
@property
def multilabel_(self):
"""Whether this is a multilabel classifier"""
return self.label_binarizer_.multilabel
return self.label_binarizer_.multilabel_

def score(self, X, y):
if self.multilabel_:
Expand Down
136 changes: 99 additions & 37 deletions sklearn/preprocessing.py
Expand Up @@ -16,6 +16,7 @@
from .utils import check_arrays, array2d, atleast2d_or_csr, safe_asarray
from .utils import warn_if_not_float
from .utils.fixes import unique
from .utils import deprecated

from .utils.sparsefuncs import inplace_csr_row_normalize_l1
from .utils.sparsefuncs import inplace_csr_row_normalize_l2
Expand Down Expand Up @@ -622,6 +623,18 @@ def _is_multilabel(y):
_is_label_indicator_matrix(y))


def _get_label_type(y):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, you can put this in utils.multiclass. :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose we replace is_multilabel with something like this, that moreover may handle multiple ys: #1985

multilabel = _is_multilabel(y)
if multilabel:
if _is_label_indicator_matrix(y):
label_type = "multilabel-indicator"
else:
label_type = "multilabel-list"
else:
label_type = "multiclass"
return label_type


class OneHotEncoder(BaseEstimator, TransformerMixin):
"""Encode categorical integer features using a one-hot aka one-of-K scheme.

Expand Down Expand Up @@ -919,24 +932,41 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):

Parameters
----------

neg_label: int (default: 0)
neg_label : int (default: 0)
Value with which negative labels must be encoded.

pos_label: int (default: 1)
pos_label : int (default: 1)
Value with which positive labels must be encoded.

classes : ndarray of int or None (default)
Array of possible classes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of the classes parameter is not clear from the comment. Is it merely so that you needn't fit if you already know the class labels? Or is it to define an ordering of the labels? What happens if an excess label is given? What happens if a label is absent but found in the y passed to fit, or in transform?

Are there cases where finding an absent class means we should raise an error, but others where it makes sense to remove the entry, or substitute a different class? Should we have a handle_unknown parameter selecting among these, or is that excessively flexible? Certainly, the behaviour is implicit at the moment, and that's a bad thing.

There is really no reason the same functionality shouldn't apply to LabelEncoder, and I think they should inherit from a common class (as in jnothman@8669605).


label_type : string, default="auto"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we call this target_type?

Expected type of y.
Possible values are:
- "multiclass", y is an array-like of ints
- "multilabel-indicator", y is an indicator matrix of classes
- "multiclass-list", y is a list of lists of labels
- "auto", the form of y is determined during 'fit'. If 'fit' is not
called, multiclass is assumed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you wrote y is ..., I think that an article is missing.
It may be better to say something like y is in the indicator matrix of classes format, y is in the list of lists of labels format.


Attributes
----------
`classes_`: array of shape [n_class]
`classes_` : array of shape [n_class]
Holds the label for each class.

`label_type_` : string
The type of label used. Inferred from training data if
``label_type="auto"``, otherwise identical to the ``label_type``
parameter.


Examples
--------
>>> from sklearn import preprocessing
>>> lb = preprocessing.LabelBinarizer()
>>> lb.fit([1, 2, 6, 4, 2])
LabelBinarizer(neg_label=0, pos_label=1)
LabelBinarizer(classes=None, label_type='auto', neg_label=0, pos_label=1)
>>> lb.classes_
array([1, 2, 4, 6])
>>> lb.transform([1, 6])
Expand All @@ -950,19 +980,29 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
array([1, 2, 3])
"""

def __init__(self, neg_label=0, pos_label=1):
if neg_label >= pos_label:
raise ValueError("neg_label must be strictly less than pos_label.")
def __init__(self, neg_label=0, pos_label=1, classes=None,
label_type='auto'):

self.neg_label = neg_label
self.pos_label = pos_label
self.classes = classes
self.label_type = label_type

def _check_fitted(self):
if not hasattr(self, "classes_"):
raise ValueError("LabelBinarizer was not fitted yet.")
if self.classes is not None:
self.classes_ = np.unique(self.classes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of the classes won't be preserve.
I don't know if it matters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know. I was wondering about that. I think I'd like to keep it like this for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only place, where I have seen such things is in the test_precision_recall_f1_score_multiclass in test_metrics.py.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"such things" meaning preserving the class order?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the class order must be known to output the precision, recall, fscore in the right order (see averaging type is equal to None).

# default to not doing multi-label things
self.label_type_ = (self.label_type
if self.label_type != "auto"
else "multiclass")
else:
raise ValueError("LabelBinarizer was not fitted yet.")

def fit(self, y):
"""Fit label binarizer
"""Fit label binarizer.

No-op if parameter ``classes`` was specified.

Parameters
----------
Expand All @@ -973,16 +1013,36 @@ def fit(self, y):
Returns
-------
self : returns an instance of self.

"""
self.multilabel = _is_multilabel(y)
if self.multilabel:
self.indicator_matrix_ = _is_label_indicator_matrix(y)
if self.indicator_matrix_:
self.classes_ = np.arange(y.shape[1])
else:
self.classes_ = np.array(sorted(set.union(*map(set, y))))
if self.neg_label >= self.pos_label:
raise ValueError("neg_label must be strictly less than pos_label.")

label_type = _get_label_type(y)

if self.label_type not in ["auto", label_type]:
raise ValueError("label_type was set to %s, but got y of type %s."
% (self.label_type, label_type))

self.label_type_ = label_type

if label_type == "multilabel-indicator":
classes = np.arange(y.shape[1])
elif label_type == "multilabel-list":
classes = np.array(sorted(set.union(*map(set, y))))
else:
classes = np.unique(y)

if self.classes is not None:
classes_set = set(classes)
if not set.issubset(classes_set, self.classes):
difference = set.difference(classes_set, self.classes)
warnings.warn("Found class(es) %s, which was not contained "
"in parameter ``classes`` and will be ignored."
% str(list(difference)))
self.classes_ = np.unique(self.classes)
else:
self.classes_ = np.unique(y)
self.classes_ = classes
return self

def transform(self, y):
Expand All @@ -1000,31 +1060,24 @@ def transform(self, y):
Returns
-------
Y : numpy array of shape [n_samples, n_classes]

"""
self._check_fitted()

if self.multilabel or len(self.classes_) > 2:
if _is_label_indicator_matrix(y):
# nothing to do as y is already a label indicator matrix
return y

label_type = _get_label_type(y)
if label_type != self.label_type_:
raise ValueError("label_type was set to %s, but got y of type %s."
% (self.label_type_, label_type))
if label_type == "multilabel-indicator":
# nothing to do as y is already a label indicator matrix
return y
elif label_type == "multilabel-list" or len(self.classes_) > 2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the case len(self.classes_) > 2 used here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because for two classes, the output is 1d.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

Y = np.zeros((len(y), len(self.classes_)), dtype=np.int)
else:
Y = np.zeros((len(y), 1), dtype=np.int)

Y += self.neg_label

y_is_multilabel = _is_multilabel(y)

if y_is_multilabel and not self.multilabel:
raise ValueError("The object was not fitted with multilabel"
" input!")

elif self.multilabel:
if not _is_multilabel(y):
raise ValueError("y should be a list of label lists/tuples,"
"got %r" % (y,))

if label_type == "multilabel-list":
# inverse map: label => column index
imap = dict((v, k) for k, v in enumerate(self.classes_))

Expand Down Expand Up @@ -1089,10 +1142,10 @@ def inverse_transform(self, Y, threshold=None):
half = (self.pos_label - self.neg_label) / 2.0
threshold = self.neg_label + half

if self.multilabel:
if self.multilabel_:
Y = np.array(Y > threshold, dtype=int)
# Return the predictions in the same format as in fit
if self.indicator_matrix_:
if self.label_type_ == "multilabel-indicator":
# Label indicator matrix format
return Y
else:
Expand All @@ -1108,6 +1161,15 @@ def inverse_transform(self, Y, threshold=None):

return self.classes_[y]

@property
def multilabel_(self):
return self.label_type_ in ["multilabel-list", "multilabel-indicator"]

@property
@deprecated("it will be removed in 0.15. Use ``label_type_`` instead.")
def label_indicator_(self):
return self.label_type_ == "multilabel-indicator"


class KernelCenterer(BaseEstimator, TransformerMixin):
"""Center a kernel matrix
Expand Down
62 changes: 62 additions & 0 deletions sklearn/tests/test_preprocessing.py
Expand Up @@ -8,6 +8,7 @@
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_false

Expand Down Expand Up @@ -640,6 +641,67 @@ def test_label_binarizer_iris():
assert_almost_equal(accuracy, accuracy2)


def test_label_binarizer_classes():
# check that explictly giving classes works
lb = LabelBinarizer(classes=np.arange(3))
y = np.ones(10)
# if classes is specified, we don't need to fit
assert_equal(lb.transform(y).shape, (10, 3))
assert_array_equal(y, np.argmax(lb.transform(y), axis=1))

# check that fitting doesn't change the shape
assert_equal(lb.fit_transform(y).shape, (10, 3))

# also works with weird classes:
lb = LabelBinarizer(classes=['a', 'b', 'see'])
transformed = lb.transform(['see', 'see'])
assert_equal(transformed.shape, (2, 3))
assert_array_equal(np.argmax(transformed, axis=1), [2, 2])
# test inverse transform
assert_array_equal(['see', 'see'], lb.inverse_transform(transformed))

# also works with multilabel data if we say so:
lb = LabelBinarizer(classes=np.arange(1, 3),
label_type="multilabel-list")
y = [(1, 2), (1,), ()]
Y = np.array([[1, 1],
[1, 0],
[0, 0]])
assert_array_equal(lb.transform(y), Y)
assert_array_equal(lb.fit_transform(y), Y)
# inverse transform of label indicator matrix to label
assert_array_equal(lb.inverse_transform(Y), y)

# inverse transform with indicator_matrix=True
lb = LabelBinarizer(classes=np.arange(1, 3),
label_type="multilabel-indicator")
assert_array_equal(lb.inverse_transform(Y), Y)

lb = LabelBinarizer(classes=np.arange(1, 3), label_type="multiclass")
assert_raise_message(ValueError, "label_type was set to multiclass, "
"but got y of type multilabel-list.",
lb.fit, y)
lb = LabelBinarizer(classes=np.arange(1, 3))
assert_raise_message(ValueError, "label_type was set to multiclass,"
" but got y of type multilabel-list.",
lb.transform, y)

# check that labels present at fit time that are not in 'classes'
# will be ignored but a warning will be shown
lb = LabelBinarizer(classes=[1, 2])
with warnings.catch_warnings(record=True) as w:
transformed = lb.fit_transform([0, 1, 2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing this :)


# warning was raised:
assert_equal(len(w), 1)
assert_true("not contained in parameter ``classes`` and will be ignored."
in str(w[0]))

# result is as for binary case
assert_equal(transformed.shape, (3, 1))
assert_array_equal(transformed.ravel(), [0, 0, 1])


def test_label_binarizer_multilabel_unlabeled():
"""Check that LabelBinarizer can handle an unlabeled sample"""
lb = LabelBinarizer()
Expand Down