Skip to content

Commit

Permalink
+ Change the process of generating output code for OutputCodeClassifier.
Browse files Browse the repository at this point in the history
+ Add test case.
+ Update the document.
  • Loading branch information
queqichao committed Oct 13, 2014
1 parent 7ac7fdc commit 9565a2a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 26 deletions.
22 changes: 14 additions & 8 deletions doc/modules/multiclass.rst
Expand Up @@ -200,9 +200,9 @@ matrix which keeps track of the location/code of each class is called the
code book. The code size is the dimensionality of the aforementioned space.
Intuitively, each class should be represented by a code as unique as
possible and a good code book should be designed to optimize classification
accuracy. In this implementation, we simply use a randomly-generated code
book as advocated in [2]_ although more elaborate methods may be added in the
future.
accuracy. In this implementation, we randomly generate subset of the exhaustive
codes, suggested in [1]_ multiple times and pick up the one with most
separation between different classes.

At fitting time, one binary classifier per bit in the code book is fitted.
At prediction time, the classifiers are used to project new points in the
Expand All @@ -222,7 +222,13 @@ one-vs-the-rest. In this case, some classifiers will in theory correct for
the mistakes made by other classifiers, hence the name "error-correcting".
In practice, however, this may not happen as classifier mistakes will
typically be correlated. The error-correcting output codes have a similar
effect to bagging.
effect to bagging. The maximum value for ``code_size`` is
``2^(n_classes-1)-1 / n_classes`` as suggested by [1]_, since the codes with
all 0 or 1, and the complement of existing codes need to be excluded.

The ``max_iter`` attribute allows the user to generate a code book with most
separation between classes from several randomly generated subsets of the
exhaustive code book.


Multiclass learning
Expand All @@ -236,14 +242,14 @@ Below is an example of multiclass learning using Output-Codes::
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf = OutputCodeClassifier(LinearSVC(random_state=0),
... code_size=2, random_state=0)
... code_size=1, max_iter=5, random_state=0)
>>> clf.fit(X, y).predict(X)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1,
1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

.. topic:: References:
Expand Down
58 changes: 44 additions & 14 deletions sklearn/multiclass.py
Expand Up @@ -616,6 +616,12 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
one-vs-the-rest. A number greater than 1 will require more classifiers
than one-vs-the-rest.
max_iter : int
Maximum number of iteration to generate a good code. An integer larger
than 0. Each iteration, a random code will be generated and the old
code will be replaced only when the total Hamming distance between
code words increases.
random_state : numpy.RandomState, optional
The generator used to initialize the codebook. Defaults to
numpy.random.
Expand Down Expand Up @@ -656,12 +662,48 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
2008.
"""

def __init__(self, estimator, code_size=1.5, random_state=None, n_jobs=1):
def __init__(self, estimator, code_size=1, max_iter=10,
random_state=None, n_jobs=1):
self.estimator = estimator
self.code_size = code_size
self.max_iter = max_iter
self.random_state = random_state
self.n_jobs = n_jobs

def _generate_codebook(self):
random_state = check_random_state(self.random_state)
iter = self.max_iter
n_classes = self.classes_.shape[0]
code_size_ = int(n_classes * self.code_size)
max_code_size_ = np.power(2, n_classes-1) - 1
if code_size_ > max_code_size_:
raise ValueError("The code size is larger than the possible "
"exhaustive codes.")
if np.power(2, code_size_) < n_classes:
raise ValueError("The code size must be large enough to "
"distinguish every class.")
tmp_code_book = np.zeros((n_classes, code_size_))
dist = 0
while iter > 0:
p = random_state.permutation(max_code_size_)
for i in range(code_size_):
code = bin(p[i] + max_code_size_ + 1)[2:].rjust(n_classes, '0')
for j in range(n_classes):
if code[j] == '0':
tmp_code_book[j][i] = 0
else:
tmp_code_book[j][i] = 1
iter = iter - 1
tmp_dist = 0
for i in range(n_classes-1):
for j in range(i + 1, n_classes):
tmp_dist = tmp_dist + np.sum(np.abs(tmp_code_book[i]-tmp_code_book[j]))
if tmp_dist > dist:
dist = tmp_dist
self.code_book_ = tmp_code_book
if hasattr(self.estimator, "decision_function"):
self.code_book_[self.code_book_ == 0] = -1

def fit(self, X, y):
"""Fit underlying estimators.
Expand All @@ -682,21 +724,9 @@ def fit(self, X, y):
"".format(self.code_size))

_check_estimator(self.estimator)
random_state = check_random_state(self.random_state)

self.classes_ = np.unique(y)
n_classes = self.classes_.shape[0]
code_size_ = int(n_classes * self.code_size)

# FIXME: there are more elaborate methods than generating the codebook
# randomly.
self.code_book_ = random_state.random_sample((n_classes, code_size_))
self.code_book_[self.code_book_ > 0.5] = 1

if hasattr(self.estimator, "decision_function"):
self.code_book_[self.code_book_ != 1] = -1
else:
self.code_book_[self.code_book_ != 1] = 0
self._generate_codebook()

classes_index = dict((c, i) for i, c in enumerate(self.classes_))

Expand Down
12 changes: 8 additions & 4 deletions sklearn/tests/test_multiclass.py
Expand Up @@ -487,19 +487,23 @@ def test_ovo_string_y():
def test_ecoc_exceptions():
ecoc = OutputCodeClassifier(LinearSVC(random_state=0))
assert_raises(ValueError, ecoc.predict, [])
ecoc = OutputCodeClassifier(LinearSVC(random_state=0), code_size=3)
assert_raises(ValueError, ecoc.fit, [], np.array([0, 1, 2, 3]))
ecoc = OutputCodeClassifier(LinearSVC(random_state=0), code_size=0.01)
assert_raises(ValueError, ecoc.fit, [], np.array([0, 1, 2, 3]))


def test_ecoc_fit_predict():
# A classifier which implements decision_function.
ecoc = OutputCodeClassifier(LinearSVC(random_state=0),
code_size=2, random_state=0)
code_size=1, random_state=0)
ecoc.fit(iris.data, iris.target).predict(iris.data)
assert_equal(len(ecoc.estimators_), n_classes * 2)
assert_equal(len(ecoc.estimators_), n_classes * 1)

# A classifier which implements predict_proba.
ecoc = OutputCodeClassifier(MultinomialNB(), code_size=2, random_state=0)
ecoc = OutputCodeClassifier(MultinomialNB(), code_size=1, random_state=0)
ecoc.fit(iris.data, iris.target).predict(iris.data)
assert_equal(len(ecoc.estimators_), n_classes * 2)
assert_equal(len(ecoc.estimators_), n_classes * 1)


def test_ecoc_gridsearch():
Expand Down

0 comments on commit 9565a2a

Please sign in to comment.