New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] faster sorting in trees; random forests almost 2× as fast #2747
Conversation
Nice! I tagged this PR for 0.15 milestone if everyone agrees :) |
On the flip side, the optimization to the sorting is so good that it makes the rest of the tree code look slow :p (But again, covertype is really easy. I'll try 20news after SVD-200 as well.) |
@larsmans I've a benchmark suite that contains datasets with different characteristics -- will send the results tomorrow |
You can try on MNIST as well with the mldata loader: there is a script in the MLP PR: https://github.com/IssamLaradji/scikit-learn/blob/multilayer-perceptron/benchmarks/bench_mnist.py |
@pprett Then be sure to use vanilla quicksort, not the randomized one. Shuffling turns out to be extremely expensive. |
On 20news, all categories, 100 SVD components, 500 trees and four cores of an Intel i7, training time goes down from 24.181s to 11.683s. F1-score goes down from ~.75 to ~.6, though, so I may have a bug somewhere... |
Covertype accuracy actually went down the drain as well. This wasn't the case before I rebased, I must have made a mistake in handling the new |
while ((p + 1 < end) and | ||
(X[X_sample_stride * samples[p + 1] + X_fx_stride * current_feature] <= | ||
X[X_sample_stride * samples[p] + X_fx_stride * current_feature] + EPSILON_FLT)): | ||
while p + 1 < end and Xf[p + 1] <= Xf[p] + EPSILON_FLT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@larsmans in the block above you set Xf[p]
for p in range(0, end-start)
. Here p
runs from range(start, end)
- is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I did my profiling of the tree code a couple of weeks ago it turned out that for datasets with lots of split points the bulk of time is spent in the while condition -- maybe part of your speed-up stems from this refactoring rather than the new sorting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@larsmans in the block above you set Xf[p] for p in range(0, end-start). Here p runs from range(start, end) - is that correct?
+1, indices are not correct. Please always make p
ranges in [start;end)
to avoid bugs and confusion with other parts of the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, this is it. Will change this bit tonight.
@pprett No, this isn't actually the cause of the speedup, it was near 50% before I even introduced this bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great - thx Lars
2014/1/14 Lars Buitinck notifications@github.com
In sklearn/tree/_tree.pyx:
# Evaluate all splits self.criterion.reset() p = start while p < end:
while ((p + 1 < end) and
(X[X_sample_stride \* samples[p + 1] + X_fx_stride \* current_feature] <=
X[X_sample_stride \* samples[p] + X_fx_stride \* current_feature] + EPSILON_FLT)):
while p + 1 < end and Xf[p + 1] <= Xf[p] + EPSILON_FLT:
Right, this is it. Will change this bit tonight.
@pprett https://github.com/pprett No, this isn't actually the cause of
the speedup, it was near 50% before I even introduced this bug.—
Reply to this email directly or view it on GitHubhttps://github.com//pull/2747/files#r8859593
.
Peter Prettenhofer
@larsmans can I benchmark the enhancements or are you still working out some issues in the code? |
@pprett As long as the trees are not guaranteed to be the same (which is not the case since accuracy drops), there is no point in benchmarking the current changes. We should invest some time to try to figure this out. I can have a look tomorrow. |
I re-applied the patches on top of current master. The first patch, faster heapsort, can AFAIC be merged into master immediately. It gives an almost two-fold speedup and it passes the testsuite. The second patch, quicksort, doesn't pass all the tests due to randomness issues, but further speeds up tree learning significantly. |
I'm running benchmarks now - should be finished in a couple of hours 2014/1/19 Lars Buitinck notifications@github.com
Peter Prettenhofer |
But do you get good test accuracy on covertype and other benchmarks with quicksort? |
Stability of the sorting algorithm shouldn't in theory have any impact on the trees that are built. As long as the feature values are sorted, the same cutting points should be found. I'll investigate when i'll have some time. |
i += 1 | ||
|
||
l -= 1 | ||
r += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is wrong. Removing these two lines solve all the bugs on my box. :)
@larsmans I just submitted the fix to your branch. Since I am sure caching the feature values should be also profitable to other splitters, I'd like to make similar changes to |
here are some benchmark results -- I only looked at the first commit 238d692 All values are relative to Master where Master is the version after @arjoly recent MSE enhancement (w/o @jnothman tree structure refactoring) -- sorry for that but I only realized when it was too late. |
I looked at classification_params = {'n_estimators': 100, 'max_depth': None,} regression_params = {'n_estimators': 100, 'max_depth': None, 'max_features': 0.3, 'min_samples_leaf': 1, } |
@larsmans each values is the mean of 3 repetitions each with a different random seed (the same for both branches) |
@larsmans which parameters did you use for your covtype benchmark? |
Just the standard ones from the covtype script, |
@larsmans I ran only single threaded experiments |
I just tried to run:
On this branch and master. The validation error is the same (~0.021). However I do not see any significant training time improvement (the standard deviations overlap). Maybe the speedup observed by @larsmans is architecture specific (e.g. related to the CPU cache size)? Here are some attributes from one of my cores:
|
Introsort is about as fast as quicksort. The only reason to use it is to get rid of the worst-case quadratic behavior. |
I don't have access to my workstation right now, but on my laptop I get for the covertype bench:
So this is very good (at least as fast as the previous benchmarks run with quicksort). |
I had a quick look at the code and it looks fine to me although I am not familiar with sorting algorithms. +1 for merging on my side. |
I also ran my memory leak detection script with the |
Thanks for the check! @larsmans feel free to merge this in :) On 26/01/2014, Olivier Grisel notifications@github.com wrote:
|
awesome!!! great work! |
[MRG] faster sorting in trees; random forests almost 2× as fast
\o/ |
Please don't forget to add an entry in the |
Great work! Nice to know we persist in teaching diverse sorting algorithms On 27 January 2014 09:09, Olivier Grisel notifications@github.com wrote:
|
After this merge, the GBRT regression takes 2 times longer on a data set than the previous build (commit bf1635d). The loss score seems OK. BTW, the data set is MSLR-WEB10K/Fold1 (MS learning of rank) |
Could you try with 31491f9 as head On 27 January 2014 08:15, Guocong Song notifications@github.com wrote:
|
thanks @songgc - I did a quick benchmark using my solar dataset (regression). I looked at master ( ), introsort (31491f9), MSE Optim (0b7c79b), best-first (834b375). @songgc I find the 2x performance degression quite harsh - can you tell me which parameters you used (max_features, max_depth, n_estimators). |
@songgc have you fixed the I also tried to bench |
@ogrisel ok - I'm running a benchmark suite now with 3 repetitions between current master and @arjoly MSE optimization -- I keep you posted. @songgc it would be great if you could post the parameters you used -- tree building performance can differ quite considerably depending on the parameters used (eg. max_features) |
this one just uses smaller datasets -- it looks good IMHO - I used the following params::
|
Same here I do not see any regression between 31491f9 and the current master: I trained
If I understand correctly, the only commit that is impacting between 31491f9 and master is @glouppe's cache optim a681c9b (aka: ENH Make PresortBestSplitter cache friendly + cosmetics). It seems to indeed work on my box by removing 3mins of training time. |
My apologies for the false alarm! I have found that I installed version 0.14.1rather than the master branch... pip install scikit-learn git+https://github.com/scikit-learn/scikit-learn.git gives me the stable version. My lesson is "check version first". Currently, my benchmarks are consistent with you guys. The speed improvement is impressive! I don't to need to transfer data to R for the GBM package:) |
no worries @songgc thanks for double-checking -- I'd say there is definitely no reason now to switch to R for the randomForest package ;) |
I've code for this -- next days are a bit busy -- will post it later this 2014-01-27 Lars Buitinck notifications@github.com
Peter Prettenhofer |
@larsmans Benchmark result: |
@songgc the current master also includes an option to build GBM style trees in GradientBoostingRegressor|Classifier -- use the |
Changed the heapsort in the tree learners into a quicksort and gave it cache-friendlier data access. Speeds up RF longer almost two-fold. In fact, profiling with @fabianp's
yep
tool show the time taken by sort to go down from 65% to <10% of total running time in the covertype benchmark.This is taking longer than I thought but I figured I should at least show @glouppe and @pprett what I've got so far.
TODO: