[MRG] faster sorting in trees; random forests almost 2× as fast #2747

Merged
merged 3 commits into from Jan 26, 2014

Projects

None yet

9 participants

@larsmans
Member

Changed the heapsort in the tree learners into a quicksort and gave it cache-friendlier data access. Speeds up RF longer almost two-fold. In fact, profiling with @fabianp's yep tool show the time taken by sort to go down from 65% to <10% of total running time in the covertype benchmark.

This is taking longer than I thought but I figured I should at least show @glouppe and @pprett what I've got so far.

TODO:

  • more benchmarks, esp. on a denser dataset than covertype (sparse data is easy :)
  • make tests pass
  • clean up code
  • filter out the cruft
  • decide on the final algorithm: quicksort takes O(n²) time in the worst case, which can be avoided by introsort at the expense of more code.
@ogrisel
Member
ogrisel commented Jan 13, 2014

Nice! I tagged this PR for 0.15 milestone if everyone agrees :)

@larsmans
Member

On the flip side, the optimization to the sorting is so good that it makes the rest of the tree code look slow :p

(But again, covertype is really easy. I'll try 20news after SVD-200 as well.)

@pprett
Member
pprett commented Jan 13, 2014

@larsmans I've a benchmark suite that contains datasets with different characteristics -- will send the results tomorrow

@ogrisel
Member
ogrisel commented Jan 13, 2014

You can try on MNIST as well with the mldata loader: there is a script in the MLP PR: https://github.com/IssamLaradji/scikit-learn/blob/multilayer-perceptron/benchmarks/bench_mnist.py

@larsmans
Member

@pprett Then be sure to use vanilla quicksort, not the randomized one. Shuffling turns out to be extremely expensive.

@larsmans
Member

pprof (Google perftools) graph w/ quicksort on covertype:

quicksort-pprof

@larsmans
Member

On 20news, all categories, 100 SVD components, 500 trees and four cores of an Intel i7, training time goes down from 24.181s to 11.683s. F1-score goes down from ~.75 to ~.6, though, so I may have a bug somewhere...

@larsmans
Member

Covertype accuracy actually went down the drain as well. This wasn't the case before I rebased, I must have made a mistake in handling the new X_fx_stride.

@pprett pprett commented on the diff Jan 14, 2014
sklearn/tree/_tree.pyx
# Evaluate all splits
self.criterion.reset()
p = start
while p < end:
- while ((p + 1 < end) and
- (X[X_sample_stride * samples[p + 1] + X_fx_stride * current_feature] <=
- X[X_sample_stride * samples[p] + X_fx_stride * current_feature] + EPSILON_FLT)):
+ while p + 1 < end and Xf[p + 1] <= Xf[p] + EPSILON_FLT:
@pprett
pprett Jan 14, 2014 Member

@larsmans in the block above you set Xf[p] for p in range(0, end-start). Here p runs from range(start, end) - is that correct?

@pprett
pprett Jan 14, 2014 Member

When I did my profiling of the tree code a couple of weeks ago it turned out that for datasets with lots of split points the bulk of time is spent in the while condition -- maybe part of your speed-up stems from this refactoring rather than the new sorting.

@glouppe
glouppe Jan 14, 2014 Member

@larsmans in the block above you set Xf[p] for p in range(0, end-start). Here p runs from range(start, end) - is that correct?

+1, indices are not correct. Please always make p ranges in [start;end) to avoid bugs and confusion with other parts of the code.

@larsmans
larsmans Jan 14, 2014 Member

Right, this is it. Will change this bit tonight.

@pprett No, this isn't actually the cause of the speedup, it was near 50% before I even introduced this bug.

@pprett
pprett Jan 14, 2014 Member

great - thx Lars

2014/1/14 Lars Buitinck notifications@github.com

In sklearn/tree/_tree.pyx:

         # Evaluate all splits
         self.criterion.reset()
         p = start

         while p < end:
  •            while ((p + 1 < end) and
    
  •                   (X[X_sample_stride \* samples[p + 1] + X_fx_stride \* current_feature] <=
    
  •                    X[X_sample_stride \* samples[p] + X_fx_stride \* current_feature] + EPSILON_FLT)):
    
  •            while p + 1 < end and Xf[p + 1] <= Xf[p] + EPSILON_FLT:
    

Right, this is it. Will change this bit tonight.

@pprett https://github.com/pprett No, this isn't actually the cause of
the speedup, it was near 50% before I even introduced this bug.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747/files#r8859593
.

Peter Prettenhofer

@glouppe glouppe and 1 other commented on an outdated diff Jan 14, 2014
sklearn/tree/_tree.pyx
- # Sort samples along that feature
- sort(X, X_sample_stride, X_fx_stride, current_feature, samples + start, end - start)
+ #sort(X, X_sample_stride, X_fx_stride, current_feature, samples + start, end - start)
+ # Sort samples along that feature; first copy the feature values
+ # for the active samples into Xf, s.t. Xf[i] == X[samples[i], j],
+ # so the sort uses the cache more effectively. Shuffle makes sure
+ # we get a randomized quicksort.
+ shuffle(samples, self.n_samples, random_state)
+ for p in range(end - start):
+ #with gil: print(samples[p])
+ #printf("%zd %zd\n", samples[p], self.n_samples)
+ #if samples[p] > self.n_samples:
+ #abort()
+ Xf[p] = X[X_sample_stride * samples[p + start]
+ + X_fx_stride * current_feature]
+ qsort(Xf, samples + start, end - start)
@glouppe
glouppe Jan 14, 2014 Member

With regards to my comment below, I would therefore replace this part of the code with:

for p in range(start, end):
    Xf[p] = X[X_sample_stride * samples[p] + X_fx_stride * current_feature]
qsort(Xf + start, samples + start, end - start)
@larsmans
larsmans Jan 14, 2014 Member

@glouppe That would overallocate. Could that be a problem in terms of peak memory usage? If not, I'll simplify the code.

@glouppe
glouppe Jan 14, 2014 Member

It is not really a huge problem in my opinion. In terms of memory usage it is like adding a new feature column. Anyway, you already overallocate Xf since you make it of size self.n_samples above (and not of size end - start). This seems to work :)

In fact, you could also avoid the mallocs/frees by allocating Xf only once during init, as we do for samples. That would be a bit cleaner in my opinion.

@larsmans
larsmans Jan 14, 2014 Member

I'll see if that's better. It would save adding a return value to this function.

@glouppe glouppe commented on an outdated diff Jan 15, 2014
sklearn/tree/_tree.pyx
- if current_threshold == X[X_sample_stride * samples[p] + X_fx_stride * current_feature]:
- current_threshold = X[X_sample_stride * samples[p - 1] + X_fx_stride * current_feature]
+ if current_threshold == Xf[p]:
+ current_threshold = Xf[p]
@glouppe
glouppe Jan 15, 2014 Member

It should be current_threshold = Xf[p - 1].

@arjoly arjoly commented on the diff Jan 15, 2014
sklearn/tree/_tree.pyx
+cdef inline void shuffle(SIZE_t* samples, SIZE_t n,
+ UINT32_t* random_state) nogil:
+ # Fisher-Yates shuffle
+ cdef SIZE_t i, j
+
+ for i in range(n):
+ j = i + rand_int(n - i, random_state)
+ samples[i], samples[j] = samples[j], samples[i]
+
+
+cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil:
+ # Helper for sort
+ Xf[i], Xf[j] = Xf[j], Xf[i]
@arjoly
arjoly Jan 15, 2014 Member

Swapping with this syntax generates one more C instruction than what is really needed.

  /* "sklearn/tree/_tree.pyx":1141
 * cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil:
 *     # Helper for sort
 *     Xf[i], Xf[j] = Xf[j], Xf[i]             # <<<<<<<<<<<<<<
 *     samples[i], samples[j] = samples[j], samples[i]
 * 
 */
  __pyx_t_1 = (__pyx_v_Xf[__pyx_v_j]);
  __pyx_t_2 = (__pyx_v_Xf[__pyx_v_i]);
  (__pyx_v_Xf[__pyx_v_i]) = __pyx_t_1;
  (__pyx_v_Xf[__pyx_v_j]) = __pyx_t_2;
@larsmans
larsmans Jan 15, 2014 Member

In C, yes. But the assembly for

int t = a[i];
a[i] = a[j];
a[j] = t;

and

int ti = a[i];
int tj = a[j];
a[i] = tj;
a[j] = ti;

is identical (gcc -O2 -S).

The only thing I still need to try is putting an if (i != j) around this, but that's for later.

@pprett
Member
pprett commented Jan 19, 2014

@larsmans can I benchmark the enhancements or are you still working out some issues in the code?

@glouppe
Member
glouppe commented Jan 19, 2014

@pprett As long as the trees are not guaranteed to be the same (which is not the case since accuracy drops), there is no point in benchmarking the current changes. We should invest some time to try to figure this out. I can have a look tomorrow.

@larsmans
Member

I re-applied the patches on top of current master. The first patch, faster heapsort, can AFAIC be merged into master immediately. It gives an almost two-fold speedup and it passes the testsuite.

The second patch, quicksort, doesn't pass all the tests due to randomness issues, but further speeds up tree learning significantly.

@larsmans
Member

@glouppe You can certainly benchmark 238d692, it passes all of the tests.

The second produces somewhat different trees. I'm not sure if we can ever fix that, since neither quicksort nor heapsort are stable sorts.

@pprett
Member
pprett commented Jan 19, 2014

I'm running benchmarks now - should be finished in a couple of hours

2014/1/19 Lars Buitinck notifications@github.com

@glouppe https://github.com/glouppe You can certainly benchmark 238d692238d692,
it passes all of the tests.

The second produces somewhat different trees. I'm not sure if we can ever
fix that, since neither quicksort nor heapsort are stable sorts.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32714620
.

Peter Prettenhofer

@ogrisel
Member
ogrisel commented Jan 19, 2014

The second produces somewhat different trees. I'm not sure if we can ever fix that, since neither quicksort nor heapsort are stable sorts.

But do you get good test accuracy on covertype and other benchmarks with quicksort?

@glouppe
Member
glouppe commented Jan 20, 2014

The second produces somewhat different trees. I'm not sure if we can ever fix that, since neither quicksort nor heapsort are stable sorts.

Stability of the sorting algorithm shouldn't in theory have any impact on the trees that are built. As long as the feature values are sorted, the same cutting points should be found. I'll investigate when i'll have some time.

@glouppe glouppe commented on an outdated diff Jan 20, 2014
sklearn/tree/_tree.pyx
+ # [<pivot (l) =pivot (r) >pivot].
+ i = l = 0
+ r = n
+ while i < r:
+ if Xf[i] < pivot:
+ swap(Xf, samples, i, l)
+ i += 1
+ l += 1
+ elif Xf[i] > pivot:
+ r -= 1
+ swap(Xf, samples, i, r)
+ else:
+ i += 1
+
+ l -= 1
+ r += 1
@glouppe
glouppe Jan 20, 2014 Member

This is wrong. Removing these two lines solve all the bugs on my box. :)

@glouppe
Member
glouppe commented Jan 20, 2014

@larsmans I just submitted the fix to your branch.

Since I am sure caching the feature values should be also profitable to other splitters, I'd like to make similar changes to PresortBestSplitter and RandomSplitter. I'll push one more patch to your branch during the day.

@pprett
Member
pprett commented Jan 20, 2014

here are some benchmark results -- I only looked at the first commit 238d692

heap_s_error

heap_s_train_time

heap_s_test_time

All values are relative to Master where Master is the version after @arjoly recent MSE enhancement (w/o @jnothman tree structure refactoring) -- sorry for that but I only realized when it was too late.
We can see a nice performance improvement for all large datasets (covtype, expedia, solar, bioresponse) -- the improvement is about 15-20% .
Good work @larsmans - awesome!
There was a slight performance decrease on the sythetic regression benchmarks (Friedman#1-3) -- these have mostly large amount of split points thus stability should not be an issue at all.

@pprett
Member
pprett commented Jan 20, 2014

I looked at RandomForestClassifier|Regressor only and used the following parameters:

classification_params = {'n_estimators': 100,
                         'max_depth': None,}
regression_params = {'n_estimators': 100,
                     'max_depth': None, 'max_features': 0.3,
                     'min_samples_leaf': 1,
                     }
@larsmans
Member

@glouppe will merge the patch later this week. @pprett Not as impressive as on the covtype benchmark... did you fix the random seed? I'm surprised to see a difference in accuracy with the refactored heapsort, it should give the exact same ordering.

@pprett
Member
pprett commented Jan 20, 2014

@larsmans each values is the mean of 3 repetitions each with a different random seed (the same for both branches)

@pprett
Member
pprett commented Jan 20, 2014

@larsmans which parameters did you use for your covtype benchmark?

@larsmans
Member

Just the standard ones from the covtype script, n_estimators=20, random_seed=13.

@larsmans
Member

@pprett How many cores? I see a somewhat smaller speedup w/ one core compared to the four I used to benchmark previously:

master 133.6747s
238d692 86.0302s

That's about 36% off. At four cores, I get 43% off, despite having only four cores and a browser still running.

@pprett
Member
pprett commented Jan 20, 2014

@larsmans I ran only single threaded experiments

@ogrisel
Member
ogrisel commented Jan 21, 2014

I just tried to run:

python benchmarks/bench_covertype.py --classifiers=ExtraTrees --n-jobs=8

On this branch and master. The validation error is the same (~0.021). However I do not see any significant training time improvement (the standard deviations overlap). Maybe the speedup observed by @larsmans is architecture specific (e.g. related to the CPU cache size)?

Here are some attributes from one of my cores:

model name  : Intel(R) Xeon(R) CPU           X5660  @ 2.80GHz
cpu MHz     : 2660.000
cache size  : 12288 KB
bogomips    : 5585.92
@glouppe
Member
glouppe commented Jan 21, 2014

@ogrisel For a correct implementation, see larsmans#6

@pprett
Member
pprett commented Jan 21, 2014

@ogrisel is this your Apple? makes my Thinkpad W510 look like a sissy..

2014/1/21 Gilles Louppe notifications@github.com

@ogrisel https://github.com/ogrisel For a correct implementation, see larsmans#6
(comment)larsmans#6 (comment)


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32833045
.

Peter Prettenhofer

@ogrisel
Member
ogrisel commented Jan 21, 2014

I tried glouppe/tree-sort and the training speed seems to be approximately the same as well.

@pprett
Member
pprett commented Jan 21, 2014

@ogrisel have you tried RandomForest and --n-jobs=1 ?

2014/1/21 Olivier Grisel notifications@github.com

I tried glouppe/tree-sort and the training speed seems to be
approximately the same as well.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32833665
.

Peter Prettenhofer

@ogrisel
Member
ogrisel commented Jan 21, 2014

Nope it's my 12 physical cores Linux workstation @ Inria (with hyperthreading disabled). It only has 24GB of RAM but soon we will get a compute server with 16 physical cores with 384 GB of RAM for the team ;)

@glouppe
Member
glouppe commented Jan 21, 2014

@ogrisel You are trying with ExtraTrees :-) This PR affects Random Forests only.

@pprett
Member
pprett commented Jan 21, 2014

Very nice... has anybody of you disabled hyperthreading for their laptops
too? any experience?

2014/1/21 Gilles Louppe notifications@github.com

@ogrisel https://github.com/ogrisel You are trying with ExtraTrees :-)
This PR affects Random Forests only.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32834060
.

Peter Prettenhofer

@ogrisel
Member
ogrisel commented Jan 21, 2014

@glouppe indeed! Will redo the bench...

@ogrisel
Member
ogrisel commented Jan 21, 2014

Very nice... has anybody of you disabled hyperthreading for their laptops too? any experience?

Actually on my laptop, training Extra Trees with n_jobs=4 with HT seems to be slightly faster than n_jobs=2 without HT. But I should retry. HT was disabled on this workstation prior me joining Inria. But I am too lazy to reboot. Probably not a big deal anyway.

@ogrisel
Member
ogrisel commented Jan 21, 2014

Ok here are my numbers:

Benchmark in sequential run:

python benchmarks/bench_covertype.py --classifiers=RandomForest --n-jobs=1 --random-seed=11
  • master:

    RandomForest 176.4649s   0.3957s     0.0302
    
  • this branch:

    RandomForest  86.2451s   0.4110s     0.0359  # incorrect validation error
    
  • glouppe/tree-sort:

    RandomForest  82.9134s   0.3910s     0.0302
    
@pprett
Member
pprett commented Jan 21, 2014

I'm impressed... nice work guys!

2014/1/21 Olivier Grisel notifications@github.com

Ok here are my numbers:

Benchmark in sequential run:

python benchmarks/bench_covertype.py --classifiers=RandomForest --n-jobs=1 --random-seed=11

master:

RandomForest 176.4649s 0.3957s 0.0302
-

this branch:

RandomForest 86.2451s 0.4110s 0.0359
-

glouppe/tree-sort:

RandomForest 82.9134s 0.3910s 0.0302


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32835417
.

Peter Prettenhofer

@ogrisel
Member
ogrisel commented Jan 21, 2014

Benchmark in parallel run:

python benchmarks/bench_covertype.py --classifiers=RandomForest --n-jobs=8 --random-seed=11
  • master:

    RandomForest  34.1381s   0.1051s     0.0302
    
  • this branch:

    RandomForest  17.4253s   0.1056s     0.0360  # incorrect validation error
    
  • glouppe/tree-sort:

    RandomForest  16.6501s   0.1051s     0.0302
    
@arjoly
Member
arjoly commented Jan 21, 2014

Nice !

@ogrisel
Member
ogrisel commented Jan 21, 2014

More benchmarks, this time on MNIST with n_estimators=100 & n_jobs=8:

Classifier                                    train-time       test-time      error-rate   
----------------------------------------------------------------------------------------
Random Forest (master)                        12.5797169209  0.239344120026      0.0305
Random Forest (glouppe/tree-sort)             7.6724011898   0.238121986389      0.0305
@ogrisel
Member
ogrisel commented Jan 21, 2014

So the speedup is not covertype specific. @larsmans could you please merge @glouppe's PR#6 into this PR to have travis run the tests on it?

@pprett
Member
pprett commented Jan 21, 2014

I'm excited... fortunately I've a new laptop to download wiseRF... I wonder
how they compare now...

2014/1/21 Olivier Grisel notifications@github.com

So the speedup is not covertype specific. @larsmanshttps://github.com/larsmanscould you please merge
@glouppe https://github.com/glouppe's PR#6 into this PR to have travis
run the tests on it?


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32836695
.

Peter Prettenhofer

@arjoly
Member
arjoly commented Jan 21, 2014

@ogrisel May I ask you to try on friedman3 with a high number of samples or on a dataset with dense features ?

edit: The mnist dataset have indeed sparse features as covtype in more or less the same proportion.

@glouppe
Member
glouppe commented Jan 21, 2014

@ogrisel Test pass, as run by Travis on my branch: https://travis-ci.org/glouppe/scikit-learn/builds/17268697

@ogrisel
Member
ogrisel commented Jan 21, 2014

@glouppe nice!

@arjoly here it is. The different is less important but still (100k training samples. 10k testing samples)

from sklearn.datasets import make_friedman3
from sklearn.ensemble import RandomForestRegressor
from sklearn.cross_validation import train_test_split
from time import time

seed = 42

X, y = make_friedman3(n_samples=110000, random_state=seed)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=10000, random_state=seed)

model = RandomForestRegressor(n_estimators=100, n_jobs=8, random_state=seed)

tic = time()
model.fit(X_train, y_train)
train_time = time() - tic

tic = time()
score = model.score(X_test, y_test)
test_time = time() - tic

print('RandomForestRegressor:'
      ' train: {:.3f}s, test: {:.3f}s, score: {:.3f}'.format(
      train_time, test_time, score))
  • master:

    RandomForestRegressor: train: 11.879s, test: 0.104s, score: 0.999
    
  • gloupe/tree-sort:

    RandomForestRegressor: train: 9.072s, test: 0.103s, score: 0.999
    
@arjoly
Member
arjoly commented Jan 21, 2014

@ogrisel Thanks !!!

@ogrisel
Member
ogrisel commented Jan 21, 2014

We have a linear progression of the training time in both branches:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_friedman3
from sklearn.ensemble import RandomForestRegressor
from sklearn.cross_validation import train_test_split
from time import time

seed = 42
n_trees = 100

train_times = []
scores = []
train_sizes = [1000, 3000, 10000, 30000, 100000]
n_samples_test = 1000

for n_samples_train in train_sizes:
    n_samples = n_samples_train + n_samples_test
    X, y = make_friedman3(n_samples=n_samples, random_state=seed)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=n_samples_test, random_state=seed)

    model = RandomForestRegressor(n_estimators=n_trees, n_jobs=10, random_state=seed)

    tic = time()
    model.fit(X_train, y_train)
    train_time = time() - tic

    tic = time()
    score = model.score(X_test, y_test)
    test_time = time() - tic

    print('RandomForestRegressor:'
          ' n_samples_train={}, train: {:.3f}s, test: {:.3f}s, score: {:.3f}'.format(
          n_samples_train, train_time, test_time, score))
    train_times.append(train_time)
    scores.append(score)


plt.loglog(train_sizes, train_times, 'o-')
plt.ylim(0, 12)
plt.xlabel('# training samples')
plt.ylabel('Training time (s)')
plt.title('Random Forest with %d trees on Friedman 3' % n_trees)
plt.show()
  • master:

    RandomForestRegressor: n_samples_train=1000, train: 0.142s, test: 0.105s, score: 0.970
    RandomForestRegressor: n_samples_train=3000, train: 0.243s, test: 0.104s, score: 0.993
    RandomForestRegressor: n_samples_train=10000, train: 0.750s, test: 0.105s, score: 0.992
    RandomForestRegressor: n_samples_train=30000, train: 2.456s, test: 0.104s, score: 0.998
    RandomForestRegressor: n_samples_train=100000, train: 9.675s, test: 0.105s, score: 0.999
    

master

  • glouppe/tree-sort:

    RandomForestRegressor: n_samples_train=1000, train: 0.144s, test: 0.104s, score: 0.970
    RandomForestRegressor: n_samples_train=3000, train: 0.242s, test: 0.105s, score: 0.993
    RandomForestRegressor: n_samples_train=10000, train: 0.644s, test: 0.105s, score: 0.992
    RandomForestRegressor: n_samples_train=30000, train: 1.946s, test: 0.105s, score: 0.998
    RandomForestRegressor: n_samples_train=100000, train: 7.477s, test: 0.105s, score: 0.999
    

glouppe/tree-sort

@ogrisel
Member
ogrisel commented Jan 21, 2014

From my point of view +1 for merging the current state of the glouppe/tree-sort branch if @larsmans is ok.

@glouppe
Member
glouppe commented Jan 21, 2014

@ogrisel Before merging, if would like to see benchmarks between heapsort and quicksort. It is still not clear which is better. (Both have been optimized in this PR.)

@arjoly
Member
arjoly commented Jan 21, 2014

@ogrisel Before merging, if would like to see benchmarks between heapsort and quicksort. It is still not clear which is better. (Both have been optimized in this PR.)

+1

I would also be interested of comparing the results with introsort.

@ogrisel
Member
ogrisel commented Jan 21, 2014

I changed that in your branch (optimized heap sort instead of new qsort):

diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx
index 6eb39e3..65f5728 100644
--- a/sklearn/tree/_tree.pyx
+++ b/sklearn/tree/_tree.pyx
@@ -1,4 +1,3 @@
-# cython: cdivision=True
 # cython: boundscheck=False
 # cython: wraparound=False

@@ -1068,7 +1067,7 @@ cdef class BestSplitter(Splitter):
                 Xf[p] = X[X_sample_stride * samples[p]
                           + X_fx_stride * current_feature]

-            qsort(Xf + start, samples + start, end - start)
+            sort(Xf + start, samples + start, end - start)

             # Evaluate all splits
             self.criterion.reset()

re-cython and re-build:

On the growing Friedman 3 benchmark (same script as above):

RandomForestRegressor: n_samples_train=1000, train: 0.143s, test: 0.105s, score: 0.970
RandomForestRegressor: n_samples_train=3000, train: 0.243s, test: 0.104s, score: 0.993
RandomForestRegressor: n_samples_train=10000, train: 0.744s, test: 0.104s, score: 0.992
RandomForestRegressor: n_samples_train=30000, train: 2.156s, test: 0.103s, score: 0.998
RandomForestRegressor: n_samples_train=100000, train: 8.270s, test: 0.104s, score: 0.999

the plot looks linear as well.

on MNIST:

Random Forest              11.2249410152  0.245896100998      0.0305

so the optimized heap sort is in-between master's heap sort and @larsmans' qsort (with @glouppe's fix).

@ogrisel
Member
ogrisel commented Jan 21, 2014

I would also be interested of comparing the results with introsort.

You will have to implement it first :) The linear progression up to (100k samples) with qsort makes me think that the quadratic worst case complexity of quicksort is not impacting us when training decision trees.

@glouppe
Member
glouppe commented Jan 21, 2014

Thanks for the quick bench @ogrisel !

Indeed, this looks good to me. +1 for merge with my fixes.

@arjoly arjoly and 1 other commented on an outdated diff Jan 21, 2014
sklearn/tree/_tree.pyx
@@ -13,7 +13,8 @@
#
# Licence: BSD 3 clause
-from libc.stdlib cimport calloc, free, malloc, realloc
+from libc.stdio cimport perror, printf
+from libc.stdlib cimport calloc, free, malloc, realloc, abort
@arjoly
arjoly Jan 21, 2014 Member

abort is unused

@ogrisel
Member
ogrisel commented Jan 22, 2014

It tried to bench it versus @GaelVaroquaux's trial license of WiseRF but it has expired. The wise.io site does not seem to offer the trial version anymore. Anybody has a working WiseRF install?

@GaelVaroquaux
Member

Using the same box and the same benchmarks than
http://gael-varoquaux.info/blog/?p=169 I computed new benchmarks for this
PR. I cannot rerun the WiseRF benchmarks because my license is expired.

dataset, implementation,         train time, test time, accuracy 
digits,  SklearnET(n_jobs=1) r0.14,  2.641s, 0.082s,    0.986
digits,  SklearnRF(n_jobs=1) r0.14,  5.074s, 0.088s,    0.981
digits,  WiseRF(n_jobs=1),           5.665s, 0.108s,    0.979
digits,  SklearnET(n_jobs=1),        2.661s, 0.088s,    0.986
digits,  SklearnRF(n_jobs=1),        2.778s, 0.090s,    0.979
digits,  SklearnET(n_jobs=2) r0.14,  4.874s, 1.478s,    0.986
digits,  SklearnRF(n_jobs=2) r0.14,  5.716s, 1.349s,    0.978
digits,  WiseRF(n_jobs=2),           3.264s, 0.104s,    0.979
digits,  SklearnET(n_jobs=2),        1.821s, 0.203s,    0.986
digits,  SklearnRF(n_jobs=2),        1.890s, 0.203s,    0.980
digits,  SklearnRF(n_jobs=4),        1.176s, 0.204s,    0.979
digits,  SklearnRF(n_jobs=8),        1.023s, 0.206s,    0.979
mnist, SklearnET(n_jobs=1) r0.14,  1378.141s,  4.768s,  0.976
mnist, SklearnRF(n_jobs=1) r0.14,  1639.866s,  4.132s,  0.972
mnist, WiseRF(n_jobs=1),           1102.465s, 14.542s,  0.972
mnist, SklearnET(n_jobs=1),        1365.407s,  3.480s,  0.976
mnist, SklearnRF(n_jobs=1),         703.528s,  3.069s,  0.971
mnist, SklearnET(n_jobs=2),         694.601s,  2.229s,  0.976
mnist, SklearnRF(n_jobs=2),         363.769s,  2.079s,  0.972
mnist, SklearnRF(n_jobs=4),         182.293s,  1.533s,  0.971
mnist, SklearnRF(n_jobs=8),          97.913s,  1.084s,  0.972

So, I think that we can say that for large datasets, we are twice as fast
as WiseRF.

@ogrisel
Member
ogrisel commented Jan 22, 2014

Nice. But only twice as fast as this version of WiseRF (1.5.9) as we could not re-run the bench on a more recent version.

What is also interesting is that now RF seems to be as fast or faster than ET on mnist according to your bench. Did you use the same number of trees for both models?

@ogrisel
Member
ogrisel commented Jan 22, 2014

BTW: converting the data to fortran aligned data might even be a bit faster IIRC.

@GaelVaroquaux
Member

What is also interesting is that now RF seems to be as fast or faster
than ET on mnist according to your bench. Did you use the same number
of trees for both models?

RandomForestClassifier(n_estimators=1000, n_jobs=n_jobs,
oob_score=False, min_samples_split=1),

ExtraTreesClassifier(n_estimators=1000, n_jobs=n_jobs, oob_score=False,
min_samples_split=1),

@glouppe
Member
glouppe commented Jan 22, 2014

@ogrisel This does not surprise me. Extra trees might be much deeper than trees from a random forests. (I might also do additional changes to the code to leverage the CPU cache in RandomSplitter as well. )

@pprett
Member
pprett commented Jan 22, 2014

leverage all the cache!

@arjoly
Member
arjoly commented Jan 22, 2014

Great news all these improvements !
Nice picture @pprett !

@glouppe
Member
glouppe commented Jan 22, 2014

Ping @larsmans. Do you wish to make any more changes on this?

@ogrisel
Member
ogrisel commented Jan 23, 2014

Not related to this PR but when benching I found that I had increasing memory usage over long IPython session. I made a script and indeed there is a memory leak (both in master and in the glouppe/tree-sort branch). The leak was not there in 0.14.1: see the script in #2787

@larsmans
Member

Changed the algorithm to introsort as given in Musser's 1997 paper (except for the fallback to insertion sort in the final phase). Included @glouppe's changes. Please review/merge.

larsmans and others added some commits Jan 6, 2014
@larsmans larsmans ENH faster heapsort in trees
Uses the heapsort version that I'm familiar with: linear-time heapify
followed by n delete-min operations. Also cache-friendlier by copying
active features into a temporary array.

Random forest training time on covertype:
now     34.4123s
before  60.1179s
7c33502
@larsmans larsmans ENH introsort in tree learner
Covertype w/ random forest training time down to 29.0857s.
31491f9
@glouppe @larsmans glouppe ENH Make PresortBestSplitter cache friendly + cosmetics a681c9b
@larsmans
Member

Also rebased on master to include the changes from #2790.

@glouppe
Member
glouppe commented Jan 26, 2014

Thanks for the merge and the rebase @larsmans

I have personally no further comments to make. Have you benchmarked introsort just to make sure it is indeed faster?

@larsmans
Member

Introsort is about as fast as quicksort. The only reason to use it is to get rid of the worst-case quadratic behavior.

@ogrisel
Member
ogrisel commented Jan 26, 2014

I don't have access to my workstation right now, but on my laptop I get for the covertype bench:

python benchmarks/bench_covertype.py --n-jobs=4 --classifiers=RandomForest --random-seed=1
  • master:
RandomForest  81.5990s   0.2137s     0.0301
  • this branch:

    RandomForest  34.9615s   0.2133s     0.0301
    

So this is very good (at least as fast as the previous benchmarks run with quicksort).

@ogrisel
Member
ogrisel commented Jan 26, 2014

I had a quick look at the code and it looks fine to me although I am not familiar with sorting algorithms. +1 for merging on my side.

@ogrisel
Member
ogrisel commented Jan 26, 2014

I also ran my memory leak detection script with the DecisionTreeRegressor class instead of the ExtraTreeRegressor class and neither psutil's reported RSS nor objgraph.get_leaking_objects detected a leak.

@glouppe
Member
glouppe commented Jan 26, 2014

Thanks for the check! @larsmans feel free to merge this in :)

On 26/01/2014, Olivier Grisel notifications@github.com wrote:

I also ran my memory leak detection script with the DecisionTreeRegressor
class instead of the ExtraTreeRegressor class and neither psutil's
reported RSS nor objgraph.get_leaking_object detected a leak.


Reply to this email directly or view it on GitHub:
#2747 (comment)

@amueller
Member

awesome!!! great work!

@larsmans larsmans merged commit 9f6dbc5 into scikit-learn:master Jan 26, 2014

1 check passed

default The Travis CI build passed
Details
@ogrisel
Member
ogrisel commented Jan 26, 2014

\o/

@ogrisel
Member
ogrisel commented Jan 26, 2014

Please don't forget to add an entry in the whats_new.rst file.

@jnothman
Member

Great work! Nice to know we persist in teaching diverse sorting algorithms
for good reason!

On 27 January 2014 09:09, Olivier Grisel notifications@github.com wrote:

\o/


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-33332180
.

@songgc
songgc commented Jan 27, 2014

After this merge, the GBRT regression takes 2 times longer on a data set than the previous build (commit bf1635d). The loss score seems OK. BTW, the data set is MSLR-WEB10K/Fold1 (MS learning of rank)

@glouppe
Member
glouppe commented Jan 27, 2014

Could you try with 31491f9 as head
instead? The only changes on GBRT are with regards to the
PresortBestSplitter and shouldn't make things slower. CC: @pprett

On 27 January 2014 08:15, Guocong Song notifications@github.com wrote:

After this merge, the GBRT regression takes 2 times longer on a data set
than the previous build (commit bf1635d). The loss score seems OK. BTW, the
data set is MSLR-WEB10K/Fold1 (MS learning of rank)


Reply to this email directly or view it on GitHub.

@songgc
songgc commented Jan 27, 2014

Hash 31491f9 is even faster than bf1635d by 15%.

@pprett
Member
pprett commented Jan 27, 2014

thanks @songgc - I did a quick benchmark using my solar dataset (regression).

I looked at master ( ), introsort (31491f9), MSE Optim (0b7c79b), best-first (834b375).
I can definitely see a performance regression between 0b7c79b and 31491f9 .
Since it wasn't exposed in the latest benchmarks we did I assume it is an effect of the memory leak fix. I need to check this in more detail.

@songgc I find the 2x performance degression quite harsh - can you tell me which parameters you used (max_features, max_depth, n_estimators).

@ogrisel
Member
ogrisel commented Jan 27, 2014

@songgc have you fixed the random_state parameter of your GradientBoostingRegressor? I tried on a subsamples of 62244 MSLR results / 136 (500 queries) with GradientBoostingRegressor(n_estimators=100, random_state=1) and it trains in 1m30s both in master and on bf1635d and yield NDCG@10=0.507 each time.

I also tried to bench GradientBoostingRegressor on a simple make_friedman3 dataset with 100k samples and the training speed is the same.

@pprett
Member
pprett commented Jan 27, 2014

@ogrisel ok - I'm running a benchmark suite now with 3 repetitions between current master and @arjoly MSE optimization -- I keep you posted.

@songgc it would be great if you could post the parameters you used -- tree building performance can differ quite considerably depending on the parameters used (eg. max_features)

@pprett
Member
pprett commented Jan 27, 2014

gbrt-bench-perf-reg

this one just uses smaller datasets -- it looks good IMHO - I used the following params::

classification_params = {'n_estimators': 500, 'loss': 'deviance',
                     'min_samples_leaf': 1, 'max_leaf_nodes': 6,
                     'max_depth': None,
                     'learning_rate': .01, 'subsample': 1.0, 'verbose': 0}

regression_params = {'n_estimators': 500, 'max_leaf_nodes': 6,
                 'max_depth': None,
                 'min_samples_leaf': 1, 'learning_rate': 0.01,
                 'loss': 'ls', 'subsample': 1.0, 'verbose': 0,
                 }
@ogrisel
Member
ogrisel commented Jan 27, 2014

Same here I do not see any regression between 31491f9 and the current master:

I trained GradientBoostingRegressor(n_estimators=100, random_state=1) on the full Fold1 split of MSLR-10K (3 folds train + 1 fold val == 958671 samples) in 33min on master and 36 min on 31491f9. In both cases I get the following scores on the Fold1 test fold:

  • NDCG@5: 0.506
  • NDCG@10: 0.514
  • R2: 0.168

If I understand correctly, the only commit that is impacting between 31491f9 and master is @glouppe's cache optim a681c9b (aka: ENH Make PresortBestSplitter cache friendly + cosmetics). It seems to indeed work on my box by removing 3mins of training time.

@larsmans larsmans deleted the larsmans:tree-sort branch Jan 27, 2014
@songgc
songgc commented Jan 27, 2014

My apologies for the false alarm! I have found that I installed version 0.14.1rather than the master branch...

pip install scikit-learn git+https://github.com/scikit-learn/scikit-learn.git gives me the stable version.
pip install -f scikit-learn file://"a synced local repos" gives me the master branch.

My lesson is "check version first". Currently, my benchmarks are consistent with you guys. The speed improvement is impressive! I don't to need to transfer data to R for the GBM package:)

@pprett
Member
pprett commented Jan 27, 2014

no worries @songgc thanks for double-checking -- I'd say there is definitely no reason now to switch to R for the randomForest package ;)

@larsmans
Member

@songgc @pprett Any benchmarks against R? :)

@pprett
Member
pprett commented Jan 27, 2014

I've code for this -- next days are a bit busy -- will post it later this
week - I expect quite a differences because 0.14.1 used to be faster as
well. WiseRF is the competitor :)

2014-01-27 Lars Buitinck notifications@github.com

@songgc https://github.com/songgc @pprett https://github.com/pprettAny benchmarks against R? :)


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-33433329
.

Peter Prettenhofer

@songgc
songgc commented Jan 27, 2014

@larsmans
I had a benchmark case against GBM as follows:
Data set: MSLR-WEB10K/Fold1
params:
for GBRT {'n_estimators': 100, 'max_depth': 4, 'min_samples_split': 10,
'learning_rate': 0.03, 'loss': 'ls', 'subsample': 0.5, 'random_state': 11, 'verbose': 1}
for GBM {"distribution": "gaussian", "shrinkage": 0.03,
"n.tree": 100, "bag.fraction": 0.5, "verbose": True,
"n.minobsinnode": 10, "interaction.depth": 6}
Please note that max depths are different. GBM usually requires deeper trees compared to GBRT to achieve a similar performance.

Benchmark result:
library, test MSE, running time
GBRT, 0.5854, 1238s
GBM, 0.5943, 1442s

@pprett
Member
pprett commented Jan 28, 2014

@songgc the current master also includes an option to build GBM style trees in GradientBoostingRegressor|Classifier -- use the max_leaf_nodes argument (max_leaf_nodes -1 equals interaction.depth)

@amueller amueller modified the milestone: 0.16, 0.15 Jul 15, 2014
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment