Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the issue #3745, the code book generation for OutputCodeClassifier #3768

Closed
wants to merge 11 commits into from
37 changes: 28 additions & 9 deletions doc/modules/multiclass.rst
Expand Up @@ -200,9 +200,9 @@ matrix which keeps track of the location/code of each class is called the
code book. The code size is the dimensionality of the aforementioned space.
Intuitively, each class should be represented by a code as unique as
possible and a good code book should be designed to optimize classification
accuracy. In this implementation, we simply use a randomly-generated code
book as advocated in [2]_ although more elaborate methods may be added in the
future.
accuracy. In this implementation, we randomly generate subset of the exhaustive
codes, suggested in [1]_ multiple times and pick up the one with most
separation between different classes.

At fitting time, one binary classifier per bit in the code book is fitted.
At prediction time, the classifiers are used to project new points in the
Expand All @@ -222,8 +222,21 @@ one-vs-the-rest. In this case, some classifiers will in theory correct for
the mistakes made by other classifiers, hence the name "error-correcting".
In practice, however, this may not happen as classifier mistakes will
typically be correlated. The error-correcting output codes have a similar
effect to bagging.

effect to bagging. The maximum value for ``code_size`` is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit strange given that code_size is a percentage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also feel a bit strange. I just use the original name before. I guess the reason is the n_classes will only be decided after fit is called. But the range of code_size still depends on n_classes.

``2^(n_classes-1)-1 / n_classes`` as suggested by [1]_, since the codes with
all 0 or 1, and the complement of existing codes need to be excluded.

The ``strategy`` attribute allows the user the select the strategy to
code the classes. Two strategies are supported currently: (1) ``"random"``,
a random ``n_classes x int(n_classes x code_size)`` matrix is generated, and
entries > 0.5 are set to be 1 and the rest to be 0 or -1; (2)
``"iter_hamming"``, random subsets of the exhaustive code are sampled,
and the one with the largest Hamming distance is chosen.

The ``max_iter`` attribute is used when
``strategy="iter_hamming"`` allows the user to generate a code
book with most separation between classes from several randomly generated
subsets of the exhaustive code book.

Multiclass learning
-------------------
Expand All @@ -236,14 +249,14 @@ Below is an example of multiclass learning using Output-Codes::
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf = OutputCodeClassifier(LinearSVC(random_state=0),
... code_size=2, random_state=0)
... code_size=1, max_iter=5, random_state=0)
>>> clf.fit(X, y).predict(X)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1,
1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

.. topic:: References:
Expand All @@ -261,3 +274,9 @@ Below is an example of multiclass learning using Output-Codes::
.. [3] "The Elements of Statistical Learning",
Hastie T., Tibshirani R., Friedman J., page 606 (second-edition)
2008.

.. [4] "Reducing multiclass to binary: A unifying approach for margin
classifiers."
Allwein, Erin L., Robert E. Schapire, and Yoram Singer,
The Journal of Machine Learning Research 1: 113-141,
2001.
139 changes: 139 additions & 0 deletions examples/classification/plot_digits_multiclass.py
@@ -0,0 +1,139 @@
"""
===========================
Multi-class classification
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this "Multi-class encoding" or something since the example is about coding strategies, not multi-class classification.

===========================

An example show how to use the different encoding methods in scikit-learn for
multi-class problem, for example, a hand-written digits recognition. And a
comparison is conducted among different encoding methods in scikit-learn:
(1) Using the label_binarizer in SVC; (2) OneVsOneClassifier;
(3) OneVsRestClassifier; (4) OutputCodeClassifer with two strategy, 'random'
and 'iter_hamming'.

"""
print(__doc__)

# Author: Qichao Que <que@cse.ohio-state.edu>
# License: BSD 3 clause

# Import timing, plotting and scientific python libarary.
from matplotlib import pyplot as plt
import numpy as np
import time

# Import datasets, classifiers, and performance metrics
from sklearn.svm import SVC
from sklearn import datasets, metrics
from sklearn import multiclass
from sklearn.metrics.pairwise import pairwise_distances

# Load the digits dataset
digits = datasets.load_digits()

# Split the dataset in to training and testing data
train_X = digits.data[:1000]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use the train_test_split function here if you like.

train_y = digits.target[:1000]
test_X = digits.data[1000:]
test_y = digits.target[1000:]

# Using label_binarizer built-in SVC
c = SVC(gamma=0.001)
s = time.time()
c.fit(train_X, train_y)
pred = c.predict(test_X)
e = time.time()
time_svc = e - s
error_svc = np.sum((test_y != pred).astype(np.int)) / test_y.shape[0]
print('Error for SVC: %.3f%%' % error_svc * 100)

# Using OneVsRestClassifier
c1 = multiclass.OneVsRestClassifier(c)
s = time.time()
c1.fit(train_X, train_y)
pred = c1.predict(test_X)
e = time.time()
time_ovr = e - s
error_ovr = np.sum((test_y != pred).astype(np.int)) / test_y.shape[0]
print('Error for SVC using OneVsRest Code: %.3f%%' % error_ovr * 100)

# Using OneVsOneClassifier
c1 = multiclass.OneVsOneClassifier(c)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is what SVC uses internally, right? So there should be no difference?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this code is pretty repetitive. Can you maybe put it into a function?

s = time.time()
c1.fit(train_X, train_y)
pred = c1.predict(test_X)
e = time.time()
time_ovo = e - s
error_ovo = np.sum((test_y != pred).astype(np.int)) / test_y.shape[0]
print('Error for SVC using OneVsOne Code: %.3f%%' % error_ovo * 100)

# constants for OutputCodeClassifier
max_iter = 50
repeat = 20
code_size = [0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]

# Using OutputCodeClassifier with strategy 'iter_hamming'
error_ecoc_iter_hamming = np.zeros((len(code_size), repeat))
random_state = np.random.RandomState(0)
time_ecoc_iter_hamming = np.zeros((len(code_size), repeat))
for j, cs in enumerate(code_size):
for i in range(repeat):
c1 = multiclass.OutputCodeClassifier(c, max_iter=max_iter, code_size=cs,
strategy="iter_hamming",
random_state=random_state)
s = time.time()
c1.fit(train_X, train_y)
pred = c1.predict(test_X)
e = time.time()
time_ecoc_iter_hamming[j, i] = e - s
error_ecoc_iter_hamming[j, i] = np.sum((test_y != pred).astype(np.int)) / test_y.shape[0]
mean_error_ecoc_iter_hamming = np.mean(error_ecoc_iter_hamming, axis=1)
print('Error for SVC using Output Code with iter_hamming strategy: ')
print(mean_error_ecoc_iter_hamming)

# Using OutputCodeClassifier with strategy 'random'
error_ecoc_random = np.zeros((len(code_size), repeat))
random_state = np.random.RandomState(0)
time_ecoc_random = np.zeros((len(code_size), repeat))
for j, cs in enumerate(code_size):
for i in range(repeat):
c1 = multiclass.OutputCodeClassifier(c, max_iter=max_iter, code_size=cs,
strategy="random",
random_state=random_state)
s = time.time()
c1.fit(train_X, train_y)
pred = c1.predict(test_X)
e = time.time()
time_ecoc_random[j, i] = e-s
error_ecoc_random[j, i] = np.sum((test_y!=pred).astype(np.int))/test_y.shape[0]
mean_error_ecoc_random = np.mean(error_ecoc_random, axis=1)
print('Error for SVC using Output Code with random strategy:')
print(mean_error_ecoc_random)

plt.figure(1)
# Plot Result, the classification error using different encoding strategy.
plt.subplot(121)
plt.plot(1, error_svc, 'b', marker='+', label='SVC')
plt.plot(1, error_ovo, 'r', marker='o', label='OneVsOne')
plt.plot(1, error_ovr, 'g', marker='*', label='OneVsRest')
plt.plot(code_size, np.mean(error_ecoc_iter_hamming, axis=1), '-c', marker='x',
label='OC-iter_hamming')
plt.plot(code_size, np.mean(error_ecoc_random, axis=1), '-k', marker='s',
label='OC-random')
plt.legend()
plt.xlabel('code_size')
plt.ylabel('Classification Error')

# Plot the timing results as well.
plt.subplot(122)
plt.plot(1, time_svc, 'b', marker='+', label='SVC')
plt.plot(1, time_ovo, 'r', marker='o', label='OneVsOne')
plt.plot(1, time_ovr, 'g', marker='*', label='OneVsRest')
plt.plot(code_size, np.mean(time_ecoc_iter_hamming, axis=1), '-c', marker='x',
label='iter_hamming')
plt.plot(code_size, np.mean(time_ecoc_random, axis=1), '-k', marker='s',
label='random')
plt.xlabel('code_size')
plt.ylabel('Running time')
plt.legend()

plt.show()
121 changes: 108 additions & 13 deletions sklearn/multiclass.py
Expand Up @@ -41,8 +41,9 @@
from .base import BaseEstimator, ClassifierMixin, clone, is_classifier
from .base import MetaEstimatorMixin
from .preprocessing import LabelBinarizer
from .metrics.pairwise import euclidean_distances
from .metrics.pairwise import euclidean_distances, pairwise_distances
from .utils import check_random_state
from .utils.random import sample_without_replacement
from .utils.validation import _num_samples
from .utils import deprecated
from .externals.joblib import Parallel
Expand Down Expand Up @@ -592,6 +593,54 @@ def predict_ecoc(estimators, classes, code_book, X):
return ecoc.predict(X)


def _random_code_book(n_classes, random_state, code_size):
"""Random generate a code book."""
code_book = random_state.random_sample((n_classes, code_size))
code_book[code_book > 0.5] = 1
code_book[code_book != 1] = 0
return code_book


def _iter_hamming_code_book(n_classes, random_state, code_size, max_iter):
"""Randomly sample subsets of exhaustive code for n_classes, and choose
the one gives the largest hamming distances.
"""
if max_iter <= 0:
raise ValueError("max_iter must be larger than 0.")
if n_classes > 60:
max_code_size = np.power(2, 60) - 1
else:
max_code_size = np.power(2, n_classes-1) - 1
if code_size > max_code_size:
raise ValueError("The number of code words is larger than the number "
"of exhaustive codes.")
best_code_distance = 0
for k in range(max_iter):
p = sample_without_replacement(n_samples=code_size,
n_population=max_code_size)
# Example for understanding the intuition behind tmp_code_book
# The exhaustive code for 4 classes is
# 1 1 1 1 1 1 1
# 0 0 0 0 1 1 1
# 0 0 1 1 0 0 1
# 0 1 0 1 0 1 0
# The columns are the binary code for
# [8, 9, ..., 14] = [0, 1, ..., 6] + 7 + 1
# Here 7 is the max_code_size when n_classes = 4
# Bit manipulation & is used for generating the binary code of integers
# representing the code words sampled from exhaustive code.
tmp_code_book = (p[:, None] + max_code_size+1 &
(1 << np.arange(n_classes-1, -1, -1)) > 0
).astype(int).T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on what it's done on this line?

# The code distance is sum of hamming distances between rows.
tmp_code_distance = np.sum(pairwise_distances(
tmp_code_book, metric='hamming'))
if tmp_code_distance > best_code_distance:
best_code_distance = tmp_code_distance
best_code_book = tmp_code_book
return best_code_book


class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
"""(Error-Correcting) Output-Code multiclass strategy

Expand All @@ -616,6 +665,13 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
one-vs-the-rest. A number greater than 1 will require more classifiers
than one-vs-the-rest.

max_iter : int, default: 10
Maximum number of iteration to generate a good code. An integer larger
than 0. Each iteration, a random code will be generated and the old
code will be replaced only when the total Hamming distance between
code words increases.
This parameter will be used when the strategy is "iter_hamming"

random_state : numpy.RandomState, optional
The generator used to initialize the codebook. Defaults to
numpy.random.
Expand All @@ -626,6 +682,18 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
useful for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are
used. Thus for n_jobs = -2, all CPUs but one are used.

strategy : str, {'auto', 'random', 'iter_hamming'}, optional, default: "auto"
The strategy to generate a code book for all classes. Three options are
avalable currently:

(1) "random": randomly generate a n_class x int(n_class*code_size)
matrix, and set entries > 0.5 to be 1, the rest to be 0 or -1;
(2) "iter_hamming": select subset of exhaustive code book mutiple times
randomly, and choose the one gives the largest hamming distances
between classes.
(3) "auto": use "iter_hamming" if the code_size is in the valid range
for it; otherwise, "random" strategy will be used.

Attributes
----------
estimators_ : list of `int(n_classes * code_size)` estimators
Expand Down Expand Up @@ -654,13 +722,23 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
.. [3] "The Elements of Statistical Learning",
Hastie T., Tibshirani R., Friedman J., page 606 (second-edition)
2008.

.. [4] "Reducing multiclass to binary: A unifying approach for margin
classifiers",
Allwein, Erin L., Robert E. Schapire, and Yoram Singer,
The Journal of Machine Learning Research 1: 113-141,
2001.
"""

def __init__(self, estimator, code_size=1.5, random_state=None, n_jobs=1):
def __init__(self, estimator, code_size=1, max_iter=10,
random_state=None, n_jobs=1, strategy="auto"):
self.estimator = estimator
self.code_size = code_size
self.max_iter = max_iter
self.random_state = random_state
self.n_jobs = n_jobs
self.strategy = strategy


def fit(self, X, y):
"""Fit underlying estimators.
Expand All @@ -682,21 +760,38 @@ def fit(self, X, y):
"".format(self.code_size))

_check_estimator(self.estimator)
random_state = check_random_state(self.random_state)

self.classes_ = np.unique(y)
n_classes = self.classes_.shape[0]
code_size_ = int(n_classes * self.code_size)

# FIXME: there are more elaborate methods than generating the codebook
# randomly.
self.code_book_ = random_state.random_sample((n_classes, code_size_))
self.code_book_[self.code_book_ > 0.5] = 1

if hasattr(self.estimator, "decision_function"):
self.code_book_[self.code_book_ != 1] = -1
random_state = check_random_state(self.random_state)
n_classes = self.classes_.shape[0]
if n_classes > 60:
max_code_size = np.power(2, 60) - 1
else:
max_code_size = np.power(2, n_classes-1) - 1
n_code_words = int(n_classes * self.code_size)
if n_code_words - 1 < np.log2(n_classes):
raise ValueError("The code size must be large enough to "
"distinguish every class.")
if self.strategy == "auto":
if n_code_words > max_code_size:
self.code_book_ = _random_code_book(n_classes, random_state, n_code_words)
else:
self.code_book_ = _iter_hamming_code_book(n_classes,
random_state,
n_code_words,
self.max_iter)
elif self.strategy == "random":
self.code_book_ = _random_code_book(n_classes, random_state, n_code_words)
elif self.strategy == "iter_hamming":
self.code_book_ = _iter_hamming_code_book(n_classes,
random_state,
n_code_words,
self.max_iter)
else:
self.code_book_[self.code_book_ != 1] = 0
raise ValueError("Unknown coding strategy %r" % self.strategy)
if hasattr(self.estimator, "decision_function"):
self.code_book_[self.code_book_ == 0] = -1

classes_index = dict((c, i) for i, c in enumerate(self.classes_))

Expand Down