Skip to content

Commit

Permalink
TST label_binarize_dense_output
Browse files Browse the repository at this point in the history
  • Loading branch information
rsivapr committed Sep 25, 2013
1 parent c855fa7 commit 83a496f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 8 deletions.
22 changes: 14 additions & 8 deletions sklearn/preprocessing/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from scipy.sparse import coo_matrix
from scipy.sparse import coo_matrix, issparse

from ..base import BaseEstimator, TransformerMixin

Expand Down Expand Up @@ -170,6 +170,11 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
pos_label : int (default: 1)
Value with which positive labels must be encoded.
dense_output : boolean (default: False)
If True, ensure that the output of label_binarize is a
dense numpy array even if the binarize matrix is sparse.
If False, the binarized data use a sparse representation.
Attributes
----------
`classes_` : array of shape [n_class]
Expand Down Expand Up @@ -207,12 +212,13 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
LabelBinarizer with fixed classes.
"""

def __init__(self, neg_label=0, pos_label=1):
def __init__(self, neg_label=0, pos_label=1, dense_output=True):
if neg_label >= pos_label:
raise ValueError("neg_label must be strictly less than pos_label.")

self.neg_label = neg_label
self.pos_label = pos_label
self.dense_output = dense_output

@property
@deprecated("Attribute `multilabel` was renamed to `multilabel_` in "
Expand Down Expand Up @@ -260,7 +266,7 @@ def transform(self, y):
Returns
-------
Y : numpy array of shape [n_samples, n_classes]
Y : numpy array or COO matrix of shape [n_samples, n_classes]
"""
self._check_fitted()

Expand All @@ -271,9 +277,9 @@ def transform(self, y):
" input.")

return label_binarize(y, self.classes_,
multilabel=self.multilabel_,
pos_label=self.pos_label,
neg_label=self.neg_label)
neg_label=self.neg_label,
dense_output=self.dense_output)

def inverse_transform(self, Y, threshold=None):
"""Transform binary labels back to multi-class labels
Expand Down Expand Up @@ -315,13 +321,13 @@ def inverse_transform(self, Y, threshold=None):
threshold = self.neg_label + half

if self.multilabel_:
if not(isinstance(Y, coo_matrix)):
if not(issparse(Y)):
Y = np.array(Y > threshold, dtype=int)
# Return the predictions in the same format as in fit
if self.indicator_matrix_:
# Label indicator matrix format
return Y
elif isinstance(Y, coo_matrix):
elif issparse(Y):
# Splitting and processing sparse matrix
y = []
for i in range(Y.shape[0]):
Expand Down Expand Up @@ -404,7 +410,7 @@ def label_binarize(y, classes, neg_label=0, pos_label=1,
raise ValueError("neg_label cannot equal pos_label")

y_type = type_of_target(y)
n_samples = len(y)
n_samples = y.shape[0] if issparse(y) else len(y)
n_classes = len(classes)
classes = np.asarray(classes)
sorted_class = np.sort(classes)
Expand Down
77 changes: 77 additions & 0 deletions sklearn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.preprocessing.label import LabelEncoder
from sklearn.preprocessing.label import label_binarize

from scipy.sparse import coo_matrix, csc_matrix, csr_matrix

from sklearn import datasets
from sklearn.linear_model.stochastic_gradient import SGDClassifier
Expand Down Expand Up @@ -147,6 +148,17 @@ def test_label_binarizer_multilabel():
assert_equal([set(x) for x in lb.inverse_transform(got)],
[set(x) for x in inp])

def test_label_binarizer_sparse_rep():
# TODO !!!


lb = LabelBinarizer(neg_label=neg_label, pos_label=pos_label,
dense_output=dense_output)
output = lb.fit_transform(y)
assert_array_equal(toarray(output), expected)
inverse_output = lb.inverse_transform(output)
assert_array_equal(toarray(inverse_output), y)


def test_label_binarizer_errors():
"""Check that invalid arguments yield ValueError"""
Expand All @@ -164,6 +176,8 @@ def test_label_binarizer_errors():
assert_raises(ValueError, LabelBinarizer, neg_label=2, pos_label=1)
assert_raises(ValueError, LabelBinarizer, neg_label=2, pos_label=2)

lb = LabelBinarizer(neg_label=-1, dense_output=False)
assert_raises(ValueError, lb.transform, [(1,0,2)])

def test_label_encoder():
"""Test LabelEncoder's transform and inverse_transform methods"""
Expand Down Expand Up @@ -253,3 +267,66 @@ def test_label_binarize_with_multilabel_indicator():

output = lb.fit(y).transform(y)
assert_array_equal(output, expected)

############

def test_label_binarize_with_class_order():
out = label_binarize([1, 6], classes=[1, 2, 4, 6])
expected = np.array([[1, 0, 0, 0], [0, 0, 0, 1]])
assert_array_equal(out, expected)

# Modified class order
out = label_binarize([1, 6], classes=[1, 6, 4, 2])
expected = np.array([[1, 0, 0, 0], [0, 1, 0, 0]])
assert_array_equal(out, expected)

def check_dense_output(y, classes, pos_label, neg_label, expected):
for dense_output in [True, False]:
output = label_binarize(y, classes, neg_label=neg_label,
pos_label=pos_label, dense_output=dense_output)
assert_array_equal(toarray(output), expected)

def test_label_binarize_dense_output_binary():
y = [0, 1, 0]
classes = [0, 1]
pos_label = 2
neg_label = -1
expected = np.array([[2, -1], [-1, 2], [2, -1]])[:, 1].reshape((-1, 1))

yield check_dense_output, y, classes, pos_label, neg_label, expected


def test_label_binarize_dense_output_multiclass():
y = [0, 1, 2]
classes = [0, 1, 2]
pos_label = 2
neg_label = 0
expected = 2 * np.eye(3)

yield check_dense_output, y, classes, pos_label, neg_label, expected

assert_raises(ValueError, label_binarize, y, classes, neg_label=-1,
pos_label=pos_label, dense_output=False)


def test_label_binarize_dense_output_multilabel():
y_seq = [(1,), (0, 1, 2), tuple()]
y_seq_repeated = [(1, 1,), (1, 0, 2, 2), tuple()]
y_ind = np.array([[0, 1, 0], [1, 1, 1], [0, 0, 0]])
classes = [0, 1, 2]
pos_label = 2
neg_label = 0
expected = pos_label * y_ind
case_list = [y_seq, y_ind, y_seq_repeated, coo_matrix(y_ind),
csr_matrix(y_ind), csc_matrix(y_ind)]
for y in case_list:
yield check_dense_output, y, classes, pos_label, neg_label, expected

assert_raises(ValueError, label_binarize, y, classes, neg_label=-1,
pos_label=pos_label, dense_output=False)

def test_label_binarize_errors():
y = [1,2,3,1]
classes = [0,1,2,3]
assert_raises(ValueError, label_binarize, y, classes, neg_label=1,
pos_label=1)

0 comments on commit 83a496f

Please sign in to comment.