Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random Forest Performance #1435

Closed
pprett opened this issue Dec 2, 2012 · 39 comments
Closed

Random Forest Performance #1435

pprett opened this issue Dec 2, 2012 · 39 comments

Comments

@pprett
Copy link
Member

pprett commented Dec 2, 2012

Random Forest is a popular classification technique; recent benchmarks [1][2] have shown that performance of sklearn's RandomForestClassifier is inferior to competing software implementations.

The performance penalty most likely stems from the underlying tree building procedure, however, changes here require considerable effort. These changes include:

  • Better representations for data set partitions (currently we use a bit mask)

Some low-hanging fruits may be found in the forest module itself:

[1] http://continuum.io/blog/wiserf-use-cases-and-benchmarks
[2] http://wise.io/wiserf.html

@pprett
Copy link
Member Author

pprett commented Dec 2, 2012

Cost of sampling w/ replacement:

Acc train time
bootstrap=True 0.95 149.68 s
bootstrap=False 0.96 94.38 s

the costs stem from a) fancy indexing and b) re-computing X_argsorted .
Both can be avoided by using sample weights rather than fancy indexing - then training times between the two should be equal.

@amueller
Copy link
Member

amueller commented Dec 2, 2012

wow that is a huge overhead.
Btw, a student in my lab just spent two weeks optimizing a C++ implementation for images, and I think there is still a lot to gain in our implementation of the trees.
We should probably start at @bdholt1's implementation an see if we can somehow speed it up.
Or do you think we should keep the presorting?

@pprett
Copy link
Member Author

pprett commented Dec 2, 2012

@amueller I too think that there's lots to be gained - regarding Brians PR: what is the idea behind his enhancement? AFAIK he sorts features only when they are needed (i.e. they fall into a feature sub-sample)?

@amueller
Copy link
Member

amueller commented Dec 2, 2012

Yes, exactly. thinking about it, maybe that doesn't help as much, ...
There was the idea to pre-sort and then rebuild the X_argsorted in each split to respect the order of the nodes (same PR I think).

Also, speeding up ExtraTrees is probably even easier.that doesn't need any sorting at all!

@pprett
Copy link
Member Author

pprett commented Dec 2, 2012

Personally, I don't expect high gains with lazy pre-sorting either. The second idea is something different, that's the way R's "randomForest" does it; it basically re-orders samples such that partitions are consecutive regions in the data / auxiliary arrays.

@pprett
Copy link
Member Author

pprett commented Dec 2, 2012

I did a quick hack on top of @ndawe 's PR #522 (sample weights for trees):

Here is the result on the MNIST benchmark::

ACC Train time
MNIST 0.9521 152.52 s
MNIST 0.9502 67.36 s

That's a 2-fold increase in performance (and probably more in memory efficiency)!

** The slight difference in ACC might be due to numerical issues

The branch is here https://github.com/pprett/scikit-learn/tree/rf-tree-weights .

@glouppe
Copy link
Contributor

glouppe commented Dec 2, 2012

Nice! That's quite a good news Peter. Maybe it's time to finally review all
@ndawe's good work. It's been a long time on my todo list.

On 2 December 2012 18:12, Peter Prettenhofer notifications@github.comwrote:

I did a quick hack on top of @ndawe https://github.com/ndawe 's PR #522#522 weights for trees):

Here is the result on the MNIST benchmark::
ACC Train time MNIST 0.9521 152.52 s MNIST 0.9502 67.36 s

That's a 2-fold increase in performance (and probably more in memory
efficiency)!

** The slight difference in ACC might be due to numerical issues

The branch is here
https://github.com/pprett/scikit-learn/tree/rf-tree-weights .


Reply to this email directly or view it on GitHubhttps://github.com//issues/1435#issuecomment-10932064.

@amueller
Copy link
Member

amueller commented Dec 4, 2012

Btw, I think this is a duplicate issue of a previous one by @bdholt1.

We should have a look at the data mining literature on building trees fast:
SLIQ: A fast scalable classifier for data mining - Mehta et al. 1996
RainForest: Gehrke, Ramakrishnan & Ganti 1998

@pprett
Copy link
Member Author

pprett commented Dec 4, 2012

I agree - the point is: we should add sample weights ASAP and get rid of the sampling w/ replacement overhead in RF - then we can tackle the tree building itself.

@glouppe we could split the work on reviewing #522 - are you busy at the moment?

regarding the literature pointed out be @amueller: literature on scalable decision tree induction is vast . We should definetly start to collect the most interesting approaches (github wiki page?) and do a reading group. I did some investigation of different software implementations (focus on GBRT) - you can find it here https://docs.google.com/spreadsheet/ccc?key=0AlBhwRZOwyxRdGo1V3A0eHYtNTY5TDVIa29pYWVjd1E (still work in progress though)

@glouppe
Copy link
Contributor

glouppe commented Dec 4, 2012

@pprett I have been more and more busy lately :-) But this is on my todo list. I plan to review the code at the end of the week. Your help is more than welcome though!

I also agree that we should add sample weights asap such that we can get rid (for free) of the sampling with replacement (huge) overhead.

@amueller
Copy link
Member

amueller commented Dec 5, 2012

Btw, can we close either this one or #964?

@glouppe
Copy link
Contributor

glouppe commented Dec 21, 2012

I got an access to wakari.io where one can use WiseRF.

WiseRF is indeed faster, but it is not as bad (i.e., not "5x to 100x faster") as what the benchmarks mentionned above indicate. I turned off bootstrap to see where we should be when it'll be properly reimplemented using sample weights.

This is nothing very scientific though. It is just one test.

In [22]: X, y = make_classification(n_samples=10000, n_features=1000, n_classes=2)

...

In [28]: clf = WiseRF(n_estimators=10,  n_jobs=1)

In [29]: %timeit clf.fit(X, y)
1 loops, best of 3: 11.5 s per loop

In [30]: clf = RandomForestClassifier(n_estimators=10, bootstrap=False, n_jobs=1)

In [31]: %timeit clf.fit(X, y)
1 loops, best of 3: 25.5 s per loop

@pprett
Copy link
Member Author

pprett commented Dec 22, 2012

great news - thanks for investigating - btw: do you know whether wiseRF supports instance subsampling? (cannot find anything in the anaconda docs)

@pprett
Copy link
Member Author

pprett commented Dec 27, 2012

@glouppe I recently checked the difference in test time performance (both, batches and single data points)

I'm testing with 10 features and 100 trees.

First, wiseRF::

X, y = datasets.make_hastie_10_2()  # X.shape == (12000, 10)
x = X[0].reshape((1, 10))
rf = WiseRF(n_estimators=100)

%timeit rf.predict(X)
1 loops, best of 3: 2.88 s per loop

%timeit rf.predict(x)
1 loops, best of 3: 2.69 s per loop

Cannot believe my eyes - 2.69 seconds for a single data point - are you kidding me?! Looks like there is a huge overhead involved; performance-wise it doesn't make a difference if you predict 10000 examples or just one.

now sklearn::

%timeit skrf.predict(X)
1 loops, best of 3: 276 ms per loop

%timeit skrf.predict(x)
100 loops, best of 3: 4.94 ms per loop

that's better but still - 5ms for one data point and 100 trees is pretty slow - IMHO we could do better (fewer function calls, faster input checks)

@glouppe
Copy link
Contributor

glouppe commented Dec 27, 2012

@pprett From what I have been able to discover, they actually store the forest as a string (!). (See the WiseRF.forest attribute.) I guess it is related to the Blaze module they mention on their blog. This would indeed explain this large overhead.

http://continuum.io/blog/blaze

@glouppe
Copy link
Contributor

glouppe commented Dec 27, 2012

I am also glad to see that we are "up to 544x faster" at prediction time (sic) ;)

@ndawe
Copy link
Member

ndawe commented Jan 7, 2013

Adding the content of #1532 here:

Currently the tree fitting procedure tries all possible splits between unique values of each feature in find_best_split and _smallest_sample_larger_than in _tree.pyx. This can be very expensive for large datasets.

TMVA [1] implements both this same procedure as well as a mode that histograms each feature with a fixed number of bins [2]:

The cut values are optimised by scanning over the variable range with a granularity that is set
via the option nCuts. The default value of nCuts=20 proved to be a good compromise between
computing time and step size. Finer stepping values did not increase noticeably the performance
of the BDTs. However, a truly optimal cut, given the training sample, is determined by setting
nCuts=-1. This invokes an algorithm that tests all possible cuts on the training sample and finds
the best one.

[1] http://tmva.sourceforge.net/
[2] http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf (sections 8.12.2 and 8.12.3)

@bdholt1
Copy link
Member

bdholt1 commented Jan 8, 2013

I really like the idea of introducing a n_cuts variable with a sensible default value.

@amueller
Copy link
Member

amueller commented Feb 5, 2013

I was just wondering about the in-place sorting of X_argsorted that we discussed. That would mean we need a copy of X_argsorted per tree and sharing the arrays across processes (as @ogrisel is working on) would not work any more, right?

@glouppe
Copy link
Contributor

glouppe commented Feb 5, 2013

Correct. :/

On 5 February 2013 09:39, Andreas Mueller notifications@github.com wrote:

I was just wondering about the in-place sorting of X_argsorted that we
discussed. That would mean we need a copy of X_argsorted per tree and
sharing the arrays across processes (as @ogriselhttps://github.com/ogriselis working on) would not work any more, right?


Reply to this email directly or view it on GitHubhttps://github.com//issues/1435#issuecomment-13119521.

@pprett
Copy link
Member Author

pprett commented Feb 5, 2013

I agree

2013/2/5 Andreas Mueller notifications@github.com

I was just wondering about the in-place sorting of X_argsorted that we
discussed. That would mean we need a copy of X_argsorted per tree and
sharing the arrays across processes (as @ogriselhttps://github.com/ogriselis working on) would not work any more, right?


Reply to this email directly or view it on GitHubhttps://github.com//issues/1435#issuecomment-13119521.

Peter Prettenhofer

@amueller
Copy link
Member

amueller commented Feb 5, 2013

but we probably still want that, right?l
On the other hand we don't need to store the sample_mask, but I guess that is not that big as it is bool (?)

@ogrisel
Copy link
Member

ogrisel commented Feb 5, 2013

I was just wondering about the in-place sorting of X_argsorted that we discussed. That would mean we need a copy of X_argsorted per tree and sharing the arrays across processes (as @ogrisel is working on) would not work any more, right?

One copy of the whole data per-sub estimator? That does not seem like a reasonable approach to me. Unless you pre-allocate one temporary buffer per computational worker and reuse those buffer sequentially for each new subestimator fit on the the indiviual workers. However if the original dataset is barely fitting in RAM, then the n_cpus independent X_argsorted buffers won't fit at the same time and that will prevent to use all the available CPU power.

@amueller
Copy link
Member

amueller commented Feb 5, 2013

so we want to make that an option maybe? Or think harder about how to implement it. but I don't think there is a chance to not allocate something dataset-size per estimator and still be fast. Also, it would be good to know how large the sample mask that we currently use is.

@glouppe
Copy link
Contributor

glouppe commented Feb 5, 2013

sample_mask is small, O(n).

On 5 February 2013 11:05, Andreas Mueller notifications@github.com wrote:

so we want to make that an option maybe? Or think harder about how to
implement it. but I don't think there is a chance to not allocate something
dataset-size per estimator and still be fast. Also, it would be good to
know how large the sample mask that we currently use is.


Reply to this email directly or view it on GitHubhttps://github.com//issues/1435#issuecomment-13122419.

@jtoy
Copy link

jtoy commented Feb 21, 2013

are predictions supposed to be slow? rf.predict(one_observation_with_ten_vars) takes ~ 1 second , that is a long time with millions of rows to predict.

@bdholt1
Copy link
Member

bdholt1 commented Feb 22, 2013

I think you'll find that there is some overhead so it should be faster with
more samples.
On Feb 21, 2013 10:58 PM, "jtoy" notifications@github.com wrote:

are predictions supposed to be slow?
rf.predict(one_observation_with_ten_vars) takes ~ 1 second , that is a long
time with millions of rows to predict.


Reply to this email directly or view it on GitHubhttps://github.com//issues/1435#issuecomment-13918541.

@pprett
Copy link
Member Author

pprett commented Feb 22, 2013

@jtoy can you please post your RandomForest arguments,
`one_observation_with_ten_vars.flagsand one_observation_with_ten_vars.shape``.

2013/2/22 Brian Holt notifications@github.com

I think you'll find that there is some overhead so it should be faster
with
more samples.
On Feb 21, 2013 10:58 PM, "jtoy" notifications@github.com wrote:

are predictions supposed to be slow?
rf.predict(one_observation_with_ten_vars) takes ~ 1 second , that is a
long
time with millions of rows to predict.


Reply to this email directly or view it on GitHub<
https://github.com/scikit-learn/scikit-learn/issues/1435#issuecomment-13918541>.


Reply to this email directly or view it on GitHubhttps://github.com//issues/1435#issuecomment-13931149.

Peter Prettenhofer

@amueller
Copy link
Member

amueller commented Mar 2, 2013

After meditating over it a bit, I think it would be easiest if we first try to speed up the extra-trees. They don't need sorting and could work by just storing lists of sample points in each node if I am not mistaken.
So they don't need to copy anything and can still be fast.

For Random Forests, it looks like there is a non-trivial memory / speed tradeoff.
My first thought on RandomForests would be to optionally use binning for the splits, which is also not that hard and would create a huge speedup (as again you don't need the sample mask and can just store lists of points).

@bdholt1
Copy link
Member

bdholt1 commented Mar 2, 2013

I've got a test implementation of the binning working and it's definitely
quicker although I haven't done proper benchmarking.
On Mar 2, 2013 5:46 PM, "Andreas Mueller" notifications@github.com wrote:

After meditating over it a bit, I think it would be easiest if we first
try to speed up the extra-trees. They don't need sorting and could work by
just storing lists of sample points in each node if I am not mistaken.
So they don't need to copy anything and can still be fast.

For Random Forests, it looks like there is a non-trivial memory / speed
tradeoff.
My first thought on RandomForests would be to optionally use binning for
the splits, which is also not that hard and would create a huge speedup (as
again you don't need the sample mask and can just store lists of points).


Reply to this email directly or view it on GitHubhttps://github.com//issues/1435#issuecomment-14332041
.

@ndawe
Copy link
Member

ndawe commented Mar 2, 2013

Another benefit of binning is that since the cuts are only ever placed at the bin edges, the tree is somewhat less prone to overfitting.

@jtoy
Copy link

jtoy commented Apr 27, 2013

is this still an issue? I have previously gone around it by doing multiple predictions at a time, but I believe the problem is still there.

@amueller
Copy link
Member

@jtoy which issue in particular are you referring to? There is still room for improvement in the forest speed ;)

@jtoy
Copy link

jtoy commented Apr 27, 2013

The issue I have seen is running predictions on scikit vs R on the same datasets, scikit is several orders of magnitude slower.

@amueller
Copy link
Member

Several orders of magnitude? That doesn't seem right. I guess it would depend a lot on the parameters and dataset. But I think there shouldn't be more than a factor of, say 2 or 3, afaik. @glouppe ?

@pprett
Copy link
Member Author

pprett commented Apr 28, 2013

@jtoy can you please elaborate on the benchmark - dataset size, model parameters, etc - how do you test: batch prediction or prediction of single data points?

One issue that often bites users is the fact that our forest estimators use n_jobs processes both at training and test time. At test time, forking multiple processes can be quite costly - especially if you predict single data points. I strongly recommend that you set n_jobs to 1 after you trained your model.

@pprett
Copy link
Member Author

pprett commented Jul 22, 2013

Thanks to #2131 our Random Forest implementation is now much more efficient.

Memory consumption is still an issue, though. I opened a dedicated issue to track its progress #2179 .

@pprett pprett closed this as completed Jul 22, 2013
@ogrisel
Copy link
Member

ogrisel commented Jul 7, 2015

Just to put a note: despite the fact that our implementation of RFs and GBRTs is much more optimized than it used to be, it still does not implement binning / approximate histograms for speeding up the best split search as mentioned in @ndawe's comment: #1435 (comment).

xgboost (another very fast open source implementation of RFs and GBRTs written in C++) is apparently using approximate feature histogram implemented with a custom quantile sketch datastructure: https://github.com/dmlc/xgboost/blob/master/src/tree/updater_histmaker-inl.hpp

This might be the primary reason of the improved performance of xgboost (even in single-threaded mode).

@glouppe
Copy link
Contributor

glouppe commented Jul 9, 2015

I met the author of xgboost some months ago and he told that binned splits are used only in the non-distributed case. I agree however that we should look a bit more into this implementation in order to better understand the performance gap.

In the case of boosting, one thing that I know is that the trees that are built with XGBoost are strictly different from ours, because of different loss functions and different impurity criteria. In particular, both include regularization terms which prevent complex (and deep) trees to be constructed, which in addition to generalize better may also be faster to construct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants