New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the issue #3745, the code book generation for OutputCodeClassifier #3768
Changes from all commits
9565a2a
9d98cca
98640e4
e2cd7f1
b3ad677
d1c08fd
e5187ee
95d0343
346d6f2
169a8a3
f4c3fe9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
""" | ||
=========================== | ||
Multi-class classification | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rename this "Multi-class encoding" or something since the example is about coding strategies, not multi-class classification. |
||
=========================== | ||
|
||
An example show how to use the different encoding methods in scikit-learn for | ||
multi-class problem, for example, a hand-written digits recognition. And a | ||
comparison is conducted among different encoding methods in scikit-learn: | ||
(1) Using the label_binarizer in SVC; (2) OneVsOneClassifier; | ||
(3) OneVsRestClassifier; (4) OutputCodeClassifer with two strategy, 'random' | ||
and 'iter_hamming'. | ||
|
||
""" | ||
print(__doc__) | ||
|
||
# Author: Qichao Que <que@cse.ohio-state.edu> | ||
# License: BSD 3 clause | ||
|
||
# Import timing, plotting and scientific python libarary. | ||
from matplotlib import pyplot as plt | ||
import numpy as np | ||
import time | ||
|
||
# Import datasets, classifiers, and performance metrics | ||
from sklearn.svm import SVC | ||
from sklearn import datasets, metrics | ||
from sklearn import multiclass | ||
from sklearn.metrics.pairwise import pairwise_distances | ||
|
||
# Load the digits dataset | ||
digits = datasets.load_digits() | ||
|
||
# Split the dataset in to training and testing data | ||
train_X = digits.data[:1000] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use the train_test_split function here if you like. |
||
train_y = digits.target[:1000] | ||
test_X = digits.data[1000:] | ||
test_y = digits.target[1000:] | ||
|
||
# Using label_binarizer built-in SVC | ||
c = SVC(gamma=0.001) | ||
s = time.time() | ||
c.fit(train_X, train_y) | ||
pred = c.predict(test_X) | ||
e = time.time() | ||
time_svc = e - s | ||
error_svc = np.sum((test_y != pred).astype(np.int)) / test_y.shape[0] | ||
print('Error for SVC: %.3f%%' % error_svc * 100) | ||
|
||
# Using OneVsRestClassifier | ||
c1 = multiclass.OneVsRestClassifier(c) | ||
s = time.time() | ||
c1.fit(train_X, train_y) | ||
pred = c1.predict(test_X) | ||
e = time.time() | ||
time_ovr = e - s | ||
error_ovr = np.sum((test_y != pred).astype(np.int)) / test_y.shape[0] | ||
print('Error for SVC using OneVsRest Code: %.3f%%' % error_ovr * 100) | ||
|
||
# Using OneVsOneClassifier | ||
c1 = multiclass.OneVsOneClassifier(c) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is what SVC uses internally, right? So there should be no difference? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, this code is pretty repetitive. Can you maybe put it into a function? |
||
s = time.time() | ||
c1.fit(train_X, train_y) | ||
pred = c1.predict(test_X) | ||
e = time.time() | ||
time_ovo = e - s | ||
error_ovo = np.sum((test_y != pred).astype(np.int)) / test_y.shape[0] | ||
print('Error for SVC using OneVsOne Code: %.3f%%' % error_ovo * 100) | ||
|
||
# constants for OutputCodeClassifier | ||
max_iter = 50 | ||
repeat = 20 | ||
code_size = [0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0] | ||
|
||
# Using OutputCodeClassifier with strategy 'iter_hamming' | ||
error_ecoc_iter_hamming = np.zeros((len(code_size), repeat)) | ||
random_state = np.random.RandomState(0) | ||
time_ecoc_iter_hamming = np.zeros((len(code_size), repeat)) | ||
for j, cs in enumerate(code_size): | ||
for i in range(repeat): | ||
c1 = multiclass.OutputCodeClassifier(c, max_iter=max_iter, code_size=cs, | ||
strategy="iter_hamming", | ||
random_state=random_state) | ||
s = time.time() | ||
c1.fit(train_X, train_y) | ||
pred = c1.predict(test_X) | ||
e = time.time() | ||
time_ecoc_iter_hamming[j, i] = e - s | ||
error_ecoc_iter_hamming[j, i] = np.sum((test_y != pred).astype(np.int)) / test_y.shape[0] | ||
mean_error_ecoc_iter_hamming = np.mean(error_ecoc_iter_hamming, axis=1) | ||
print('Error for SVC using Output Code with iter_hamming strategy: ') | ||
print(mean_error_ecoc_iter_hamming) | ||
|
||
# Using OutputCodeClassifier with strategy 'random' | ||
error_ecoc_random = np.zeros((len(code_size), repeat)) | ||
random_state = np.random.RandomState(0) | ||
time_ecoc_random = np.zeros((len(code_size), repeat)) | ||
for j, cs in enumerate(code_size): | ||
for i in range(repeat): | ||
c1 = multiclass.OutputCodeClassifier(c, max_iter=max_iter, code_size=cs, | ||
strategy="random", | ||
random_state=random_state) | ||
s = time.time() | ||
c1.fit(train_X, train_y) | ||
pred = c1.predict(test_X) | ||
e = time.time() | ||
time_ecoc_random[j, i] = e-s | ||
error_ecoc_random[j, i] = np.sum((test_y!=pred).astype(np.int))/test_y.shape[0] | ||
mean_error_ecoc_random = np.mean(error_ecoc_random, axis=1) | ||
print('Error for SVC using Output Code with random strategy:') | ||
print(mean_error_ecoc_random) | ||
|
||
plt.figure(1) | ||
# Plot Result, the classification error using different encoding strategy. | ||
plt.subplot(121) | ||
plt.plot(1, error_svc, 'b', marker='+', label='SVC') | ||
plt.plot(1, error_ovo, 'r', marker='o', label='OneVsOne') | ||
plt.plot(1, error_ovr, 'g', marker='*', label='OneVsRest') | ||
plt.plot(code_size, np.mean(error_ecoc_iter_hamming, axis=1), '-c', marker='x', | ||
label='OC-iter_hamming') | ||
plt.plot(code_size, np.mean(error_ecoc_random, axis=1), '-k', marker='s', | ||
label='OC-random') | ||
plt.legend() | ||
plt.xlabel('code_size') | ||
plt.ylabel('Classification Error') | ||
|
||
# Plot the timing results as well. | ||
plt.subplot(122) | ||
plt.plot(1, time_svc, 'b', marker='+', label='SVC') | ||
plt.plot(1, time_ovo, 'r', marker='o', label='OneVsOne') | ||
plt.plot(1, time_ovr, 'g', marker='*', label='OneVsRest') | ||
plt.plot(code_size, np.mean(time_ecoc_iter_hamming, axis=1), '-c', marker='x', | ||
label='iter_hamming') | ||
plt.plot(code_size, np.mean(time_ecoc_random, axis=1), '-k', marker='s', | ||
label='random') | ||
plt.xlabel('code_size') | ||
plt.ylabel('Running time') | ||
plt.legend() | ||
|
||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,8 +41,9 @@ | |
from .base import BaseEstimator, ClassifierMixin, clone, is_classifier | ||
from .base import MetaEstimatorMixin | ||
from .preprocessing import LabelBinarizer | ||
from .metrics.pairwise import euclidean_distances | ||
from .metrics.pairwise import euclidean_distances, pairwise_distances | ||
from .utils import check_random_state | ||
from .utils.random import sample_without_replacement | ||
from .utils.validation import _num_samples | ||
from .utils import deprecated | ||
from .externals.joblib import Parallel | ||
|
@@ -592,6 +593,54 @@ def predict_ecoc(estimators, classes, code_book, X): | |
return ecoc.predict(X) | ||
|
||
|
||
def _random_code_book(n_classes, random_state, code_size): | ||
"""Random generate a code book.""" | ||
code_book = random_state.random_sample((n_classes, code_size)) | ||
code_book[code_book > 0.5] = 1 | ||
code_book[code_book != 1] = 0 | ||
return code_book | ||
|
||
|
||
def _iter_hamming_code_book(n_classes, random_state, code_size, max_iter): | ||
"""Randomly sample subsets of exhaustive code for n_classes, and choose | ||
the one gives the largest hamming distances. | ||
""" | ||
if max_iter <= 0: | ||
raise ValueError("max_iter must be larger than 0.") | ||
if n_classes > 60: | ||
max_code_size = np.power(2, 60) - 1 | ||
else: | ||
max_code_size = np.power(2, n_classes-1) - 1 | ||
if code_size > max_code_size: | ||
raise ValueError("The number of code words is larger than the number " | ||
"of exhaustive codes.") | ||
best_code_distance = 0 | ||
for k in range(max_iter): | ||
p = sample_without_replacement(n_samples=code_size, | ||
n_population=max_code_size) | ||
# Example for understanding the intuition behind tmp_code_book | ||
# The exhaustive code for 4 classes is | ||
# 1 1 1 1 1 1 1 | ||
# 0 0 0 0 1 1 1 | ||
# 0 0 1 1 0 0 1 | ||
# 0 1 0 1 0 1 0 | ||
# The columns are the binary code for | ||
# [8, 9, ..., 14] = [0, 1, ..., 6] + 7 + 1 | ||
# Here 7 is the max_code_size when n_classes = 4 | ||
# Bit manipulation & is used for generating the binary code of integers | ||
# representing the code words sampled from exhaustive code. | ||
tmp_code_book = (p[:, None] + max_code_size+1 & | ||
(1 << np.arange(n_classes-1, -1, -1)) > 0 | ||
).astype(int).T | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment on what it's done on this line? |
||
# The code distance is sum of hamming distances between rows. | ||
tmp_code_distance = np.sum(pairwise_distances( | ||
tmp_code_book, metric='hamming')) | ||
if tmp_code_distance > best_code_distance: | ||
best_code_distance = tmp_code_distance | ||
best_code_book = tmp_code_book | ||
return best_code_book | ||
|
||
|
||
class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin): | ||
"""(Error-Correcting) Output-Code multiclass strategy | ||
|
||
|
@@ -616,6 +665,13 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin): | |
one-vs-the-rest. A number greater than 1 will require more classifiers | ||
than one-vs-the-rest. | ||
|
||
max_iter : int, default: 10 | ||
Maximum number of iteration to generate a good code. An integer larger | ||
than 0. Each iteration, a random code will be generated and the old | ||
code will be replaced only when the total Hamming distance between | ||
code words increases. | ||
This parameter will be used when the strategy is "iter_hamming" | ||
|
||
random_state : numpy.RandomState, optional | ||
The generator used to initialize the codebook. Defaults to | ||
numpy.random. | ||
|
@@ -626,6 +682,18 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin): | |
useful for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are | ||
used. Thus for n_jobs = -2, all CPUs but one are used. | ||
|
||
strategy : str, {'auto', 'random', 'iter_hamming'}, optional, default: "auto" | ||
The strategy to generate a code book for all classes. Three options are | ||
avalable currently: | ||
|
||
(1) "random": randomly generate a n_class x int(n_class*code_size) | ||
matrix, and set entries > 0.5 to be 1, the rest to be 0 or -1; | ||
(2) "iter_hamming": select subset of exhaustive code book mutiple times | ||
randomly, and choose the one gives the largest hamming distances | ||
between classes. | ||
(3) "auto": use "iter_hamming" if the code_size is in the valid range | ||
for it; otherwise, "random" strategy will be used. | ||
|
||
Attributes | ||
---------- | ||
estimators_ : list of `int(n_classes * code_size)` estimators | ||
|
@@ -654,13 +722,23 @@ class OutputCodeClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin): | |
.. [3] "The Elements of Statistical Learning", | ||
Hastie T., Tibshirani R., Friedman J., page 606 (second-edition) | ||
2008. | ||
|
||
.. [4] "Reducing multiclass to binary: A unifying approach for margin | ||
classifiers", | ||
Allwein, Erin L., Robert E. Schapire, and Yoram Singer, | ||
The Journal of Machine Learning Research 1: 113-141, | ||
2001. | ||
""" | ||
|
||
def __init__(self, estimator, code_size=1.5, random_state=None, n_jobs=1): | ||
def __init__(self, estimator, code_size=1, max_iter=10, | ||
random_state=None, n_jobs=1, strategy="auto"): | ||
self.estimator = estimator | ||
self.code_size = code_size | ||
self.max_iter = max_iter | ||
self.random_state = random_state | ||
self.n_jobs = n_jobs | ||
self.strategy = strategy | ||
|
||
|
||
def fit(self, X, y): | ||
"""Fit underlying estimators. | ||
|
@@ -682,21 +760,38 @@ def fit(self, X, y): | |
"".format(self.code_size)) | ||
|
||
_check_estimator(self.estimator) | ||
random_state = check_random_state(self.random_state) | ||
|
||
self.classes_ = np.unique(y) | ||
n_classes = self.classes_.shape[0] | ||
code_size_ = int(n_classes * self.code_size) | ||
|
||
# FIXME: there are more elaborate methods than generating the codebook | ||
# randomly. | ||
self.code_book_ = random_state.random_sample((n_classes, code_size_)) | ||
self.code_book_[self.code_book_ > 0.5] = 1 | ||
|
||
if hasattr(self.estimator, "decision_function"): | ||
self.code_book_[self.code_book_ != 1] = -1 | ||
random_state = check_random_state(self.random_state) | ||
n_classes = self.classes_.shape[0] | ||
if n_classes > 60: | ||
max_code_size = np.power(2, 60) - 1 | ||
else: | ||
max_code_size = np.power(2, n_classes-1) - 1 | ||
n_code_words = int(n_classes * self.code_size) | ||
if n_code_words - 1 < np.log2(n_classes): | ||
raise ValueError("The code size must be large enough to " | ||
"distinguish every class.") | ||
if self.strategy == "auto": | ||
if n_code_words > max_code_size: | ||
self.code_book_ = _random_code_book(n_classes, random_state, n_code_words) | ||
else: | ||
self.code_book_ = _iter_hamming_code_book(n_classes, | ||
random_state, | ||
n_code_words, | ||
self.max_iter) | ||
elif self.strategy == "random": | ||
self.code_book_ = _random_code_book(n_classes, random_state, n_code_words) | ||
elif self.strategy == "iter_hamming": | ||
self.code_book_ = _iter_hamming_code_book(n_classes, | ||
random_state, | ||
n_code_words, | ||
self.max_iter) | ||
else: | ||
self.code_book_[self.code_book_ != 1] = 0 | ||
raise ValueError("Unknown coding strategy %r" % self.strategy) | ||
if hasattr(self.estimator, "decision_function"): | ||
self.code_book_[self.code_book_ == 0] = -1 | ||
|
||
classes_index = dict((c, i) for i, c in enumerate(self.classes_)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit strange given that code_size is a percentage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also feel a bit strange. I just use the original name before. I guess the reason is the
n_classes
will only be decided afterfit
is called. But the range ofcode_size
still depends onn_classes
.