Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the issue #3745, the code book generation for OutputCodeClassifier #3768

Closed
wants to merge 11 commits into from
31 changes: 22 additions & 9 deletions doc/modules/multiclass.rst
Expand Up @@ -200,9 +200,9 @@ matrix which keeps track of the location/code of each class is called the
code book. The code size is the dimensionality of the aforementioned space.
Intuitively, each class should be represented by a code as unique as
possible and a good code book should be designed to optimize classification
accuracy. In this implementation, we simply use a randomly-generated code
book as advocated in [2]_ although more elaborate methods may be added in the
future.
accuracy. In this implementation, we randomly generate subset of the exhaustive
codes, suggested in [1]_ multiple times and pick up the one with most
separation between different classes.

At fitting time, one binary classifier per bit in the code book is fitted.
At prediction time, the classifiers are used to project new points in the
Expand All @@ -222,8 +222,21 @@ one-vs-the-rest. In this case, some classifiers will in theory correct for
the mistakes made by other classifiers, hence the name "error-correcting".
In practice, however, this may not happen as classifier mistakes will
typically be correlated. The error-correcting output codes have a similar
effect to bagging.

effect to bagging. The maximum value for ``code_size`` is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit strange given that code_size is a percentage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also feel a bit strange. I just use the original name before. I guess the reason is the n_classes will only be decided after fit is called. But the range of code_size still depends on n_classes.

``2^(n_classes-1)-1 / n_classes`` as suggested by [1]_, since the codes with
all 0 or 1, and the complement of existing codes need to be excluded.

The ``strategy`` attribute allows the user the select the strategy to
code the classes. Two strategies are supported currently: (1) ``"random"``,
a random ``n_classes x int(n_classes * code_size)`` matrix is generated, and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"n_classes x " => "n_classes * "

entries > 0.5 are set to be 1 and the rest to be 0 or -1; (2)
``"max_hamming"``, random subsets of the exhaustive code are sampled,
and the one with the largest Hamming distance is chosen.

The ``max_iter`` attribute is used when
``strategy="max_hamming"`` allows the user to generate a code
book with most separation between classes from several randomly generated
subsets of the exhaustive code book.

Multiclass learning
-------------------
Expand All @@ -236,14 +249,14 @@ Below is an example of multiclass learning using Output-Codes::
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf = OutputCodeClassifier(LinearSVC(random_state=0),
... code_size=2, random_state=0)
... code_size=1, max_iter=5, random_state=0)
>>> clf.fit(X, y).predict(X)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1,
1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

.. topic:: References:
Expand Down
99 changes: 87 additions & 12 deletions sklearn/multiclass.py
Expand Up @@ -41,8 +41,9 @@
from .base import BaseEstimator, ClassifierMixin, clone, is_classifier
from .base import MetaEstimatorMixin
from .preprocessing import LabelBinarizer
from .metrics.pairwise import euclidean_distances
from .metrics.pairwise import euclidean_distances, pairwise_distances
from .utils import check_random_state
from .utils.random import sample_without_replacement
from .utils.validation import _num_samples
from .utils import deprecated
from .externals.joblib import Parallel
Expand Down Expand Up @@ -592,6 +593,40 @@ def predict_ecoc(estimators, classes, code_book, X):
return ecoc.predict(X)


def _random_code_book(n_classes, random_state, code_size):
"""Random generate a code book."""
code_book = random_state.random_sample((n_classes, code_size))
code_book[code_book > 0.5] = 1
code_book[code_book != 1] = 0
return code_book


def _max_hamming_code_book(n_classes, random_state, code_size, max_iter):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A ValueError should be raised if max_iter hasn't the proper value.

"""Randomly sample subsets of exhaustive code for n_classes, and choose
the one gives the largest hamming distances.
"""
max_code_size = np.power(2, n_classes-1) - 1
if code_size > max_code_size:
raise ValueError("The code size is larger than the possible "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say The number of code words is larger than the number of ...

"exhaustive codes.")
if np.power(2, code_size) < n_classes:
raise ValueError("The code size must be large enough to "
"distinguish every class.")
tmp_code_book = np.zeros((n_classes, code_size))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would initialize the code book with the best one as default.

tmp_code_book => best_code_book.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will avoid issue some edge case issues.

dist = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One the project convention is to avoid abreviation and have plain variable name. I would rename this to sum_pairwise_distances or something similar.

for k in range(max_iter):
p = sample_without_replacement(n_samples=code_size,
n_population=max_code_size)
tmp_code_book = (p[:, None] + max_code_size+1 &
(1 << np.arange(n_classes-1, -1, -1)) > 0
).astype(int).T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on what it's done on this line?

tmp_distance = np.sum(pairwise_distances(tmp_code_book, metric='hamming'))
if tmp_distance > dist:
dist = tmp_distance
code_book = tmp_code_book
return code_book


class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
"""(Error-Correcting) Output-Code multiclass strategy

Expand All @@ -616,6 +651,13 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
one-vs-the-rest. A number greater than 1 will require more classifiers
than one-vs-the-rest.

max_iter : int, default: 10
Maximum number of iteration to generate a good code. An integer larger
than 0. Each iteration, a random code will be generated and the old
code will be replaced only when the total Hamming distance between
code words increases.
This parameter will be used when the strategy is "max_hamming"

random_state : numpy.RandomState, optional
The generator used to initialize the codebook. Defaults to
numpy.random.
Expand All @@ -626,6 +668,18 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
useful for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are
used. Thus for n_jobs = -2, all CPUs but one are used.

strategy : str, {'auto', 'random', 'max_hamming'}, optional, default: "auto"
The strategy to generate a code book for all classes. Three options are
avalable currently:

(1) "random": randomly generate a n_class x int(n_class*code_size)
matrix, and set entries > 0.5 to be 1, the rest to be 0 or -1;
(2) "max_hamming": select subset of exhaustive code book mutiple times
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I don't understand where the max_hamming strategy is explained and presented in "Error-correcting ouput codes library". Do you have a reference for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important since we strive to merge only well established algorithm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, given my understanding of the algorithm it would be better name "iter_haming", since it tries to iteravely maximize the hamming distance between codes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This algorithm is similar to the dense random in that paper. However, the way they generate code book for each iteration is the old implementation in this project, which I believe is problematic. So I change to the current implementation.

I totally understand your point that only well-established algorithm could be merged in the project. My point is the original implementation here is a kind of heuristic (obviously not "well-established"), that's why I want to improve it. More important, as I understand, the paper by Dietterich et. al was to propose the general idea of error correcting output coding for multi-class problem, and they provided the design principle for ECOC, i.e. maximizing the row-separation and column separation. My implementation is kind of stupid way to optimize the row-separation, by repeatedly sampling the valid code books and choose the "best" one. By "valid code book", I mean the subset of codewords from the exhaustive code from Dietterich et. al. Maybe I can maximize the column separation as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you cite the original paper Reducing Multiclass to Binary: A Unifying Approach for Margin Classifiers? It has 1439 citation which is sufficient for inclusion.

randomly, and choose the one gives the largest hamming distances
between classes.
(3) "auto": use "max_hamming" if the code_size is in the valid range
for it; otherwise, "random" strategy will be used.

Attributes
----------
estimators_ : list of `int(n_classes * code_size)` estimators
Expand Down Expand Up @@ -654,13 +708,23 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
.. [3] "The Elements of Statistical Learning",
Hastie T., Tibshirani R., Friedman J., page 606 (second-edition)
2008.

.. [4] "Error-correcting ouput codes library,"
Escalera, Sergio, Oriol Pujol, and Petia Radeva,
The Journal of Machine Learning Research 11, page 661-664,
2010.

"""

def __init__(self, estimator, code_size=1.5, random_state=None, n_jobs=1):
def __init__(self, estimator, code_size=1, max_iter=10,
random_state=None, n_jobs=1, strategy="auto"):
self.estimator = estimator
self.code_size = code_size
self.max_iter = max_iter
self.random_state = random_state
self.n_jobs = n_jobs
self.strategy = strategy


def fit(self, X, y):
"""Fit underlying estimators.
Expand All @@ -682,21 +746,32 @@ def fit(self, X, y):
"".format(self.code_size))

_check_estimator(self.estimator)
random_state = check_random_state(self.random_state)

self.classes_ = np.unique(y)

random_state = check_random_state(self.random_state)
n_classes = self.classes_.shape[0]
code_size_ = int(n_classes * self.code_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename code_size_ to n_code_words ? It's wasn't cleared that this was a percentage anymore.


# FIXME: there are more elaborate methods than generating the codebook
# randomly.
self.code_book_ = random_state.random_sample((n_classes, code_size_))
self.code_book_[self.code_book_ > 0.5] = 1

if hasattr(self.estimator, "decision_function"):
self.code_book_[self.code_book_ != 1] = -1
max_code_size = np.power(2, n_classes-1) - 1
if self.strategy == "auto":
if code_size_ > max_code_size or np.power(2, code_size_) < n_classes:
self.code_book_ = _random_code_book(n_classes, random_state, code_size_)
else:
self.code_book_ = _max_hamming_code_book(n_classes,
random_state,
code_size_,
self.max_iter)
elif self.strategy == "random":
self.code_book_ = _random_code_book(n_classes, random_state, code_size_)
elif self.strategy == "max_hamming":
self.code_book_ = _max_hamming_code_book(n_classes,
random_state,
code_size_,
self.max_iter)
else:
self.code_book_[self.code_book_ != 1] = 0
raise ValueError("Unknown coding strategy %r" % self.strategy)
if hasattr(self.estimator, "decision_function"):
self.code_book_[self.code_book_ == 0] = -1

classes_index = dict((c, i) for i, c in enumerate(self.classes_))

Expand Down
64 changes: 59 additions & 5 deletions sklearn/tests/test_multiclass.py
Expand Up @@ -13,6 +13,8 @@
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multiclass import OneVsOneClassifier
from sklearn.multiclass import OutputCodeClassifier
from sklearn.multiclass import _random_code_book
from sklearn.multiclass import _max_hamming_code_book

from sklearn.multiclass import fit_ovr
from sklearn.multiclass import fit_ovo
Expand All @@ -24,6 +26,7 @@

from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics.pairwise import pairwise_distances

from sklearn.preprocessing import LabelBinarizer

Expand Down Expand Up @@ -483,24 +486,75 @@ def test_ovo_string_y():
ovo.fit(X, y)
assert_array_equal(y, ovo.predict(X))

def test_code_book_functions():
random_state = np.random
random_state.seed(0)
code_book = _random_code_book(3, random_state, 10000)
proportion_of_1 = np.sum((code_book==1).astype(int)) * 1.0 / 30000
assert_true(proportion_of_1 > 0.48 and proportion_of_1 < 0.52)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use assert_greater and assert_lower here.

code_book = _max_hamming_code_book(5, random_state, 15, 10)
assert_equal(5, code_book.shape[0])
assert_equal(15, code_book.shape[1])

def test_ecoc_exceptions():
ecoc = OutputCodeClassifier(LinearSVC(random_state=0))
assert_raises(ValueError, ecoc.predict, [])

ecoc = OutputCodeClassifier(LinearSVC(random_state=0),
strategy="abc")
assert_raises(ValueError, ecoc.fit, [], [])
ecoc = OutputCodeClassifier(LinearSVC(random_state=0), code_size=1.5,
strategy="max_hamming")
assert_raises(ValueError, ecoc.fit, [], np.array([0, 1, 2]))
ecoc = OutputCodeClassifier(LinearSVC(random_state=0), code_size=0.01,
strategy="max_hamming")
assert_raises(ValueError, ecoc.fit, [], np.array([0, 1, 2, 3]))

def test_ecoc_fit_predict():
# A classifier which implements decision_function.
ecoc = OutputCodeClassifier(LinearSVC(random_state=0),
code_size=2, random_state=0)
code_size=1, random_state=0)
ecoc.fit(iris.data, iris.target).predict(iris.data)
assert_equal(len(ecoc.estimators_), n_classes * 2)
assert_equal(len(ecoc.estimators_), n_classes * 1)

# A classifier which implements predict_proba.
ecoc = OutputCodeClassifier(MultinomialNB(), code_size=2, random_state=0)
ecoc = OutputCodeClassifier(MultinomialNB(), code_size=1, random_state=0)
ecoc.fit(iris.data, iris.target).predict(iris.data)
assert_equal(len(ecoc.estimators_), n_classes * 1)

def test_ecoc_strategy():
# For irsi dataset, code_size=1.5 will use random_code_book
ecoc = OutputCodeClassifier(LinearSVC(random_state=0),
code_size=1.5, random_state=0)
ecoc.fit(iris.data, iris.target).predict(iris.data)
assert_equal(len(ecoc.estimators_), n_classes * 2)
assert_equal(len(ecoc.estimators_), int(n_classes * 1.5))

# Set the strategy to be "random"
ecoc = OutputCodeClassifier(LinearSVC(random_state=0),
code_size=1.5, random_state=0,
strategy="random")
ecoc.fit(iris.data, iris.target).predict(iris.data)
assert_equal(len(ecoc.estimators_), int(n_classes * 1.5))

# Set the strategy to be "max_hamming"
ecoc = OutputCodeClassifier(LinearSVC(random_state=0),
code_size=1.0, random_state=0,
strategy="max_hamming")
ecoc.fit(iris.data, iris.target).predict(iris.data)
assert_equal(len(ecoc.estimators_), n_classes * 1)

def test_max_hamming_code_book():
# Test the the code could be improved using larger max_iter
random_state = np.random
random_state.seed(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use random_state?

random_state = np.random.RandomState(0)

dist0 = np.sum(pairwise_distances(_max_hamming_code_book(5, random_state,
10, 1),
metric='hamming'))
random_state = np.random
random_state.seed(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use random_state?

random_state = np.random.RandomState(0)

dist1 = np.sum(pairwise_distances(_max_hamming_code_book(5, random_state,
10, 2),
metric='hamming'))
assert_true(dist0 >= dist1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you can use assert_greater_equal.


def test_ecoc_gridsearch():
ecoc = OutputCodeClassifier(LinearSVC(random_state=0),
Expand Down