New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] faster sorting in trees; random forests almost 2× as fast #2747

Merged
merged 3 commits into from Jan 26, 2014

Conversation

Projects
None yet
9 participants
@larsmans
Member

larsmans commented Jan 13, 2014

Changed the heapsort in the tree learners into a quicksort and gave it cache-friendlier data access. Speeds up RF longer almost two-fold. In fact, profiling with @fabianp's yep tool show the time taken by sort to go down from 65% to <10% of total running time in the covertype benchmark.

This is taking longer than I thought but I figured I should at least show @glouppe and @pprett what I've got so far.

TODO:

  • more benchmarks, esp. on a denser dataset than covertype (sparse data is easy :)
  • make tests pass
  • clean up code
  • filter out the cruft
  • decide on the final algorithm: quicksort takes O(n²) time in the worst case, which can be avoided by introsort at the expense of more code.
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 13, 2014

Member

Nice! I tagged this PR for 0.15 milestone if everyone agrees :)

Member

ogrisel commented Jan 13, 2014

Nice! I tagged this PR for 0.15 milestone if everyone agrees :)

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 13, 2014

Member

On the flip side, the optimization to the sorting is so good that it makes the rest of the tree code look slow :p

(But again, covertype is really easy. I'll try 20news after SVD-200 as well.)

Member

larsmans commented Jan 13, 2014

On the flip side, the optimization to the sorting is so good that it makes the rest of the tree code look slow :p

(But again, covertype is really easy. I'll try 20news after SVD-200 as well.)

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 13, 2014

Member

@larsmans I've a benchmark suite that contains datasets with different characteristics -- will send the results tomorrow

Member

pprett commented Jan 13, 2014

@larsmans I've a benchmark suite that contains datasets with different characteristics -- will send the results tomorrow

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 13, 2014

Member

You can try on MNIST as well with the mldata loader: there is a script in the MLP PR: https://github.com/IssamLaradji/scikit-learn/blob/multilayer-perceptron/benchmarks/bench_mnist.py

Member

ogrisel commented Jan 13, 2014

You can try on MNIST as well with the mldata loader: there is a script in the MLP PR: https://github.com/IssamLaradji/scikit-learn/blob/multilayer-perceptron/benchmarks/bench_mnist.py

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 13, 2014

Member

@pprett Then be sure to use vanilla quicksort, not the randomized one. Shuffling turns out to be extremely expensive.

Member

larsmans commented Jan 13, 2014

@pprett Then be sure to use vanilla quicksort, not the randomized one. Shuffling turns out to be extremely expensive.

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 13, 2014

Member

pprof (Google perftools) graph w/ quicksort on covertype:

quicksort-pprof

Member

larsmans commented Jan 13, 2014

pprof (Google perftools) graph w/ quicksort on covertype:

quicksort-pprof

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 13, 2014

Member

On 20news, all categories, 100 SVD components, 500 trees and four cores of an Intel i7, training time goes down from 24.181s to 11.683s. F1-score goes down from ~.75 to ~.6, though, so I may have a bug somewhere...

Member

larsmans commented Jan 13, 2014

On 20news, all categories, 100 SVD components, 500 trees and four cores of an Intel i7, training time goes down from 24.181s to 11.683s. F1-score goes down from ~.75 to ~.6, though, so I may have a bug somewhere...

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 13, 2014

Member

Covertype accuracy actually went down the drain as well. This wasn't the case before I rebased, I must have made a mistake in handling the new X_fx_stride.

Member

larsmans commented Jan 13, 2014

Covertype accuracy actually went down the drain as well. This wasn't the case before I rebased, I must have made a mistake in handling the new X_fx_stride.

while ((p + 1 < end) and
(X[X_sample_stride * samples[p + 1] + X_fx_stride * current_feature] <=
X[X_sample_stride * samples[p] + X_fx_stride * current_feature] + EPSILON_FLT)):
while p + 1 < end and Xf[p + 1] <= Xf[p] + EPSILON_FLT:

This comment has been minimized.

@pprett

pprett Jan 14, 2014

Member

@larsmans in the block above you set Xf[p] for p in range(0, end-start). Here p runs from range(start, end) - is that correct?

@pprett

pprett Jan 14, 2014

Member

@larsmans in the block above you set Xf[p] for p in range(0, end-start). Here p runs from range(start, end) - is that correct?

This comment has been minimized.

@pprett

pprett Jan 14, 2014

Member

When I did my profiling of the tree code a couple of weeks ago it turned out that for datasets with lots of split points the bulk of time is spent in the while condition -- maybe part of your speed-up stems from this refactoring rather than the new sorting.

@pprett

pprett Jan 14, 2014

Member

When I did my profiling of the tree code a couple of weeks ago it turned out that for datasets with lots of split points the bulk of time is spent in the while condition -- maybe part of your speed-up stems from this refactoring rather than the new sorting.

This comment has been minimized.

@glouppe

glouppe Jan 14, 2014

Member

@larsmans in the block above you set Xf[p] for p in range(0, end-start). Here p runs from range(start, end) - is that correct?

+1, indices are not correct. Please always make p ranges in [start;end) to avoid bugs and confusion with other parts of the code.

@glouppe

glouppe Jan 14, 2014

Member

@larsmans in the block above you set Xf[p] for p in range(0, end-start). Here p runs from range(start, end) - is that correct?

+1, indices are not correct. Please always make p ranges in [start;end) to avoid bugs and confusion with other parts of the code.

This comment has been minimized.

@larsmans

larsmans Jan 14, 2014

Member

Right, this is it. Will change this bit tonight.

@pprett No, this isn't actually the cause of the speedup, it was near 50% before I even introduced this bug.

@larsmans

larsmans Jan 14, 2014

Member

Right, this is it. Will change this bit tonight.

@pprett No, this isn't actually the cause of the speedup, it was near 50% before I even introduced this bug.

This comment has been minimized.

@pprett

pprett Jan 14, 2014

Member

great - thx Lars

2014/1/14 Lars Buitinck notifications@github.com

In sklearn/tree/_tree.pyx:

         # Evaluate all splits
         self.criterion.reset()
         p = start

         while p < end:
  •            while ((p + 1 < end) and
    
  •                   (X[X_sample_stride \* samples[p + 1] + X_fx_stride \* current_feature] <=
    
  •                    X[X_sample_stride \* samples[p] + X_fx_stride \* current_feature] + EPSILON_FLT)):
    
  •            while p + 1 < end and Xf[p + 1] <= Xf[p] + EPSILON_FLT:
    

Right, this is it. Will change this bit tonight.

@pprett https://github.com/pprett No, this isn't actually the cause of
the speedup, it was near 50% before I even introduced this bug.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747/files#r8859593
.

Peter Prettenhofer

@pprett

pprett Jan 14, 2014

Member

great - thx Lars

2014/1/14 Lars Buitinck notifications@github.com

In sklearn/tree/_tree.pyx:

         # Evaluate all splits
         self.criterion.reset()
         p = start

         while p < end:
  •            while ((p + 1 < end) and
    
  •                   (X[X_sample_stride \* samples[p + 1] + X_fx_stride \* current_feature] <=
    
  •                    X[X_sample_stride \* samples[p] + X_fx_stride \* current_feature] + EPSILON_FLT)):
    
  •            while p + 1 < end and Xf[p + 1] <= Xf[p] + EPSILON_FLT:
    

Right, this is it. Will change this bit tonight.

@pprett https://github.com/pprett No, this isn't actually the cause of
the speedup, it was near 50% before I even introduced this bug.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747/files#r8859593
.

Peter Prettenhofer

Show outdated Hide outdated sklearn/tree/_tree.pyx
#abort()
Xf[p] = X[X_sample_stride * samples[p + start]
+ X_fx_stride * current_feature]
qsort(Xf, samples + start, end - start)

This comment has been minimized.

@glouppe

glouppe Jan 14, 2014

Member

With regards to my comment below, I would therefore replace this part of the code with:

for p in range(start, end):
    Xf[p] = X[X_sample_stride * samples[p] + X_fx_stride * current_feature]
qsort(Xf + start, samples + start, end - start)
@glouppe

glouppe Jan 14, 2014

Member

With regards to my comment below, I would therefore replace this part of the code with:

for p in range(start, end):
    Xf[p] = X[X_sample_stride * samples[p] + X_fx_stride * current_feature]
qsort(Xf + start, samples + start, end - start)

This comment has been minimized.

@larsmans

larsmans Jan 14, 2014

Member

@glouppe That would overallocate. Could that be a problem in terms of peak memory usage? If not, I'll simplify the code.

@larsmans

larsmans Jan 14, 2014

Member

@glouppe That would overallocate. Could that be a problem in terms of peak memory usage? If not, I'll simplify the code.

This comment has been minimized.

@glouppe

glouppe Jan 14, 2014

Member

It is not really a huge problem in my opinion. In terms of memory usage it is like adding a new feature column. Anyway, you already overallocate Xf since you make it of size self.n_samples above (and not of size end - start). This seems to work :)

In fact, you could also avoid the mallocs/frees by allocating Xf only once during init, as we do for samples. That would be a bit cleaner in my opinion.

@glouppe

glouppe Jan 14, 2014

Member

It is not really a huge problem in my opinion. In terms of memory usage it is like adding a new feature column. Anyway, you already overallocate Xf since you make it of size self.n_samples above (and not of size end - start). This seems to work :)

In fact, you could also avoid the mallocs/frees by allocating Xf only once during init, as we do for samples. That would be a bit cleaner in my opinion.

This comment has been minimized.

@larsmans

larsmans Jan 14, 2014

Member

I'll see if that's better. It would save adding a return value to this function.

@larsmans

larsmans Jan 14, 2014

Member

I'll see if that's better. It would save adding a return value to this function.

Show outdated Hide outdated sklearn/tree/_tree.pyx
if current_threshold == X[X_sample_stride * samples[p] + X_fx_stride * current_feature]:
current_threshold = X[X_sample_stride * samples[p - 1] + X_fx_stride * current_feature]
if current_threshold == Xf[p]:
current_threshold = Xf[p]

This comment has been minimized.

@glouppe

glouppe Jan 15, 2014

Member

It should be current_threshold = Xf[p - 1].

@glouppe

glouppe Jan 15, 2014

Member

It should be current_threshold = Xf[p - 1].

cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil:
# Helper for sort
Xf[i], Xf[j] = Xf[j], Xf[i]

This comment has been minimized.

@arjoly

arjoly Jan 15, 2014

Member

Swapping with this syntax generates one more C instruction than what is really needed.

  /* "sklearn/tree/_tree.pyx":1141
 * cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil:
 *     # Helper for sort
 *     Xf[i], Xf[j] = Xf[j], Xf[i]             # <<<<<<<<<<<<<<
 *     samples[i], samples[j] = samples[j], samples[i]
 * 
 */
  __pyx_t_1 = (__pyx_v_Xf[__pyx_v_j]);
  __pyx_t_2 = (__pyx_v_Xf[__pyx_v_i]);
  (__pyx_v_Xf[__pyx_v_i]) = __pyx_t_1;
  (__pyx_v_Xf[__pyx_v_j]) = __pyx_t_2;
@arjoly

arjoly Jan 15, 2014

Member

Swapping with this syntax generates one more C instruction than what is really needed.

  /* "sklearn/tree/_tree.pyx":1141
 * cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil:
 *     # Helper for sort
 *     Xf[i], Xf[j] = Xf[j], Xf[i]             # <<<<<<<<<<<<<<
 *     samples[i], samples[j] = samples[j], samples[i]
 * 
 */
  __pyx_t_1 = (__pyx_v_Xf[__pyx_v_j]);
  __pyx_t_2 = (__pyx_v_Xf[__pyx_v_i]);
  (__pyx_v_Xf[__pyx_v_i]) = __pyx_t_1;
  (__pyx_v_Xf[__pyx_v_j]) = __pyx_t_2;

This comment has been minimized.

@larsmans

larsmans Jan 15, 2014

Member

In C, yes. But the assembly for

int t = a[i];
a[i] = a[j];
a[j] = t;

and

int ti = a[i];
int tj = a[j];
a[i] = tj;
a[j] = ti;

is identical (gcc -O2 -S).

The only thing I still need to try is putting an if (i != j) around this, but that's for later.

@larsmans

larsmans Jan 15, 2014

Member

In C, yes. But the assembly for

int t = a[i];
a[i] = a[j];
a[j] = t;

and

int ti = a[i];
int tj = a[j];
a[i] = tj;
a[j] = ti;

is identical (gcc -O2 -S).

The only thing I still need to try is putting an if (i != j) around this, but that's for later.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 19, 2014

Member

@larsmans can I benchmark the enhancements or are you still working out some issues in the code?

Member

pprett commented Jan 19, 2014

@larsmans can I benchmark the enhancements or are you still working out some issues in the code?

@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 19, 2014

Member

@pprett As long as the trees are not guaranteed to be the same (which is not the case since accuracy drops), there is no point in benchmarking the current changes. We should invest some time to try to figure this out. I can have a look tomorrow.

Member

glouppe commented Jan 19, 2014

@pprett As long as the trees are not guaranteed to be the same (which is not the case since accuracy drops), there is no point in benchmarking the current changes. We should invest some time to try to figure this out. I can have a look tomorrow.

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 19, 2014

Member

I re-applied the patches on top of current master. The first patch, faster heapsort, can AFAIC be merged into master immediately. It gives an almost two-fold speedup and it passes the testsuite.

The second patch, quicksort, doesn't pass all the tests due to randomness issues, but further speeds up tree learning significantly.

Member

larsmans commented Jan 19, 2014

I re-applied the patches on top of current master. The first patch, faster heapsort, can AFAIC be merged into master immediately. It gives an almost two-fold speedup and it passes the testsuite.

The second patch, quicksort, doesn't pass all the tests due to randomness issues, but further speeds up tree learning significantly.

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 19, 2014

Member

@glouppe You can certainly benchmark 238d692, it passes all of the tests.

The second produces somewhat different trees. I'm not sure if we can ever fix that, since neither quicksort nor heapsort are stable sorts.

Member

larsmans commented Jan 19, 2014

@glouppe You can certainly benchmark 238d692, it passes all of the tests.

The second produces somewhat different trees. I'm not sure if we can ever fix that, since neither quicksort nor heapsort are stable sorts.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 19, 2014

Member

I'm running benchmarks now - should be finished in a couple of hours

2014/1/19 Lars Buitinck notifications@github.com

@glouppe https://github.com/glouppe You can certainly benchmark 238d692238d692,
it passes all of the tests.

The second produces somewhat different trees. I'm not sure if we can ever
fix that, since neither quicksort nor heapsort are stable sorts.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32714620
.

Peter Prettenhofer

Member

pprett commented Jan 19, 2014

I'm running benchmarks now - should be finished in a couple of hours

2014/1/19 Lars Buitinck notifications@github.com

@glouppe https://github.com/glouppe You can certainly benchmark 238d692238d692,
it passes all of the tests.

The second produces somewhat different trees. I'm not sure if we can ever
fix that, since neither quicksort nor heapsort are stable sorts.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32714620
.

Peter Prettenhofer

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 19, 2014

Member

The second produces somewhat different trees. I'm not sure if we can ever fix that, since neither quicksort nor heapsort are stable sorts.

But do you get good test accuracy on covertype and other benchmarks with quicksort?

Member

ogrisel commented Jan 19, 2014

The second produces somewhat different trees. I'm not sure if we can ever fix that, since neither quicksort nor heapsort are stable sorts.

But do you get good test accuracy on covertype and other benchmarks with quicksort?

@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 20, 2014

Member

The second produces somewhat different trees. I'm not sure if we can ever fix that, since neither quicksort nor heapsort are stable sorts.

Stability of the sorting algorithm shouldn't in theory have any impact on the trees that are built. As long as the feature values are sorted, the same cutting points should be found. I'll investigate when i'll have some time.

Member

glouppe commented Jan 20, 2014

The second produces somewhat different trees. I'm not sure if we can ever fix that, since neither quicksort nor heapsort are stable sorts.

Stability of the sorting algorithm shouldn't in theory have any impact on the trees that are built. As long as the feature values are sorted, the same cutting points should be found. I'll investigate when i'll have some time.

Show outdated Hide outdated sklearn/tree/_tree.pyx
i += 1
l -= 1
r += 1

This comment has been minimized.

@glouppe

glouppe Jan 20, 2014

Member

This is wrong. Removing these two lines solve all the bugs on my box. :)

@glouppe

glouppe Jan 20, 2014

Member

This is wrong. Removing these two lines solve all the bugs on my box. :)

@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 20, 2014

Member

@larsmans I just submitted the fix to your branch.

Since I am sure caching the feature values should be also profitable to other splitters, I'd like to make similar changes to PresortBestSplitter and RandomSplitter. I'll push one more patch to your branch during the day.

Member

glouppe commented Jan 20, 2014

@larsmans I just submitted the fix to your branch.

Since I am sure caching the feature values should be also profitable to other splitters, I'd like to make similar changes to PresortBestSplitter and RandomSplitter. I'll push one more patch to your branch during the day.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 20, 2014

Member

here are some benchmark results -- I only looked at the first commit 238d692

heap_s_error

heap_s_train_time

heap_s_test_time

All values are relative to Master where Master is the version after @arjoly recent MSE enhancement (w/o @jnothman tree structure refactoring) -- sorry for that but I only realized when it was too late.
We can see a nice performance improvement for all large datasets (covtype, expedia, solar, bioresponse) -- the improvement is about 15-20% .
Good work @larsmans - awesome!
There was a slight performance decrease on the sythetic regression benchmarks (Friedman#1-3) -- these have mostly large amount of split points thus stability should not be an issue at all.

Member

pprett commented Jan 20, 2014

here are some benchmark results -- I only looked at the first commit 238d692

heap_s_error

heap_s_train_time

heap_s_test_time

All values are relative to Master where Master is the version after @arjoly recent MSE enhancement (w/o @jnothman tree structure refactoring) -- sorry for that but I only realized when it was too late.
We can see a nice performance improvement for all large datasets (covtype, expedia, solar, bioresponse) -- the improvement is about 15-20% .
Good work @larsmans - awesome!
There was a slight performance decrease on the sythetic regression benchmarks (Friedman#1-3) -- these have mostly large amount of split points thus stability should not be an issue at all.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 20, 2014

Member

I looked at RandomForestClassifier|Regressor only and used the following parameters:

classification_params = {'n_estimators': 100,
                         'max_depth': None,}
regression_params = {'n_estimators': 100,
                     'max_depth': None, 'max_features': 0.3,
                     'min_samples_leaf': 1,
                     }
Member

pprett commented Jan 20, 2014

I looked at RandomForestClassifier|Regressor only and used the following parameters:

classification_params = {'n_estimators': 100,
                         'max_depth': None,}
regression_params = {'n_estimators': 100,
                     'max_depth': None, 'max_features': 0.3,
                     'min_samples_leaf': 1,
                     }
@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 20, 2014

Member

@glouppe will merge the patch later this week. @pprett Not as impressive as on the covtype benchmark... did you fix the random seed? I'm surprised to see a difference in accuracy with the refactored heapsort, it should give the exact same ordering.

Member

larsmans commented Jan 20, 2014

@glouppe will merge the patch later this week. @pprett Not as impressive as on the covtype benchmark... did you fix the random seed? I'm surprised to see a difference in accuracy with the refactored heapsort, it should give the exact same ordering.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 20, 2014

Member

@larsmans each values is the mean of 3 repetitions each with a different random seed (the same for both branches)

Member

pprett commented Jan 20, 2014

@larsmans each values is the mean of 3 repetitions each with a different random seed (the same for both branches)

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 20, 2014

Member

@larsmans which parameters did you use for your covtype benchmark?

Member

pprett commented Jan 20, 2014

@larsmans which parameters did you use for your covtype benchmark?

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 20, 2014

Member

Just the standard ones from the covtype script, n_estimators=20, random_seed=13.

Member

larsmans commented Jan 20, 2014

Just the standard ones from the covtype script, n_estimators=20, random_seed=13.

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 20, 2014

Member

@pprett How many cores? I see a somewhat smaller speedup w/ one core compared to the four I used to benchmark previously:

master 133.6747s
238d692 86.0302s

That's about 36% off. At four cores, I get 43% off, despite having only four cores and a browser still running.

Member

larsmans commented Jan 20, 2014

@pprett How many cores? I see a somewhat smaller speedup w/ one core compared to the four I used to benchmark previously:

master 133.6747s
238d692 86.0302s

That's about 36% off. At four cores, I get 43% off, despite having only four cores and a browser still running.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 20, 2014

Member

@larsmans I ran only single threaded experiments

Member

pprett commented Jan 20, 2014

@larsmans I ran only single threaded experiments

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 21, 2014

Member

I just tried to run:

python benchmarks/bench_covertype.py --classifiers=ExtraTrees --n-jobs=8

On this branch and master. The validation error is the same (~0.021). However I do not see any significant training time improvement (the standard deviations overlap). Maybe the speedup observed by @larsmans is architecture specific (e.g. related to the CPU cache size)?

Here are some attributes from one of my cores:

model name  : Intel(R) Xeon(R) CPU           X5660  @ 2.80GHz
cpu MHz     : 2660.000
cache size  : 12288 KB
bogomips    : 5585.92
Member

ogrisel commented Jan 21, 2014

I just tried to run:

python benchmarks/bench_covertype.py --classifiers=ExtraTrees --n-jobs=8

On this branch and master. The validation error is the same (~0.021). However I do not see any significant training time improvement (the standard deviations overlap). Maybe the speedup observed by @larsmans is architecture specific (e.g. related to the CPU cache size)?

Here are some attributes from one of my cores:

model name  : Intel(R) Xeon(R) CPU           X5660  @ 2.80GHz
cpu MHz     : 2660.000
cache size  : 12288 KB
bogomips    : 5585.92
@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 21, 2014

Member

@ogrisel For a correct implementation, see larsmans#6

Member

glouppe commented Jan 21, 2014

@ogrisel For a correct implementation, see larsmans#6

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 21, 2014

Member

@ogrisel is this your Apple? makes my Thinkpad W510 look like a sissy..

2014/1/21 Gilles Louppe notifications@github.com

@ogrisel https://github.com/ogrisel For a correct implementation, see larsmans#6
(comment)larsmans#6 (comment)


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32833045
.

Peter Prettenhofer

Member

pprett commented Jan 21, 2014

@ogrisel is this your Apple? makes my Thinkpad W510 look like a sissy..

2014/1/21 Gilles Louppe notifications@github.com

@ogrisel https://github.com/ogrisel For a correct implementation, see larsmans#6
(comment)larsmans#6 (comment)


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32833045
.

Peter Prettenhofer

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 21, 2014

Member

I tried glouppe/tree-sort and the training speed seems to be approximately the same as well.

Member

ogrisel commented Jan 21, 2014

I tried glouppe/tree-sort and the training speed seems to be approximately the same as well.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 21, 2014

Member

@ogrisel have you tried RandomForest and --n-jobs=1 ?

2014/1/21 Olivier Grisel notifications@github.com

I tried glouppe/tree-sort and the training speed seems to be
approximately the same as well.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32833665
.

Peter Prettenhofer

Member

pprett commented Jan 21, 2014

@ogrisel have you tried RandomForest and --n-jobs=1 ?

2014/1/21 Olivier Grisel notifications@github.com

I tried glouppe/tree-sort and the training speed seems to be
approximately the same as well.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32833665
.

Peter Prettenhofer

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 21, 2014

Member

Nope it's my 12 physical cores Linux workstation @ Inria (with hyperthreading disabled). It only has 24GB of RAM but soon we will get a compute server with 16 physical cores with 384 GB of RAM for the team ;)

Member

ogrisel commented Jan 21, 2014

Nope it's my 12 physical cores Linux workstation @ Inria (with hyperthreading disabled). It only has 24GB of RAM but soon we will get a compute server with 16 physical cores with 384 GB of RAM for the team ;)

@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 21, 2014

Member

@ogrisel You are trying with ExtraTrees :-) This PR affects Random Forests only.

Member

glouppe commented Jan 21, 2014

@ogrisel You are trying with ExtraTrees :-) This PR affects Random Forests only.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 21, 2014

Member

Very nice... has anybody of you disabled hyperthreading for their laptops
too? any experience?

2014/1/21 Gilles Louppe notifications@github.com

@ogrisel https://github.com/ogrisel You are trying with ExtraTrees :-)
This PR affects Random Forests only.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32834060
.

Peter Prettenhofer

Member

pprett commented Jan 21, 2014

Very nice... has anybody of you disabled hyperthreading for their laptops
too? any experience?

2014/1/21 Gilles Louppe notifications@github.com

@ogrisel https://github.com/ogrisel You are trying with ExtraTrees :-)
This PR affects Random Forests only.


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-32834060
.

Peter Prettenhofer

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 21, 2014

Member

@glouppe indeed! Will redo the bench...

Member

ogrisel commented Jan 21, 2014

@glouppe indeed! Will redo the bench...

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 21, 2014

Member

Very nice... has anybody of you disabled hyperthreading for their laptops too? any experience?

Actually on my laptop, training Extra Trees with n_jobs=4 with HT seems to be slightly faster than n_jobs=2 without HT. But I should retry. HT was disabled on this workstation prior me joining Inria. But I am too lazy to reboot. Probably not a big deal anyway.

Member

ogrisel commented Jan 21, 2014

Very nice... has anybody of you disabled hyperthreading for their laptops too? any experience?

Actually on my laptop, training Extra Trees with n_jobs=4 with HT seems to be slightly faster than n_jobs=2 without HT. But I should retry. HT was disabled on this workstation prior me joining Inria. But I am too lazy to reboot. Probably not a big deal anyway.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 21, 2014

Member

Ok here are my numbers:

Benchmark in sequential run:

python benchmarks/bench_covertype.py --classifiers=RandomForest --n-jobs=1 --random-seed=11
  • master:

    RandomForest 176.4649s   0.3957s     0.0302
    
  • this branch:

    RandomForest  86.2451s   0.4110s     0.0359  # incorrect validation error
    
  • glouppe/tree-sort:

    RandomForest  82.9134s   0.3910s     0.0302
    
Member

ogrisel commented Jan 21, 2014

Ok here are my numbers:

Benchmark in sequential run:

python benchmarks/bench_covertype.py --classifiers=RandomForest --n-jobs=1 --random-seed=11
  • master:

    RandomForest 176.4649s   0.3957s     0.0302
    
  • this branch:

    RandomForest  86.2451s   0.4110s     0.0359  # incorrect validation error
    
  • glouppe/tree-sort:

    RandomForest  82.9134s   0.3910s     0.0302
    
@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 22, 2014

Member

@ogrisel This does not surprise me. Extra trees might be much deeper than trees from a random forests. (I might also do additional changes to the code to leverage the CPU cache in RandomSplitter as well. )

Member

glouppe commented Jan 22, 2014

@ogrisel This does not surprise me. Extra trees might be much deeper than trees from a random forests. (I might also do additional changes to the code to leverage the CPU cache in RandomSplitter as well. )

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett
Member

pprett commented Jan 22, 2014

leverage all the cache!

@arjoly

This comment has been minimized.

Show comment
Hide comment
@arjoly

arjoly Jan 22, 2014

Member

Great news all these improvements !
Nice picture @pprett !

Member

arjoly commented Jan 22, 2014

Great news all these improvements !
Nice picture @pprett !

@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 22, 2014

Member

Ping @larsmans. Do you wish to make any more changes on this?

Member

glouppe commented Jan 22, 2014

Ping @larsmans. Do you wish to make any more changes on this?

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 23, 2014

Member

Not related to this PR but when benching I found that I had increasing memory usage over long IPython session. I made a script and indeed there is a memory leak (both in master and in the glouppe/tree-sort branch). The leak was not there in 0.14.1: see the script in #2787

Member

ogrisel commented Jan 23, 2014

Not related to this PR but when benching I found that I had increasing memory usage over long IPython session. I made a script and indeed there is a memory leak (both in master and in the glouppe/tree-sort branch). The leak was not there in 0.14.1: see the script in #2787

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 25, 2014

Member

Changed the algorithm to introsort as given in Musser's 1997 paper (except for the fallback to insertion sort in the final phase). Included @glouppe's changes. Please review/merge.

Member

larsmans commented Jan 25, 2014

Changed the algorithm to introsort as given in Musser's 1997 paper (except for the fallback to insertion sort in the final phase). Included @glouppe's changes. Please review/merge.

larsmans and others added some commits Jan 6, 2014

ENH faster heapsort in trees
Uses the heapsort version that I'm familiar with: linear-time heapify
followed by n delete-min operations. Also cache-friendlier by copying
active features into a temporary array.

Random forest training time on covertype:
now     34.4123s
before  60.1179s
ENH introsort in tree learner
Covertype w/ random forest training time down to 29.0857s.
@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 25, 2014

Member

Also rebased on master to include the changes from #2790.

Member

larsmans commented Jan 25, 2014

Also rebased on master to include the changes from #2790.

@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 26, 2014

Member

Thanks for the merge and the rebase @larsmans

I have personally no further comments to make. Have you benchmarked introsort just to make sure it is indeed faster?

Member

glouppe commented Jan 26, 2014

Thanks for the merge and the rebase @larsmans

I have personally no further comments to make. Have you benchmarked introsort just to make sure it is indeed faster?

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 26, 2014

Member

Introsort is about as fast as quicksort. The only reason to use it is to get rid of the worst-case quadratic behavior.

Member

larsmans commented Jan 26, 2014

Introsort is about as fast as quicksort. The only reason to use it is to get rid of the worst-case quadratic behavior.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 26, 2014

Member

I don't have access to my workstation right now, but on my laptop I get for the covertype bench:

python benchmarks/bench_covertype.py --n-jobs=4 --classifiers=RandomForest --random-seed=1
  • master:
RandomForest  81.5990s   0.2137s     0.0301
  • this branch:

    RandomForest  34.9615s   0.2133s     0.0301
    

So this is very good (at least as fast as the previous benchmarks run with quicksort).

Member

ogrisel commented Jan 26, 2014

I don't have access to my workstation right now, but on my laptop I get for the covertype bench:

python benchmarks/bench_covertype.py --n-jobs=4 --classifiers=RandomForest --random-seed=1
  • master:
RandomForest  81.5990s   0.2137s     0.0301
  • this branch:

    RandomForest  34.9615s   0.2133s     0.0301
    

So this is very good (at least as fast as the previous benchmarks run with quicksort).

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 26, 2014

Member

I had a quick look at the code and it looks fine to me although I am not familiar with sorting algorithms. +1 for merging on my side.

Member

ogrisel commented Jan 26, 2014

I had a quick look at the code and it looks fine to me although I am not familiar with sorting algorithms. +1 for merging on my side.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 26, 2014

Member

I also ran my memory leak detection script with the DecisionTreeRegressor class instead of the ExtraTreeRegressor class and neither psutil's reported RSS nor objgraph.get_leaking_objects detected a leak.

Member

ogrisel commented Jan 26, 2014

I also ran my memory leak detection script with the DecisionTreeRegressor class instead of the ExtraTreeRegressor class and neither psutil's reported RSS nor objgraph.get_leaking_objects detected a leak.

@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 26, 2014

Member

Thanks for the check! @larsmans feel free to merge this in :)

On 26/01/2014, Olivier Grisel notifications@github.com wrote:

I also ran my memory leak detection script with the DecisionTreeRegressor
class instead of the ExtraTreeRegressor class and neither psutil's
reported RSS nor objgraph.get_leaking_object detected a leak.


Reply to this email directly or view it on GitHub:
#2747 (comment)

Member

glouppe commented Jan 26, 2014

Thanks for the check! @larsmans feel free to merge this in :)

On 26/01/2014, Olivier Grisel notifications@github.com wrote:

I also ran my memory leak detection script with the DecisionTreeRegressor
class instead of the ExtraTreeRegressor class and neither psutil's
reported RSS nor objgraph.get_leaking_object detected a leak.


Reply to this email directly or view it on GitHub:
#2747 (comment)

@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Jan 26, 2014

Member

awesome!!! great work!

Member

amueller commented Jan 26, 2014

awesome!!! great work!

larsmans added a commit that referenced this pull request Jan 26, 2014

Merge pull request #2747 from larsmans/tree-sort
[MRG] faster sorting in trees; random forests almost 2× as fast

@larsmans larsmans merged commit 9f6dbc5 into scikit-learn:master Jan 26, 2014

1 check passed

default The Travis CI build passed
Details
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 26, 2014

Member

\o/

Member

ogrisel commented Jan 26, 2014

\o/

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 26, 2014

Member

Please don't forget to add an entry in the whats_new.rst file.

Member

ogrisel commented Jan 26, 2014

Please don't forget to add an entry in the whats_new.rst file.

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Jan 26, 2014

Member

Great work! Nice to know we persist in teaching diverse sorting algorithms
for good reason!

On 27 January 2014 09:09, Olivier Grisel notifications@github.com wrote:

\o/


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-33332180
.

Member

jnothman commented Jan 26, 2014

Great work! Nice to know we persist in teaching diverse sorting algorithms
for good reason!

On 27 January 2014 09:09, Olivier Grisel notifications@github.com wrote:

\o/


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-33332180
.

@songgc

This comment has been minimized.

Show comment
Hide comment
@songgc

songgc Jan 27, 2014

After this merge, the GBRT regression takes 2 times longer on a data set than the previous build (commit bf1635d). The loss score seems OK. BTW, the data set is MSLR-WEB10K/Fold1 (MS learning of rank)

songgc commented Jan 27, 2014

After this merge, the GBRT regression takes 2 times longer on a data set than the previous build (commit bf1635d). The loss score seems OK. BTW, the data set is MSLR-WEB10K/Fold1 (MS learning of rank)

@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jan 27, 2014

Member

Could you try with 31491f9 as head
instead? The only changes on GBRT are with regards to the
PresortBestSplitter and shouldn't make things slower. CC: @pprett

On 27 January 2014 08:15, Guocong Song notifications@github.com wrote:

After this merge, the GBRT regression takes 2 times longer on a data set
than the previous build (commit bf1635d). The loss score seems OK. BTW, the
data set is MSLR-WEB10K/Fold1 (MS learning of rank)


Reply to this email directly or view it on GitHub.

Member

glouppe commented Jan 27, 2014

Could you try with 31491f9 as head
instead? The only changes on GBRT are with regards to the
PresortBestSplitter and shouldn't make things slower. CC: @pprett

On 27 January 2014 08:15, Guocong Song notifications@github.com wrote:

After this merge, the GBRT regression takes 2 times longer on a data set
than the previous build (commit bf1635d). The loss score seems OK. BTW, the
data set is MSLR-WEB10K/Fold1 (MS learning of rank)


Reply to this email directly or view it on GitHub.

@songgc

This comment has been minimized.

Show comment
Hide comment
@songgc

songgc Jan 27, 2014

Hash 31491f9 is even faster than bf1635d by 15%.

songgc commented Jan 27, 2014

Hash 31491f9 is even faster than bf1635d by 15%.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 27, 2014

Member

thanks @songgc - I did a quick benchmark using my solar dataset (regression).

I looked at master ( ), introsort (31491f9), MSE Optim (0b7c79b), best-first (834b375).
I can definitely see a performance regression between 0b7c79b and 31491f9 .
Since it wasn't exposed in the latest benchmarks we did I assume it is an effect of the memory leak fix. I need to check this in more detail.

@songgc I find the 2x performance degression quite harsh - can you tell me which parameters you used (max_features, max_depth, n_estimators).

Member

pprett commented Jan 27, 2014

thanks @songgc - I did a quick benchmark using my solar dataset (regression).

I looked at master ( ), introsort (31491f9), MSE Optim (0b7c79b), best-first (834b375).
I can definitely see a performance regression between 0b7c79b and 31491f9 .
Since it wasn't exposed in the latest benchmarks we did I assume it is an effect of the memory leak fix. I need to check this in more detail.

@songgc I find the 2x performance degression quite harsh - can you tell me which parameters you used (max_features, max_depth, n_estimators).

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 27, 2014

Member

@songgc have you fixed the random_state parameter of your GradientBoostingRegressor? I tried on a subsamples of 62244 MSLR results / 136 (500 queries) with GradientBoostingRegressor(n_estimators=100, random_state=1) and it trains in 1m30s both in master and on bf1635d and yield NDCG@10=0.507 each time.

I also tried to bench GradientBoostingRegressor on a simple make_friedman3 dataset with 100k samples and the training speed is the same.

Member

ogrisel commented Jan 27, 2014

@songgc have you fixed the random_state parameter of your GradientBoostingRegressor? I tried on a subsamples of 62244 MSLR results / 136 (500 queries) with GradientBoostingRegressor(n_estimators=100, random_state=1) and it trains in 1m30s both in master and on bf1635d and yield NDCG@10=0.507 each time.

I also tried to bench GradientBoostingRegressor on a simple make_friedman3 dataset with 100k samples and the training speed is the same.

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 27, 2014

Member

@ogrisel ok - I'm running a benchmark suite now with 3 repetitions between current master and @arjoly MSE optimization -- I keep you posted.

@songgc it would be great if you could post the parameters you used -- tree building performance can differ quite considerably depending on the parameters used (eg. max_features)

Member

pprett commented Jan 27, 2014

@ogrisel ok - I'm running a benchmark suite now with 3 repetitions between current master and @arjoly MSE optimization -- I keep you posted.

@songgc it would be great if you could post the parameters you used -- tree building performance can differ quite considerably depending on the parameters used (eg. max_features)

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 27, 2014

Member

gbrt-bench-perf-reg

this one just uses smaller datasets -- it looks good IMHO - I used the following params::

classification_params = {'n_estimators': 500, 'loss': 'deviance',
                     'min_samples_leaf': 1, 'max_leaf_nodes': 6,
                     'max_depth': None,
                     'learning_rate': .01, 'subsample': 1.0, 'verbose': 0}

regression_params = {'n_estimators': 500, 'max_leaf_nodes': 6,
                 'max_depth': None,
                 'min_samples_leaf': 1, 'learning_rate': 0.01,
                 'loss': 'ls', 'subsample': 1.0, 'verbose': 0,
                 }
Member

pprett commented Jan 27, 2014

gbrt-bench-perf-reg

this one just uses smaller datasets -- it looks good IMHO - I used the following params::

classification_params = {'n_estimators': 500, 'loss': 'deviance',
                     'min_samples_leaf': 1, 'max_leaf_nodes': 6,
                     'max_depth': None,
                     'learning_rate': .01, 'subsample': 1.0, 'verbose': 0}

regression_params = {'n_estimators': 500, 'max_leaf_nodes': 6,
                 'max_depth': None,
                 'min_samples_leaf': 1, 'learning_rate': 0.01,
                 'loss': 'ls', 'subsample': 1.0, 'verbose': 0,
                 }
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Jan 27, 2014

Member

Same here I do not see any regression between 31491f9 and the current master:

I trained GradientBoostingRegressor(n_estimators=100, random_state=1) on the full Fold1 split of MSLR-10K (3 folds train + 1 fold val == 958671 samples) in 33min on master and 36 min on 31491f9. In both cases I get the following scores on the Fold1 test fold:

  • NDCG@5: 0.506
  • NDCG@10: 0.514
  • R2: 0.168

If I understand correctly, the only commit that is impacting between 31491f9 and master is @glouppe's cache optim a681c9b (aka: ENH Make PresortBestSplitter cache friendly + cosmetics). It seems to indeed work on my box by removing 3mins of training time.

Member

ogrisel commented Jan 27, 2014

Same here I do not see any regression between 31491f9 and the current master:

I trained GradientBoostingRegressor(n_estimators=100, random_state=1) on the full Fold1 split of MSLR-10K (3 folds train + 1 fold val == 958671 samples) in 33min on master and 36 min on 31491f9. In both cases I get the following scores on the Fold1 test fold:

  • NDCG@5: 0.506
  • NDCG@10: 0.514
  • R2: 0.168

If I understand correctly, the only commit that is impacting between 31491f9 and master is @glouppe's cache optim a681c9b (aka: ENH Make PresortBestSplitter cache friendly + cosmetics). It seems to indeed work on my box by removing 3mins of training time.

@larsmans larsmans deleted the larsmans:tree-sort branch Jan 27, 2014

@songgc

This comment has been minimized.

Show comment
Hide comment
@songgc

songgc Jan 27, 2014

My apologies for the false alarm! I have found that I installed version 0.14.1rather than the master branch...

pip install scikit-learn git+https://github.com/scikit-learn/scikit-learn.git gives me the stable version.
pip install -f scikit-learn file://"a synced local repos" gives me the master branch.

My lesson is "check version first". Currently, my benchmarks are consistent with you guys. The speed improvement is impressive! I don't to need to transfer data to R for the GBM package:)

songgc commented Jan 27, 2014

My apologies for the false alarm! I have found that I installed version 0.14.1rather than the master branch...

pip install scikit-learn git+https://github.com/scikit-learn/scikit-learn.git gives me the stable version.
pip install -f scikit-learn file://"a synced local repos" gives me the master branch.

My lesson is "check version first". Currently, my benchmarks are consistent with you guys. The speed improvement is impressive! I don't to need to transfer data to R for the GBM package:)

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 27, 2014

Member

no worries @songgc thanks for double-checking -- I'd say there is definitely no reason now to switch to R for the randomForest package ;)

Member

pprett commented Jan 27, 2014

no worries @songgc thanks for double-checking -- I'd say there is definitely no reason now to switch to R for the randomForest package ;)

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jan 27, 2014

Member

@songgc @pprett Any benchmarks against R? :)

Member

larsmans commented Jan 27, 2014

@songgc @pprett Any benchmarks against R? :)

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 27, 2014

Member

I've code for this -- next days are a bit busy -- will post it later this
week - I expect quite a differences because 0.14.1 used to be faster as
well. WiseRF is the competitor :)

2014-01-27 Lars Buitinck notifications@github.com

@songgc https://github.com/songgc @pprett https://github.com/pprettAny benchmarks against R? :)


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-33433329
.

Peter Prettenhofer

Member

pprett commented Jan 27, 2014

I've code for this -- next days are a bit busy -- will post it later this
week - I expect quite a differences because 0.14.1 used to be faster as
well. WiseRF is the competitor :)

2014-01-27 Lars Buitinck notifications@github.com

@songgc https://github.com/songgc @pprett https://github.com/pprettAny benchmarks against R? :)


Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2747#issuecomment-33433329
.

Peter Prettenhofer

@songgc

This comment has been minimized.

Show comment
Hide comment
@songgc

songgc Jan 27, 2014

@larsmans
I had a benchmark case against GBM as follows:
Data set: MSLR-WEB10K/Fold1
params:
for GBRT {'n_estimators': 100, 'max_depth': 4, 'min_samples_split': 10,
'learning_rate': 0.03, 'loss': 'ls', 'subsample': 0.5, 'random_state': 11, 'verbose': 1}
for GBM {"distribution": "gaussian", "shrinkage": 0.03,
"n.tree": 100, "bag.fraction": 0.5, "verbose": True,
"n.minobsinnode": 10, "interaction.depth": 6}
Please note that max depths are different. GBM usually requires deeper trees compared to GBRT to achieve a similar performance.

Benchmark result:
library, test MSE, running time
GBRT, 0.5854, 1238s
GBM, 0.5943, 1442s

songgc commented Jan 27, 2014

@larsmans
I had a benchmark case against GBM as follows:
Data set: MSLR-WEB10K/Fold1
params:
for GBRT {'n_estimators': 100, 'max_depth': 4, 'min_samples_split': 10,
'learning_rate': 0.03, 'loss': 'ls', 'subsample': 0.5, 'random_state': 11, 'verbose': 1}
for GBM {"distribution": "gaussian", "shrinkage": 0.03,
"n.tree": 100, "bag.fraction": 0.5, "verbose": True,
"n.minobsinnode": 10, "interaction.depth": 6}
Please note that max depths are different. GBM usually requires deeper trees compared to GBRT to achieve a similar performance.

Benchmark result:
library, test MSE, running time
GBRT, 0.5854, 1238s
GBM, 0.5943, 1442s

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Jan 28, 2014

Member

@songgc the current master also includes an option to build GBM style trees in GradientBoostingRegressor|Classifier -- use the max_leaf_nodes argument (max_leaf_nodes -1 equals interaction.depth)

Member

pprett commented Jan 28, 2014

@songgc the current master also includes an option to build GBM style trees in GradientBoostingRegressor|Classifier -- use the max_leaf_nodes argument (max_leaf_nodes -1 equals interaction.depth)

@amueller amueller modified the milestones: 0.16, 0.15 Jul 15, 2014

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment