New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+2] parallelized VotingClassifier #5805

Merged
merged 13 commits into from Aug 26, 2016

Conversation

Projects
None yet
5 participants
@olologin
Contributor

olologin commented Nov 12, 2015

First version, looks like it's working :)
Also, i added sample_weight parameter into fit method, don't know if it's appropriate.

@olologin

This comment has been minimized.

Contributor

olologin commented Nov 13, 2015

If someone wants to test it with nested multithreading, e.g. in the code below we call version of VotingClassifier from cross_val_score.

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import VotingClassifier
from sklearn import datasets
from sklearn.model_selection import cross_val_score

# Load the iris dataset and randomly permute it
iris = datasets.load_iris()
X, y = iris.data[:, 1:3], iris.target


"""Check classification by majority label on dataset iris."""
clf1 = LogisticRegression(random_state=123)
clf2 = RandomForestClassifier(random_state=123)
clf3 = GaussianNB()
eclf = VotingClassifier(estimators=[
    ('lr', clf1), ('rf', clf2), ('gnb', clf3)],
    voting='hard', n_jobs=-1, backend='multiprocessing')
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy', n_jobs=-1)

With backend='multiprocessing' it shows message UserWarning: Multiprocessing-backed parallel loops cannot be nested, setting n_jobs=1, so it just falls back to -1 threads in cross_val_score, but VotingClassifier entirely works in 1 thread. With backend='threading' it doesn't show any warning.

In both cases everything works fine and provides correct results.

@giorgiop

This comment has been minimized.

Contributor

giorgiop commented Nov 13, 2015

In both cases everything works fine and provides correct results.

Could you please add tests showing that? Same for the sample_weight

@olologin

This comment has been minimized.

Contributor

olologin commented Nov 13, 2015

Could you please add tests showing that? Same for the sample_weight

Ok, i'll add couple of tests with multithreading/processing. I've tested it already on existing tests from test_voting_classifier, with different default threading/processing parameters.

assert_array_equal(eclf1.predict(X), eclf2.predict(X))
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))

This comment has been minimized.

@betatim

betatim Nov 16, 2015

Contributor

two empty lines between functions please (PEP8)

voting='soft', n_jobs=-1, backend='multiprocessing').fit(X, y)
assert_array_equal(eclf1.predict(X), eclf2.predict(X))
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))

This comment has been minimized.

@betatim

betatim Nov 16, 2015

Contributor

two empty lines between functions please (PEP8)

return [estimator.predict_proba(X) for estimator in estimators]
def _flatten_list(list_):

This comment has been minimized.

@betatim

betatim Nov 16, 2015

Contributor

check out for example bagging.py:367 where list(itertools.chain.from_iterable(...)) is used instead of a special flatten function. Would be good if possible to keep the code consistent.

This comment has been minimized.

@olologin

olologin Nov 16, 2015

Contributor

Thanks for comments and sorry for such obvious problems. I'll make commit when i'll get some conclusion from this issue #5820 (comment)

@betatim

This comment has been minimized.

Contributor

betatim commented Nov 16, 2015

There are a few more pep8 related things to fix (param = 42 should be param=42, etc) in the code. If you install pep8 to have it tell you about them.

@olologin olologin changed the title from parallelized VotingClassifier to [MRG] parallelized VotingClassifier Jan 8, 2016

The number of jobs to run in parallel for both `fit` and `predict`.
If -1, then the number of jobs is set to the number of cores.
backend: str, {'multiprocessing', 'threading'} (default='multiprocessing')

This comment has been minimized.

@MechCoder

MechCoder Apr 28, 2016

Member

I don't think we should allow the user to choose this. AFAIK, the threading backend is useful only if the code is written in a nogil block and the user is highly unlikely to know in which estimators that this is being done.

We should maybe just allow the default "multiprocessing" backend just like in cross_val_score and in other places.

X,
transformed_y,
sample_weight
) for i in range(n_jobs)

This comment has been minimized.

@MechCoder

MechCoder Apr 29, 2016

Member

I'm really wondering if joblib doesn't do the job allocation automatically, in other words could you just iterate over range(len(self.estimators)) as done in the first example here (https://pythonhosted.org/joblib/parallel.html)

Maybe @ogrisel or @glouppe (the original author of _partition_estimators) can clarify?

This comment has been minimized.

@MechCoder

MechCoder Apr 29, 2016

Member

If there is some reason behind this, then we would have to do this wherever n_splits > n_jobs in cross-validation strategies and in 20 other places!

This comment has been minimized.

@olologin

olologin Apr 29, 2016

Contributor

@MechCoder, No, seems there is no reason behind this, look here #5820,

I'll fix everything and will try to rebase it tomorrow, too tired today :)

This comment has been minimized.

@MechCoder

MechCoder Apr 29, 2016

Member

Not to scare you, but a first +1 need not guarantee a second one immediately :P

sample_weight : array-like, shape = [n_samples] or None
Sample weights. If None, then samples are equally weighted.
Note that this is supported only if the base estimator supports
sample weighting.

This comment has been minimized.

@MechCoder

MechCoder Apr 29, 2016

Member

weighting -> weights

('lr', clf1), ('rf', clf2), ('gnb', clf3)],
voting='soft', n_jobs=1, backend='threading').fit(X, y)
eclf2 = VotingClassifier(estimators=[
('lr', clone(clf1)), ('rf', clone(clf2)), ('gnb', clone(clf3))],

This comment has been minimized.

@MechCoder

MechCoder Apr 29, 2016

Member

This should work without cloning.

voting='soft', n_jobs=1, backend='threading').fit(X, y)
eclf2 = VotingClassifier(estimators=[
('lr', clone(clf1)), ('rf', clone(clf2)), ('gnb', clone(clf3))],
voting='soft', n_jobs=-1, backend='multiprocessing').fit(X, y)

This comment has been minimized.

@MechCoder

MechCoder Apr 29, 2016

Member

You should test with a smaller value of n_jobs, like 2 for instance

voting='hard', n_jobs=1).fit(X, y)
eclf2 = VotingClassifier(estimators=[
('lr', clone(clf1)), ('rf', clone(clf2)), ('gnb', clone(clf3))],
voting='hard', n_jobs=-1).fit(X, y)

This comment has been minimized.

@MechCoder
voting='soft', n_jobs=1).fit(X, y, sample_weight=np.ones((len(y),)))
eclf2 = VotingClassifier(estimators=[
('lr', clone(clf1)), ('rf', clone(clf2)), ('svc', clone(clf3))],
voting='soft', n_jobs=-1).fit(X, y)

This comment has been minimized.

@MechCoder

MechCoder Apr 29, 2016

Member

You can set this to be 1, since you are testing only for sample_weight

@MechCoder

This comment has been minimized.

Member

MechCoder commented Apr 29, 2016

@olologin Could you rebase?

assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))
def test_parallel_majority_label_iris():

This comment has been minimized.

@MechCoder

MechCoder Apr 30, 2016

Member

Is this test not redundant? Isn't the previous test sufficient?

This comment has been minimized.

@olologin

olologin May 1, 2016

Contributor

It tests "hard" voting stability (So that 1threaded and 2threaded version will give same results) on iris dataset, while previous one tests 'soft' version.
Anyway first test is fast (Small artificial dataset), so it shouldn't increase testing time noticeably.

def test_sample_weight():
"""
Tests sample_weight parameter of VotingClassifier

This comment has been minimized.

@MechCoder

MechCoder Apr 30, 2016

Member

Could you check indentation here?

This comment has been minimized.

@olologin

olologin May 1, 2016

Contributor

I didn't understand what you exactly mean, pep8 doesn't yell at anything. But I've changed my indentation little bit.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Apr 30, 2016

LGTM pending nitpick @olologin

cc: @agramfort | @TomDLT for a quick second pass.

@MechCoder MechCoder changed the title from [MRG] parallelized VotingClassifier to [MRG+1] parallelized VotingClassifier Apr 30, 2016

@MechCoder

This comment has been minimized.

Member

MechCoder commented Apr 30, 2016

And sorry for the delays..

@olologin

This comment has been minimized.

Contributor

olologin commented May 1, 2016

@MechCoder , Actually I almost forgot about this PR, thought it was someone else’s :)

Could you take a look at this one too #6116 ? I think it's good PR too.

@MechCoder

This comment has been minimized.

Member

MechCoder commented May 5, 2016

I can try after the sem breaks, if no one gets to it by then.

self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_parallel_fit_estimator)(
clone(clf),

This comment has been minimized.

@jnothman

jnothman Jun 21, 2016

Member

A few cosmetic issues here.

  1. Let's avoid this level of nesting.
  2. closing parentheses conventionally appear right after what they're closing, not on a new line.

This comment has been minimized.

@jnothman

jnothman Aug 23, 2016

Member

could you sort this minor issue out

@jnothman

This comment has been minimized.

Member

jnothman commented Jun 21, 2016

For many estimators (e.g. linear models), predicting in parallel will degrade performance, particularly with the multiprocessing backend. Perhaps we should be using the threading backend at predict time, for which overhead is much smaller, or remove it altogether. You're welcome to benchmark, but I'm not sure how you'd come up with a realistic set of voters in the ensemble.

@olologin

This comment has been minimized.

Contributor

olologin commented Jul 31, 2016

@jnothman, Yes, you are right about this. I made some benchmarks with this code:

import xgboost as xgb
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
from sklearn.svm import SVC
from sklearn.ensemble import VotingClassifier
from sklearn.datasets import load_digits

digits = load_digits()

lr = LogisticRegression(C=1000)
svc = SVC(gamma=.001, C=100, kernel='rbf', probability=True)
rf= RandomForestClassifier(n_estimators=400)
extraGini = ExtraTreesClassifier(n_estimators=400, criterion='gini')
extraEntropy = ExtraTreesClassifier(n_estimators=400, criterion='entropy')
xgb_model = xgb.XGBClassifier(nthread=2, silent=True, colsample_bytree=.4, learning_rate=0.05, max_depth=4, gamma=0, n_estimators=800)

#Setup the ensemble classifier
eclf = VotingClassifier(estimators=[
    ('lr', lr),
    ('svc', svc),
    ('rf', rf),
    ('extraGini', extraGini),
    ('extraEntropy', extraEntropy),
    ('xgb_model', xgb_model)
], voting='soft', n_jobs=4)

eclf.fit(digits.data, digits.target)
%timeit eclf.predict(digits.data)
%timeit eclf.predict_proba(digits.data)

Version with multiprocessing backend takes ~2.1sec, threading backend ~1.3sec, without any parallelization ~1.9sec for predict_proba and predict.

I think I could revert parallelization in predict methods completely (though threading backend is faster, difference is not so big), or change it to threading backend. What do you think is best solution?

@jnothman

This comment has been minimized.

Member

jnothman commented Jul 31, 2016

I don't feel like I have the expertise to judge this, but would lean conservatively (i.e. no parallelism) since we can't know the makeup of the ensemble, and prediction time is small relative to fit time.

@jnothman

This comment has been minimized.

Member

jnothman commented Aug 1, 2016

Also, all prediction can be parallelised over samples, so slow predictors can deal with that, I supose.

@jnothman

This comment has been minimized.

Member

jnothman commented Aug 23, 2016

Also rebase on master and we'll try to get this merged soon.

@@ -207,3 +207,46 @@ def test_gridsearch():
grid = GridSearchCV(estimator=eclf, param_grid=params, cv=5)
grid.fit(iris.data, iris.target)
def test_parallel_predict_proba_on_toy_problem():

This comment has been minimized.

@jnothman

jnothman Aug 23, 2016

Member

Not sure this name is relevant any longer.

@@ -233,6 +233,9 @@ Enhancements
- Added new return type ``(data, target)`` : tuple option to :func:`load_iris` dataset. (`#7049 <https://github.com/scikit-learn/scikit-learn/pull/7049>`_)
By `Manvendra Singh`_ and `Nelson Liu`_.
- Added ``n_jobs`` parameter for :class:`VotingClassifier` to fit underlying estimators in parallel

This comment has been minimized.

@MechCoder

MechCoder Aug 24, 2016

Member

and also sample_weight

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 24, 2016

We can merge after rebase..

@jnothman

This comment has been minimized.

Member

jnothman commented Aug 24, 2016

Uh, @MechCoder I've not actually given this my +1.

@jnothman

This comment has been minimized.

Member

jnothman commented Aug 24, 2016

(nor has anyone but you)

assert_array_equal(eclf1.predict(X), eclf2.predict(X))
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))
sample_weight_ = np.random.RandomState(123).uniform(size=(len(y),))

This comment has been minimized.

@jnothman

jnothman Aug 24, 2016

Member

I don't get why there's an _ after sample_weight here.

@@ -47,6 +58,10 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
predicted class labels (`hard` voting) or class probabilities
before averaging (`soft` voting). Uses uniform weights if `None`.
n_jobs : int, optional (default=1)
The number of jobs to run in parallel for both `fit`.

This comment has been minimized.

@jnothman

jnothman Aug 24, 2016

Member

"both" should be removed.

fit won't show as code/tt unless you put it in double-backticks.

def test_sample_weight():
"""Tests sample_weight parameter of VotingClassifier"""
clf1 = LogisticRegression(random_state=123)

This comment has been minimized.

@jnothman

jnothman Aug 24, 2016

Member

You don't test what happens with base estimators lacking sample_weight support, it seems.

def _parallel_fit_estimator(estimator, X, y, sample_weight):
"""Private function used to fit an estimator within a job."""
if (sample_weight is not None and
has_fit_parameter(estimator, "sample_weight")):

This comment has been minimized.

@jnothman

jnothman Aug 24, 2016

Member

I'm not happy with this use of has_fit_parameter. So far all uses of has_fit_parameter raise an error when sample_weight isn't supported. The problem with the present use is that if sample_weight support is added to an estimator that formerly lacked it (something that has frequently occurred), the behaviour of VotingClassifier is changed without notice. I would rather, at this point, have a requirement that all estimators require sample_weight if sample_weight is provided. Another PR might consider allowing the user to specify a subset of estimators as accepting unweighted samples.

This comment has been minimized.

@MechCoder

MechCoder Aug 24, 2016

Member

Really really sorry that I did not catch this. :-(

Yes, we should raise an error at the beginning of fit if sample_weight is provided and not all estimators support sample_weight as done here (https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/multioutput.py#L81)

('lr', clf1), ('svc', clf3), ('knn', clf4)],
voting='soft')
msg = ('Underlying estimator \'knn\' does not support sample weights.')
assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight)

This comment has been minimized.

@jnothman

jnothman Aug 26, 2016

Member

PEP8 requires newline at end of file

@jnothman

This comment has been minimized.

Member

jnothman commented Aug 26, 2016

LGTM!

@jnothman jnothman changed the title from [MRG+1] parallelized VotingClassifier to [MRG+2] parallelized VotingClassifier Aug 26, 2016

@olologin

This comment has been minimized.

Contributor

olologin commented Aug 26, 2016

@jnothman, thanks. Sorry for having so many style problems in so small PR :) I'm working with C++ at work, so sometimes different code-style kicks in unintentionally.

@jnothman

This comment has been minimized.

Member

jnothman commented Aug 26, 2016

No big deal. A commit hook running flake8 on changed files can help...

@jnothman jnothman merged commit b3e122a into scikit-learn:master Aug 26, 2016

2 of 3 checks passed

continuous-integration/appveyor/pr Waiting for AppVeyor build to complete
Details
ci/circleci Your tests passed on CircleCI!
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@jnothman

This comment has been minimized.

Member

jnothman commented Aug 26, 2016

Thanks!

TomDLT added a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016

[MRG+2] parallelized VotingClassifier and sample_weight support (scik…
…it-learn#5805)

* parallelized VotingClassifier

* rename list to list_ to avoid problems

* Added new tests for sample_weight, multithreading and multiprocessing

* assert_equal -> assert_array_equal

* Fixed sample_weight existence check and test_sample_weight

* Code is clearer now

* Tests refactoring, 'backend' parameter removed

* Tests indentation fix

* reverted parallel predict and predict_proba to single threaded version

* what's new section added

* minor fixes

* check for sample_weight support in underlying estimators added

* newline at the end of test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment