Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+2] parallelized VotingClassifier #5805

Merged
merged 13 commits into from Aug 26, 2016

Conversation

olologin
Copy link
Contributor

First version, looks like it's working :)
Also, i added sample_weight parameter into fit method, don't know if it's appropriate.

@olologin
Copy link
Contributor Author

If someone wants to test it with nested multithreading, e.g. in the code below we call version of VotingClassifier from cross_val_score.

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import VotingClassifier
from sklearn import datasets
from sklearn.model_selection import cross_val_score

# Load the iris dataset and randomly permute it
iris = datasets.load_iris()
X, y = iris.data[:, 1:3], iris.target


"""Check classification by majority label on dataset iris."""
clf1 = LogisticRegression(random_state=123)
clf2 = RandomForestClassifier(random_state=123)
clf3 = GaussianNB()
eclf = VotingClassifier(estimators=[
    ('lr', clf1), ('rf', clf2), ('gnb', clf3)],
    voting='hard', n_jobs=-1, backend='multiprocessing')
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy', n_jobs=-1)

With backend='multiprocessing' it shows message UserWarning: Multiprocessing-backed parallel loops cannot be nested, setting n_jobs=1, so it just falls back to -1 threads in cross_val_score, but VotingClassifier entirely works in 1 thread. With backend='threading' it doesn't show any warning.

In both cases everything works fine and provides correct results.

@giorgiop
Copy link
Contributor

In both cases everything works fine and provides correct results.

Could you please add tests showing that? Same for the sample_weight

@olologin
Copy link
Contributor Author

Could you please add tests showing that? Same for the sample_weight

Ok, i'll add couple of tests with multithreading/processing. I've tested it already on existing tests from test_voting_classifier, with different default threading/processing parameters.


assert_array_equal(eclf1.predict(X), eclf2.predict(X))
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two empty lines between functions please (PEP8)

@betatim
Copy link
Member

betatim commented Nov 16, 2015

There are a few more pep8 related things to fix (param = 42 should be param=42, etc) in the code. If you install pep8 to have it tell you about them.

@olologin olologin changed the title parallelized VotingClassifier [MRG] parallelized VotingClassifier Jan 8, 2016
The number of jobs to run in parallel for both `fit` and `predict`.
If -1, then the number of jobs is set to the number of cores.

backend: str, {'multiprocessing', 'threading'} (default='multiprocessing')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should allow the user to choose this. AFAIK, the threading backend is useful only if the code is written in a nogil block and the user is highly unlikely to know in which estimators that this is being done.

We should maybe just allow the default "multiprocessing" backend just like in cross_val_score and in other places.

@MechCoder
Copy link
Member

@olologin Could you rebase?

@olologin olologin force-pushed the votingclassifier_multithreading branch from 955733c to 4c00284 Compare April 30, 2016 05:49
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))


def test_parallel_majority_label_iris():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test not redundant? Isn't the previous test sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It tests "hard" voting stability (So that 1threaded and 2threaded version will give same results) on iris dataset, while previous one tests 'soft' version.
Anyway first test is fast (Small artificial dataset), so it shouldn't increase testing time noticeably.

@MechCoder
Copy link
Member

MechCoder commented Apr 30, 2016

LGTM pending nitpick @olologin

cc: @agramfort | @TomDLT for a quick second pass.

@MechCoder MechCoder changed the title [MRG] parallelized VotingClassifier [MRG+1] parallelized VotingClassifier Apr 30, 2016
@MechCoder
Copy link
Member

And sorry for the delays..

@olologin
Copy link
Contributor Author

olologin commented May 1, 2016

@MechCoder , Actually I almost forgot about this PR, thought it was someone else’s :)

Could you take a look at this one too #6116 ? I think it's good PR too.

@MechCoder
Copy link
Member

I can try after the sem breaks, if no one gets to it by then.


self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_parallel_fit_estimator)(
clone(clf),
Copy link
Member

@jnothman jnothman Jun 21, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few cosmetic issues here.

  1. Let's avoid this level of nesting.
  2. closing parentheses conventionally appear right after what they're closing, not on a new line.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you sort this minor issue out

@jnothman
Copy link
Member

For many estimators (e.g. linear models), predicting in parallel will degrade performance, particularly with the multiprocessing backend. Perhaps we should be using the threading backend at predict time, for which overhead is much smaller, or remove it altogether. You're welcome to benchmark, but I'm not sure how you'd come up with a realistic set of voters in the ensemble.

@olologin
Copy link
Contributor Author

olologin commented Jul 31, 2016

@jnothman, Yes, you are right about this. I made some benchmarks with this code:

import xgboost as xgb
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
from sklearn.svm import SVC
from sklearn.ensemble import VotingClassifier
from sklearn.datasets import load_digits

digits = load_digits()

lr = LogisticRegression(C=1000)
svc = SVC(gamma=.001, C=100, kernel='rbf', probability=True)
rf= RandomForestClassifier(n_estimators=400)
extraGini = ExtraTreesClassifier(n_estimators=400, criterion='gini')
extraEntropy = ExtraTreesClassifier(n_estimators=400, criterion='entropy')
xgb_model = xgb.XGBClassifier(nthread=2, silent=True, colsample_bytree=.4, learning_rate=0.05, max_depth=4, gamma=0, n_estimators=800)

#Setup the ensemble classifier
eclf = VotingClassifier(estimators=[
    ('lr', lr),
    ('svc', svc),
    ('rf', rf),
    ('extraGini', extraGini),
    ('extraEntropy', extraEntropy),
    ('xgb_model', xgb_model)
], voting='soft', n_jobs=4)

eclf.fit(digits.data, digits.target)
%timeit eclf.predict(digits.data)
%timeit eclf.predict_proba(digits.data)

Version with multiprocessing backend takes ~2.1sec, threading backend ~1.3sec, without any parallelization ~1.9sec for predict_proba and predict.

I think I could revert parallelization in predict methods completely (though threading backend is faster, difference is not so big), or change it to threading backend. What do you think is best solution?

@jnothman
Copy link
Member

I don't feel like I have the expertise to judge this, but would lean conservatively (i.e. no parallelism) since we can't know the makeup of the ensemble, and prediction time is small relative to fit time.

@jnothman
Copy link
Member

jnothman commented Aug 1, 2016

Also, all prediction can be parallelised over samples, so slow predictors can deal with that, I supose.

@olologin olologin force-pushed the votingclassifier_multithreading branch from 904a172 to 520089c Compare August 1, 2016 02:48
@MechCoder
Copy link
Member

We can merge after rebase..

@olologin olologin force-pushed the votingclassifier_multithreading branch from e950b2f to 1c2d6b4 Compare August 24, 2016 06:36
@olologin olologin force-pushed the votingclassifier_multithreading branch from 1c2d6b4 to 86d02f1 Compare August 24, 2016 07:20
@jnothman
Copy link
Member

Uh, @MechCoder I've not actually given this my +1.

@jnothman
Copy link
Member

(nor has anyone but you)

assert_array_equal(eclf1.predict(X), eclf2.predict(X))
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))

sample_weight_ = np.random.RandomState(123).uniform(size=(len(y),))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get why there's an _ after sample_weight here.

('lr', clf1), ('svc', clf3), ('knn', clf4)],
voting='soft')
msg = ('Underlying estimator \'knn\' does not support sample weights.')
assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8 requires newline at end of file

@jnothman
Copy link
Member

LGTM!

@jnothman jnothman changed the title [MRG+1] parallelized VotingClassifier [MRG+2] parallelized VotingClassifier Aug 26, 2016
@olologin
Copy link
Contributor Author

@jnothman, thanks. Sorry for having so many style problems in so small PR :) I'm working with C++ at work, so sometimes different code-style kicks in unintentionally.

@jnothman
Copy link
Member

No big deal. A commit hook running flake8 on changed files can help...

@jnothman jnothman merged commit b3e122a into scikit-learn:master Aug 26, 2016
@jnothman
Copy link
Member

Thanks!

TomDLT pushed a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016
…it-learn#5805)

* parallelized VotingClassifier

* rename list to list_ to avoid problems

* Added new tests for sample_weight, multithreading and multiprocessing

* assert_equal -> assert_array_equal

* Fixed sample_weight existence check and test_sample_weight

* Code is clearer now

* Tests refactoring, 'backend' parameter removed

* Tests indentation fix

* reverted parallel predict and predict_proba to single threaded version

* what's new section added

* minor fixes

* check for sample_weight support in underlying estimators added

* newline at the end of test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants