[MRG] Complete rewrite of the tree module #2131

Merged
merged 124 commits into from Jul 22, 2013

Projects

None yet
@glouppe
Member
glouppe commented Jul 6, 2013

Here are some good news for Kaggle competitors... :)

In this PR, I propose a complete rewrite of the core tree module (_tree.pyx) and of all tree-dependent estimators. In particular, this new implementation factorizes out the splitting strategy at the core of the construction process of a tree. Such a strategy is now implemented in a Splitter object as specified in the new interface in _tree.pxd. As of now, this PR provides two splitting strategies, BestSplitter for finding the best split in a node (this is CART) and RandomSplitter for finding the best random split (this is Extra-Tree).

The PR adresses 3 issues with the module: modularity, space and speed.

  1. Modularity: It is now more convenient to write a new Splitting strategy. For example, one could now easily write a splitting strategy based on binning and plug it in the module.

I think it is also possible to reimplement our old X_argsorted strategy in a shared Splitter object. We will have to carry out some benchmarks though, to see if it is worth it.

  1. Space: no more X_argsorted, no more sample_mask, no more min_density stuff. No more dupplicates of X! A tree is now directly built on the original version of X.

  2. Speed: Both Random Forest and Extra-Trees have been speeded up. The most significant improvement benefits to Extra-Trees which are now properly implemented, i.e. without any sorting.

As a small benchmark, I did a little experiment on mnist3vs8 (around ~12000 samples, 784 features, 2 classes):

Parameters:
n_estimators=10
max_features=5
bootstrap=False
random_state=0
all other defaults

master
    RandomForestClassifier
        In [4]: %timeit -n 5 clf.fit(X_train, y_train)
        5 loops, best of 3: 22.2 s per loop

    ExtraTreesClassifier
        In [7]: %timeit -n 5 clf.fit(X_train, y_train)
        5 loops, best of 3: 25.8 s per loop

trees-v2
    RandomForestClassifier
        In [5]: %timeit -n 5 clf.fit(X_train, y_train)
        5 loops, best of 3: 1.33 s per loop
        => Speedup = 16.7x

    ExtraTreesClassifier
        In [9]: %timeit -n 5 clf.fit(X_train, y_train)
        5 loops, best of 3: 757 ms per loop
        => Speedup = 34.08x

I think the figures speak for themselves... :-)


This is still a work in progress and lot of work still needs to be done. However all tests in test_tree.py and test_forest.py already pass.

On my todo list:

  • Fix the GBRT module for the new interface
  • Fix the AdaBoost module for the new interface
  • Depecrate removed parameters
  • Update documentation
  • More benchmarks
  • PEP8 / Flake
  • Check the impact of this refactoring on the RandomState unpickling perf issue from #1622

Once all of this is done, I'll call for reviews. My goal is to have this beast merged in during the sprint. I think it is shaping up very well... :)

CC: @pprett @bdholt1 @ndawe @arjoly

@jakevdp
Member
jakevdp commented Jul 6, 2013

This looks awesome! Though it might ruin wise.io's favorite speed demo... ๐Ÿ˜„

@satra
Member
satra commented Jul 6, 2013

this is great.

ps. also at the recent pattern recognition in neuroimaging conference many people were introduced to random forests and extra trees in a tutorial. this will also benefit all of those folks as well.

@pprett
Member
pprett commented Jul 6, 2013

aahhhh... I'm so excited - great work Gilles!

I will make GBRT running in no time.

2013/7/6 Satrajit Ghosh notifications@github.com

this is great.

ps. also at the recent pattern recognition in neuroimaging conference many
people were introduced to random forests and extra trees in a tutorial.
this will also benefit all of those folks as well.

โ€”
Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2131#issuecomment-20556793
.

Peter Prettenhofer

@pprett
Member
pprett commented Jul 6, 2013

@glouppe could you make a gist with the benchmark or is it in the repo?

@glouppe
Member
glouppe commented Jul 6, 2013

@pprett The benchmark is not in the repo. I run the following script inside IPython:

import numpy as np
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
data = np.load("/home/gilles/PhD/db/data/mnist3vs8.npz")
X_train = data["X_train"]
y_train = data["y_train"]
X_test = data["X_test"]
y_test = data["y_test"]

clf = RandomForestClassifier(n_estimators=10, max_features=5, bootstrap=False)
%timeit -n 5 clf.fit(X_train, y_train)

So, nothing very scientific. We should make a proper script to benchmark the two implementations in more details (this is on the todo list).

@pprett
Member
pprett commented Jul 6, 2013

in _tree.pyx line 88::

 def __cinit__(self, SIZE_t n_outputs, object n_classes):

shouldn't n_classes be np.ndarray of dtype SIZE_T or SIZE_t[::1]

(doesn't matter much performance wise but the yellow stain in the annotated cython hurts my eyes)

@pprett
Member
pprett commented Jul 6, 2013

Currently, you call the following line in each iteration over the max_features loop::

f_j = random_state.randint(0, n_features - f_idx)

this calls into the interpreter; for large max_features it might be better to create a random array of size max_features and then loop over that (alternatively, maybe we can use the c-api of RandomState)

@glouppe
Member
glouppe commented Jul 6, 2013

it might be better to create a random array of size max_features and then loop over that

The thing is, we may need to visit more than max_features features (because e.g, some may be constant for that node and should not be counted among the max_features features to be considered).

I'll dig into the c-api of Random State and try to solve that.

@GaelVaroquaux
Member

This looks awesome! Though it might ruin wise.io's favorite speed demo...

It's already fake, as if you choose reasonable settings you don't achieve
speed ups as fast as they pretend.

@glouppe
Member
glouppe commented Jul 7, 2013

@pprett Both of your comments have been addressed. However, GBRT still happens to be slower than before. I think this comes from the new default splitting strategy in decision trees. As you build more and more trees, pre-sorting X compensates, and eventually beats, sorting the features on the fly at each node. I'll work on a new splitter starting from tomorrow.

@pprett
Member
pprett commented Jul 7, 2013

@glouppe I can do that too - already started some coding - for now I just wanted to write a splitter that implements the same strategy as in our current impl (presorting + sample mask).

do you see a better way to incorporate presorting?

@glouppe
Member
glouppe commented Jul 7, 2013

A better strategy is the one from Breiman's code: X_argsorted is computed once. Then, for each tree, it is duplicated and then rearranged inplace when splitting a node. This is fast, but requires additional memory space.

(I'll be away today, we can discuss this tomorrow.)

@pprett
Member
pprett commented Jul 7, 2013

ok - thanks

2013/7/7 Gilles Louppe notifications@github.com

A better strategy is the one from Breiman's code: X_argsorted is computed
once. Then, for each tree, it is duplicated and then rearranged inplace
when splitting a node. This is fast, but requires additional memory space.

(I'll be away today, we can discuss this tomorrow.)

โ€”
Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2131#issuecomment-20567872
.

Peter Prettenhofer

@pprett
Member
pprett commented Jul 7, 2013

another random thought: currently you sort the data at the end of find_split. Since find_split is called for each node (internal and leaf) it might be a bit faster to sort only if the node is internal.

what do you think - does it matter? for deep trees and low min_samples_leaf it might not matter much...

@ogrisel
Member
ogrisel commented Jul 7, 2013

It seems that most of the code of method functions like find_split does not need to call into Python at all. In that case it would be interesting to add the nogil marker either for the whole function or for the CPU intensive inner loops in with nogil blocks.

http://docs.cython.org/src/userguide/external_C_code.html#releasing-the-gil

That way it should make it possible to have a threadpool version of joblib.Parallel (TODO) be able to efficiently train random forests on multicore machines without incurring any memory copy of the dataset.

We can discuss this further during the next sprint by adding nogil declarations right now will make it easier prototyping stuff with threads later.

@glouppe
Member
glouppe commented Jul 8, 2013

Regarding wise.io implementation, I just the followed the benchmark they posted some time ago (http://about.wise.io/blog/2012/11/22/wiserf-introduction-and-benchmarks/) and compared computing times with this branch. The full benchmark script is available at https://gist.github.com/glouppe/5949526

I used WiseRF 1.5.6 (as available on their website).

On my machine, results are the following:

RandomForestClassifier: Accuracy: 0.95     34.61s
ExtraTreesClassifier: Accuracy: 0.96       26.65s
WiseRFClassifier Accuracy: 0.95            28.48s

Take home message: According to their blog post, WiseRFClassifier was 7.19x faster than RandomForestClassifier. It is now only 1.21x faster. Barely noticeable. Yet, this means we still have some room for improvements. Let's beat that!

Also, ExtraTreesClassifier confirms to be faster than both of them - while being more accurate! Really, why bother finding the very best splits? ;-)

@arjoly
Member
arjoly commented Jul 8, 2013

Nice !!! :-)

@ogrisel
Member
ogrisel commented Jul 8, 2013

Very cool. I was already sold on extra trees but it's nice to see it confirmed once more.

@ogrisel
Member
ogrisel commented Jul 8, 2013

What is the current performance impact on Adaboost and GBRT?

@glouppe
Member
glouppe commented Jul 8, 2013

What is the current performance impact on Adaboost and GBRT?

Currently, this branch is slower than master for GBRT. We have been profiling the code with Peter and it happens that pre-sorting X is a better strategy for shallow trees with a large max_features value (as in stumps used in GBRT). To solve that, we plan on writing a third Splitter object tuned for that use-case and rhoughly reimplementing our old strategy (pre-sorting X once and for all trees).

@ogrisel
Member
ogrisel commented Jul 8, 2013

Yes this is what I understood from the previous exchanges. What I wanted to know was more actual numbers: is it 10%, 50% or 200% slower or does it really depend on the size of the dataset?

@glouppe
Member
glouppe commented Jul 8, 2013

Yes this is what I understood from the previous exchanges. What I wanted to know was more actual numbers: is it 10%, 50% or 200% slower or does it really depend on the size of the dataset?

About twice slower on the benchmarks we did.

(On some others, it was faster though... so this is really parameter dependent.)

@yang
yang commented Jul 8, 2013

@glouppe Awesome work! For those of us who have been curious about how wise.io's RFs were so much faster, what do you think was the most important change for the performance gains?

@glouppe
Member
glouppe commented Jul 8, 2013

@yang For historical reasons I would say, finding a split within a node was previously linearly proportional to the total number of samples in X. It is now linear with respect to the number of samples falling into that node, as it should always have been. The new implementation is also now organized in a better way, which allows to easily implement various splitting strategies. In particular, Extra-Trees now benefit from their own Splitter object, which find splits without having to sort features. Previously, the implementation of Extra-Trees was more closely related to the one of Random Forest, which made them actually quite inefficient... Regarding wise.io, I wouldn't say anything, as they always remained quite obscure regarding implementation details, but I don't think they do anything magical. Just a plain, classical, but well coded, Random Forest algorithm.

@larsmans larsmans commented on the diff Jul 9, 2013
sklearn/ensemble/_gradient_boosting.pyx
# Define a datatype for the data array
DTYPE = np.float32
ctypedef np.float32_t DTYPE_t
+ctypedef np.npy_intp SIZE_t
@larsmans
larsmans Jul 9, 2013 Member

Good to see npy_intp in use!

@larsmans
Member
larsmans commented Jul 9, 2013

How would you handle negative sample weights, apart from throwing an exception?

@glouppe
Member
glouppe commented Jul 9, 2013

Just redid the benchmark:

RandomForestClassifier: Accuracy: 0.95  27.13s
ExtraTreesClassifier: Accuracy: 0.95    22.17s
WiseRFClassifier Accuracy: 0.95 27.68s

Load on my machine seemed to have affected computing times :-)

(Anyway, for a proper comparison, one should run that benchmark several times and then make a t-test to assess the significance.)

@arjoly
Member
arjoly commented Jul 12, 2013

Thanks for taking my comment into account (commit 54b77e1 to b26273c).

@pprett
Member
pprett commented Jul 12, 2013

I did some benchmarks with GBRT on covertype - here are the results

I've used the bench_covtype.py script but changed the GBRT params as follows::

GradientBoostingClassifier(n_estimators=200, min_samples_split=5,  max_features='log2',
                                        max_depth=6, subsample=0.2, verbose=3,
                                        random_state=opts.random_seed)
version train time test time error rate
tree-v2 472.6818s 1.5883s 0.1489
master 219.1950s 0.5327s 0.1332

Two things seem strange: a) test time on tree-v2 is much larger -- this shouldn't be the case unless the trees have been build differently (maybe the random state is used differently in master vs tree-v2?)
b) train time is 2x faster on master even though the tree depth is rather large -- maybe this interacts with subsample?

@pprett
Member
pprett commented Jul 12, 2013

Here I've a comparison on a number of datasets -- again using fairly deep trees (depth=6)

for classification I've used the following params::

classification_params = {'loss': 'deviance', 'n_estimators': 250,
                     'min_samples_leaf': 1, 'max_depth': 6,
                     'learning_rate': .6, 'subsample': 1.0}

master-treev2-clf

for regression I used the params::

regression_params = {'n_estimators': 250, 'max_depth': 6,
                 'min_samples_leaf': 1, 'learning_rate': 0.1,
                 'loss': 'ls'}

master-treev2-reg

@ogrisel
Member
ogrisel commented Jul 12, 2013

The test time change is strange. Could you output the mean effective size of the trees in the ensembles? Maybe an implementation detail has changed the way the tree are actually built.

@pprett
Member
pprett commented Jul 12, 2013

Gilles, you will like this one :-)

I tried your ordinary splitter instead of the breiman splitter for GBRT - here are the results for covertype

version train time test time error rate
tree-v2 144.0397s 0.6534s 0.1481
master 219.1950s 0.5327s 0.1332

way faster! also, the test time thing was probably an artifact...

@pprett
Member
pprett commented Jul 12, 2013

here is the best splitter with smaller trees (depth=3)

version train time test time error rate
tree-v2 73.3558s 0.5371s 0.1958
master 88.0081s 0.3438s 0.1936

and here only 100 trees w/ max_features=None

version train time test time error rate
master 200.9225s 0.1981s 0.1926
tree-v2 227.3296s 0.3089s 0.1954

and here only 20 trees w/o sub-sampling

version train time test time error rate
master 82.1088s 0.0800s 0.2243
tree-v2 139.6819s 0.0920s 0.2243
@pprett
Member
pprett commented Jul 12, 2013

from this benchmarks I basically conclude that your default splitter actually does pretty good for GBRT on larger datasets (even using "shallow" trees).

@glouppe
Member
glouppe commented Jul 12, 2013

Thanks for the benchmarks Peter! So basically, this is a good news. I will remove the BreimanSplitter since it does not prove to be that effective. I have actually always prefered the default one, since it uses no additional memory and happens to be faster, as your last results suggest.

@glouppe
Member
glouppe commented Jul 12, 2013

I have just removed BreimanSplitter and made GBRT use the default splitter. This greatly simplifies the code :)

I have also just pushed a small optimization to predict. Don't know if it will change anything. This is quite strange though, since the code is basically identical to master. Maybe this is due to compilation option (I added -03 and -funroll-all-loops in setup.py, this proved to be slightly faster on my machine, but it is maybe not the case on yours. Could run a quick benchmark with and without loop unrolling? If it happens to be slower, then I'll remove it.)

@glouppe
Member
glouppe commented Jul 12, 2013

maybe the random state is used differently in master vs tree-v2?

Yes, random states are used differently, but the variance shouldn't be that high...

@glouppe
Member
glouppe commented Jul 12, 2013

I'll have some time tomorrow - I'll launch some more benchmarks, for RandomForest, Extra-Trees and GBRT on larger datasets, comparing results with master.

@ogrisel ogrisel and 1 other commented on an outdated diff Jul 12, 2013
sklearn/ensemble/weight_boosting.py
"""Implement a single boost using the SAMME.R real algorithm."""
estimator = self._make_estimator()
- if X_argsorted is not None:
- estimator.fit(X, y, sample_weight=sample_weight,
- X_argsorted=X_argsorted)
- else:
- estimator.fit(X, y, sample_weight=sample_weight)
+ try:
+ estimator.set_params(random_state=self.random_state)
+ except:
+ pass
@ogrisel
ogrisel Jul 12, 2013 Member

This looks very very fishy. What kind of exceptions have you seen in practice?

@ogrisel
ogrisel Jul 12, 2013 Member

If random state is no an expected parameters, all the scikit-learn model should raise the same exception type and we should only catch this one.

@glouppe
glouppe Jul 12, 2013 Member

AdaBoost works for any estimator. If we don't add this line, then the test unit using an svm inside adaboost crashes, because SVM doesn't have any random_state parameter. Yet , in my opinion, if the base estimator has one, then it should properly be controlled by the meta-estimator.

@ogrisel
ogrisel Jul 12, 2013 Member

That what I though, let's be specific a only except the expected exception type(s): I suspect we expect AttributeError in this case.

@ogrisel ogrisel commented on an outdated diff Jul 12, 2013
sklearn/ensemble/weight_boosting.py
"""Implement a single boost using the SAMME discrete algorithm."""
estimator = self._make_estimator()
- if X_argsorted is not None:
- estimator.fit(X, y, sample_weight=sample_weight,
- X_argsorted=X_argsorted)
- else:
- estimator.fit(X, y, sample_weight=sample_weight)
+ try:
+ estimator.set_params(random_state=self.random_state)
+ except:
+ pass
@ogrisel
ogrisel Jul 12, 2013 Member

Same here.

@ogrisel ogrisel commented on an outdated diff Jul 12, 2013
sklearn/tree/_tree.pxd
- cdef double* threshold
- cdef double* value
- cdef double* best_error
- cdef double* init_error
- cdef int* n_samples
-
- cdef np.ndarray features
+ cdef public SIZE_t node_count # Counter for node IDs
+ cdef public SIZE_t capacity # Capacity
+ cdef SIZE_t* children_left # children_left[i] is the left child of node i
+ cdef SIZE_t* children_right # children_right[i] is the right child of node i
+ cdef SIZE_t* feature # features[i] is the feature used for splitting node i
+ cdef double* threshold # threshold[i] is the threshold value at node i
+ cdef double* value # value[i] is the values contained at node i
+ cdef double* impurity # impurity[i] is the impurity of node i
+ cdef SIZE_t* n_node_samples # n_node_samples[i] is the number of samples at node i
@ogrisel
ogrisel Jul 12, 2013 Member

It's a very good practice to document the parameters of the cython code this way. Thanks very much it helps a lot understanding the code.

@ogrisel
ogrisel Jul 12, 2013 Member

However could you explain what counter, capacity impurity actually are and what are they used for? Maybe in a multiline comment right below or above the block of parameters.

@ogrisel ogrisel and 1 other commented on an outdated diff Jul 12, 2013
sklearn/tree/_tree.pxd
- double* _best_t, double* _best_error,
- double* _initial_error)
-
+ cpdef build(self, np.ndarray X,
+ np.ndarray y,
+ np.ndarray sample_weight=*)
+
+ cdef SIZE_t add_node(self, SIZE_t parent,
+ bint is_left,
+ bint is_leaf,
+ SIZE_t feature,
+ double threshold,
+ double impurity,
+ SIZE_t n_node_samples)
+
+ cdef void resize(self, SIZE_t capacity=*)
@ogrisel
ogrisel Jul 12, 2013 Member

What do the equal-star signs mean? I had never seen that before.

@glouppe
glouppe Jul 12, 2013 Member

It is a Cython syntax. Basically, you cannot put default values in pxd files and have to use that instead (don't know why, but it is like that).

@ogrisel
ogrisel Jul 12, 2013 Member

Alright, thanks.

@glouppe
Member
glouppe commented Jul 13, 2013

I just ran some benchmarks comparing master and this branch on some datasets. For each, I built either a RandomForestClassifier or a ExtraTreesClassifier with n_estimators=100, max_features=sqrt and all default parameters.

The two plots below show the speedup (i.e. the ratio between master and trees-v2) of the new implementation for fit and predict.

image001
image003

As my very first results suggested, this new implementation is most benetial to Extra-Trees, which are way faster (from 6x to 18x!). RandomForest are also faster, but that is not as spectacular (from 2x to 4x), yet this is still a big improvement. To be completely honest though, it is easy to make this results quite different by tuning max_features. For low values, the new implementation will be even faster.

Regarding predictions, results do not seem to show any speedup (computing times are all almost the same, since speedups are around 1.0), which is expected since the implementation is nearly identical.

@ogrisel
Member
ogrisel commented Jul 13, 2013

This looks very good from a performance point of view and all tests pass. Maybe a final pass of pep8 / pyflakes? Have you checked that the existing examples still look good?

@ogrisel
Member
ogrisel commented Jul 13, 2013

BTW @glouppe I like the style of your plot. Do you have a matplotlibrc file somewhere? Or a some bench / plot script in a gist?

@glouppe
Member
glouppe commented Jul 13, 2013

BTW @glouppe I like the style of your plot. Do you have a matplotlibrc file somewhere? Or a some bench / plot script in a gist?

Then you like Excel plots :-)

I'll check the examples and do another pass of pep8/flake8 (but those do not really comply with Cython files...).

@ogrisel
Member
ogrisel commented Jul 13, 2013

Then you like Excel plots :-)

arf

@amueller
Member

Just wanted to say how awesome this is. I think this is a very important improvement. ๐Ÿ‘

@ogrisel
Member
ogrisel commented Jul 13, 2013

I ran most the examples and compared the plots with the plots from the master examples published on the website and they still look very similar.

I suspect that the previous test failures are caused by tests that are too sensitive to the RNG or numerical accuracy.

@ogrisel
Member
ogrisel commented Jul 13, 2013

It seems that this branch is slower on the covertype benchmark for extra trees:

$ python benchmarks/bench_covertype.py --n-jobs=2 --classifiers=ExtraTrees

yields:

  • on this branch:
Classifier   train-time test-time error-rate
--------------------------------------------
ExtraTrees   181.3093s   0.8924s     0.0198
  • on master:
Classifier   train-time test-time error-rate
--------------------------------------------
ExtraTrees   163.4042s   0.8334s     0.0202

Not a big deal though as the memory improvement is more important than the 10% difference in fit time in my opinion.

@ogrisel
Member
ogrisel commented Jul 13, 2013

I ran it twice and the results are stable to the second.

@glouppe
Member
glouppe commented Jul 13, 2013

This is really odd. I'll run the benchmark to see...

Edit: oh yeah, max_features is set explicitely to None, which is why it is slower. I think we should update the benchmark file and use the default values instead (i.e., sqrt).

@ogrisel
Member
ogrisel commented Jul 13, 2013

@glouppe any idea about the benchmarks? Can you reproduce the behavior on your box?

@glouppe
Member
glouppe commented Jul 13, 2013

Sorry I replied by editing my previous message. This comes from max_features which is set to None (all) in the benchmark script. In that case, pre-sorting as we do in master is a fine and relevant strategy and actually has a better complexity than sorting on the fly all features in all nodes. For Extra-Trees, this shouldn't make a difference you could say, but pre-sorting allowed us in master to know almost directly which samples were going left and which were going right. In the new implementation, we have to partition the node first (move each sample right or left), which may be a bit slower... I am currently doing some tests on covertype. I'll post my results in a few minutes.

@ogrisel
Member
ogrisel commented Jul 13, 2013

Alright, thanks.

@glouppe
Member
glouppe commented Jul 13, 2013

Here are my results, conclusions are different on my box. The new branch is faster... but I am still not surprised by your results, as I have just explained.

Covertype with Extra-Trees(n_estimators=20, n_jobs=2, min_samples_split=5)

                                train-time  test-time   error-rate
                                -------------------------------
master, max_features=None       173.7674s   0.7530s     0.0202
trees-v2, max_features=None     153.3088s   0.8716s     0.0195 
master, max_features=sqrt       115.7302s   0.8738s     0.0214 
trees-v2, max_features=sqrt     57.5994s    0.9001s     0.0216 
@glouppe
Member
glouppe commented Jul 13, 2013

(I also tend to think that covertype is a particular dataset (hundreds of thousands of samples, few features), from which it is in my opinion not a good idea to draw strong conclusions.)

@ogrisel
Member
ogrisel commented Jul 13, 2013

Alright. Thanks for investigating further.

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/ensemble/forest.py
self.max_features = max_features
+ if min_density is not None:
+ warn("The min_density parameter is deprecated and will be removed "
+ "in 0.15.", DeprecationWarning)
@ogrisel
ogrisel Jul 13, 2013 Member

The message should be: "The min_density parameter is deprecated as of version 0.14 and will be removed in 0.16."

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/ensemble/forest.py
self.max_features = max_features
+ if min_density is not None:
+ warn("The min_density parameter is deprecated and will be removed "
+ "in 0.15.", DeprecationWarning)
+
+ if compute_importances is not None:
+ warn("Setting compute_importances is no longer "
+ "required. Variable importances are now computed on the fly "
+ "when accessing the feature_importances_ attribute. This "
+ "parameter will be removed in 0.15.", DeprecationWarning)
@ogrisel
ogrisel Jul 13, 2013 Member

Same remark.

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/ensemble/forest.py
self.max_features = max_features
+ if min_density is not None:
+ warn("The min_density parameter is deprecated and will be removed "
+ "in 0.15.", DeprecationWarning)
+
+ if compute_importances is not None:
+ warn("Setting compute_importances is no longer "
+ "required. Variable importances are now computed on the fly "
+ "when accessing the feature_importances_ attribute. This "
+ "parameter will be removed in 0.15.", DeprecationWarning)
@ogrisel
ogrisel Jul 13, 2013 Member

To be updated as well.

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/ensemble/forest.py
self.max_features = max_features
+ if min_density is not None:
+ warn("The min_density parameter is deprecated and will be removed "
+ "in 0.15.", DeprecationWarning)
+
+ if compute_importances is not None:
+ warn("Setting compute_importances is no longer "
+ "required. Variable importances are now computed on the fly "
+ "when accessing the feature_importances_ attribute. This "
+ "parameter will be removed in 0.15.", DeprecationWarning)
@ogrisel
ogrisel Jul 13, 2013 Member

here again.

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/ensemble/forest.py
self.max_features = max_features
+ if min_density is not None:
+ warn("The min_density parameter is deprecated and will be removed "
+ "in 0.15.", DeprecationWarning)
+
+ if compute_importances is not None:
+ warn("Setting compute_importances is no longer "
+ "required. Variable importances are now computed on the fly "
+ "when accessing the feature_importances_ attribute. This "
+ "parameter will be removed in 0.15.", DeprecationWarning)
@ogrisel
ogrisel Jul 13, 2013 Member

here again.

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/ensemble/forest.py
self.max_features = 1
+ if min_density is not None:
+ warn("The min_density parameter is deprecated and will be removed "
+ "in 0.15.", DeprecationWarning)
@ogrisel
ogrisel Jul 13, 2013 Member

one more :)

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/ensemble/partial_dependence.py
>>> kwargs = dict(X=samples, percentiles=(0, 1), grid_resolution=2)
>>> partial_dependence(gb, [0], **kwargs)
- (array([[-10.72892297, 10.72892297]]), [array([ 0., 1.])])
+ (array([[-5.67647953, 5.67647953]]), [array([ 0., 1.])])
@ogrisel
ogrisel Jul 13, 2013 Member

I would put ellipsis sign here and only test for the fist 2 or 3 digits.

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/ensemble/weight_boosting.py
@@ -956,6 +936,11 @@ def _boost(self, iboost, X, y, sample_weight, X_argsorted=None):
"""
estimator = self._make_estimator()
+ try:
+ estimator.set_params(random_state=self.random_state)
+ except:
@ogrisel
ogrisel Jul 13, 2013 Member

except ValueError

@ogrisel ogrisel commented on the diff Jul 13, 2013
sklearn/tree/tests/test_tree.py
+
+ # sample_weight = np.ones(X.shape[0])
+ # sample_weight[0] = -1
+ # clf = tree.DecisionTreeClassifier(random_state=1)
+ # clf.fit(X, y, sample_weight=sample_weight)
+
+ # # Check that predict_proba returns valid probabilities in the presence of
+ # # samples with negative weight
+ # X = iris.data
+ # y = iris.target
+
+ # sample_weight = rng.normal(.5, 1.0, X.shape[0])
+ # clf = tree.DecisionTreeClassifier(random_state=1)
+ # clf.fit(X, y, sample_weight=sample_weight)
+ # proba = clf.predict_proba(X)
+ # assert (proba >= 0).all() and (proba <= 1).all()
@ogrisel
ogrisel Jul 13, 2013 Member

Why are those tests commented out? Negative weights do no longer work? Will this change of behavior cause issues for some users? What was the use case for negative weights in the first place?

@glouppe
glouppe Jul 13, 2013 Member

This is something to be discussed. The branch does not explicitly checks whether weights are negative. In master, this requires some checks here and there in criteria to avoid undefined situations. I still find that treatment of negative sample weights to be somehow fishy though... Unless there is a strong demand for it, I would prefer leave it like this. (I still can't wrap my mind around "negative weights" and still don't understand how someone could exploit that in an educated way...)

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/tree/tree.py
random_state)
+ if min_density is not None:
+ warn("The min_density parameter is deprecated and will be removed "
+ "in 0.15.", DeprecationWarning)
+
+ if compute_importances is not None:
+ warn("Setting compute_importances is no longer "
+ "required. Variable importances are now computed on the fly "
+ "when accessing the feature_importances_ attribute. This "
+ "parameter will be removed in 0.15.", DeprecationWarning)
@ogrisel
ogrisel Jul 13, 2013 Member

Same remark for the deprecation versions.

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/tree/tree.py
random_state)
+ if min_density is not None:
+ warn("The min_density parameter is deprecated and will be removed "
+ "in 0.15.", DeprecationWarning)
+
+ if compute_importances is not None:
+ warn("Setting compute_importances is no longer "
+ "required. Variable importances are now computed on the fly "
+ "when accessing the feature_importances_ attribute. This "
+ "parameter will be removed in 0.15.", DeprecationWarning)
@ogrisel
ogrisel Jul 13, 2013 Member

here again

@ogrisel ogrisel commented on an outdated diff Jul 13, 2013
sklearn/tree/tree.py
random_state)
+ if min_density is not None:
+ warn("The min_density parameter is deprecated and will be removed "
+ "in 0.15.", DeprecationWarning)
+
+ if compute_importances is not None:
+ warn("Setting compute_importances is no longer "
+ "required. Variable importances are now computed on the fly "
+ "when accessing the feature_importances_ attribute. This "
+ "parameter will be removed in 0.15.", DeprecationWarning)
@ogrisel
ogrisel Jul 13, 2013 Member

Here again. Maybe this deprecation message should be put in 2 module level constants and reused everywhere to avoid duplication.

@ogrisel
Member
ogrisel commented Jul 13, 2013

For the style:

  • there are many lines in _tree.pyx that break the 80 columns boundary
  • the are if <condition>: <execution statement> onliners. The execution statement should be on its own line.
  • I prefer multiline statement wrapped in parens instead of using the "" escape for the end of line char:
            is_leaf = (depth >= self.max_depth) or \
                      (n_node_samples < self.min_samples_split) or \
                      (n_node_samples < 2 * self.min_samples_leaf)

could be rewritten:

            is_leaf = ((depth >= self.max_depth) or
                       (n_node_samples < self.min_samples_split) or
                       (n_node_samples < 2 * self.min_samples_leaf))
  • out = np.zeros((n_samples, ), dtype=np.int32) should be out = np.zeros((n_samples,), dtype=np.int32)

Other notes:

  • I don't understand the docstring: "Resize all inner arrays to capacity, if < 0 double capacity."
  • The inline cython functions rand_double and rand_int use the rand function from the libc which uses an external singleton PRNG that is seeded with srand(self.random_state.randint(0, RAND_MAX)) where random_state is an instance of the from PRNG numpy. Using the external singleton from the libc would make threaded random forest non-deterministic because of the race condition on the singleton. If there is no significant penalty I would rather handle the state of the libc rng explicit in each tree instance by using rand_r instead of rand:

"""
The function rand() is not reentrant or thread-safe, since it uses hidden state that is modified on each call. This might just be the seed value to be used by the next call, or it might be something more elaborate. In order to get reproducible behavior in a threaded application, this state must be made explicit; this can be done using the reentrant function rand_r().

Like rand(), rand_r() returns a pseudo-random integer in the range [0, RAND_MAX]. The seedp argument is a pointer to an unsigned int that is used to store state between calls. If rand_r() is called with the same initial value for the integer pointed to by seedp, and that value is not modified between calls, then the same pseudo-random sequence will result.
"""
taken from http://linux.die.net/man/3/srand

Alternatively one could try to use the C API of the numpy rng instance (which is Mersenne twister implementation) that might have better statistical properties than the libc rand and rand_r functions and also not subject to platform specific behaviors.

@ogrisel
Member
ogrisel commented Jul 13, 2013

I confirm that libc.rand and libc.rand_r do not have the same range under windows and linux: https://bitbucket.org/haypo/hasard/src/tip/doc/engine_list.rst:

"""
RAND_MAX is 32767 on Windows or 2147483647 on Linux.
"""

I think this is a not that a big issue (less than not using the reentrant rand_r version).

@ogrisel
Member
ogrisel commented Jul 13, 2013

Actually to get both cross-platform homogeneity, re-entrance and larger internal state we should probably use:

http://linux.die.net/man/3/drand48_r

and init the state struct with srand48_r from with a random_state.randint.

@glouppe
Member
glouppe commented Jul 14, 2013

Thanks for the report @ogrisel and good catch for the random number generator. I will fix all of those tomorrow.

@glouppe
Member
glouppe commented Jul 14, 2013

The inline cython functions rand_double and rand_int use the rand function from the libc which uses an external singleton PRNG that is seeded with srand(self.random_state.randint(0, RAND_MAX)) where random_state is an instance of the from PRNG numpy. Using the external singleton from the libc would make threaded random forest non-deterministic because of the race condition on the singleton. If there is no significant penalty I would rather handle the state of the libc rng explicit in each tree instance by using rand_r instead of rand:

I agree, but I have one comment though : there is no multi-threading in Python. Since we fork processes (and not threads) to build trees, each one of them should have its own libc random state, shouldn't they? (I can make the changes anyway though)

@arjoly arjoly commented on an outdated diff Jul 14, 2013
sklearn/ensemble/forest.py
@@ -307,17 +285,8 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Out of bag estimation only available"
" if bootstrap=True")
- sample_mask = np.ones((n_samples,), dtype=np.bool)
-
n_jobs, _, starts = _partition_features(self, self.n_features_)
@arjoly
arjoly Jul 14, 2013 Member

Is there still a use for this function?

@GaelVaroquaux
Member

I agree, but I have one comment though : there is no multi-threading in Python.

There is. We tend not to use it, but there is. Now, if you don't release
the GIL in your code, I believe that it won't be executed in parallel, so
you might be safe.

@ogrisel
Member
ogrisel commented Jul 14, 2013

But it s quite easy to release the gil in perf critical cython loops so
adding threading support to joblib might help us fix the memory issue and
the openblas segfaults at the same time. Hence thread safety in the tree
code might actually be very relevant. And it s not that complicated to
implement with yhe reentrant libc API.

@glouppe
Member
glouppe commented Jul 15, 2013

the are if : onliners. The execution statement should be on its own line

Could you point out the files please? I don't find any in _tree.pyx.

@arjoly arjoly commented on the diff Jul 15, 2013
sklearn/ensemble/_gradient_boosting.pyx
@@ -286,3 +303,39 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X,
if not (0.999 < total_weight < 1.001):
raise ValueError("Total weight should be 1.0 but was %.9f" %
total_weight)
+
+
+def _random_sample_mask(int n_total_samples, int n_total_in_bag, random_state):
@arjoly
arjoly Jul 15, 2013 Member

This should go in sklearn.utils.random.

@arjoly
arjoly Jul 15, 2013 Member

Is there a test for this function?

@glouppe
glouppe Jul 15, 2013 Member

Okay, but I have a strange bug. Cython crashes when I try to compile random.pyx. Even compiling the master version of that file does not work.

  File "/usr/lib/python2.7/cgi.py", line 51, in <module>
    import mimetools
  File "/usr/lib/python2.7/mimetools.py", line 6, in <module>
    import tempfile
  File "/usr/lib/python2.7/tempfile.py", line 34, in <module>
    from random import Random as _Random
ImportError: cannot import name Random

I am using Cython 0.20-dev.

@ogrisel
ogrisel Jul 15, 2013 Member

Maybe this is a cython regression. Can you try with cython 19?

@glouppe
glouppe Jul 15, 2013 Member

Same bug (0.19.1).

@ogrisel
ogrisel Jul 15, 2013 Member

I cannot reproduce using cython 0.19.1 on the master branch:

$ cython -a sklearn/utils/random.pyx
$ make clean inplace

=> no build error.

@glouppe
glouppe Jul 15, 2013 Member

Hmm, indeed, when building sklearn/utils/random.pyx, it works, but not when building the file from inside the sklearn/utils directory.

@glouppe
glouppe Jul 15, 2013 Member

Meh, there is no pxd file for random.pyx, which makes it not importable from Cython files... Will fix that later.

@ogrisel
Member
ogrisel commented Jul 15, 2013

Could you point out the files please? I don't find any in _tree.pyx.

From line 1344 to 1362.

@ogrisel
Member
ogrisel commented Jul 15, 2013

Why not use drand48_r with srand48_r for seeding? That should ensure uniform behavior independently of the host platform: otherwise a random forest trained under windows will not be reproducible under linux despite using the same random state seed and the same data.

@glouppe
Member
glouppe commented Jul 15, 2013

@ogrisel Because you said it was fine with rand_r and because drand48_r looked more complicated to use :-) (programmers are lazy, you should know that ;)) I'll have a look.

@ogrisel
Member
ogrisel commented Jul 15, 2013

Well the only difference is the allocation of the rng state struct which is no longer an integer. It's quite easy to do in cython:

http://wiki.cython.org/DynamicMemoryAllocation

However I am not sure where to put the call to free. Is the Splitter instance only called during the fit method (that would make sense)? If so, I think it should be cinitialized at the beginning of the fit and then destroyed explicitly to free the memory.

@ogrisel
Member
ogrisel commented Jul 15, 2013

Actually you could just put the call to free in a __dealloc__ function and let the gc call it in time:

http://docs.cython.org/src/userguide/special_methods.html#finalization-method-dealloc

@larsmans
Member

According to my Linux manpages, drand48_r is a GNU-specific extension, so it's bound to not even be available on Windows.

@larsmans
Member

As for rand_r, that's in POSIX, but marked obsolete.

@glouppe
Member
glouppe commented Jul 15, 2013

In man 3 drand48, there is also this note:

NOTES
       These  functions  are declared obsolete by SVID 3, which states that rand(3) should be
       used instead.
@ogrisel
Member
ogrisel commented Jul 15, 2013

@larsmans thanks for the hint. Let us stick with rand_r then...

@larsmans
Member

rand_r is not available on MSVC, or at least not on all versions.

@ogrisel
Member
ogrisel commented Jul 15, 2013

rand_r is not available on MSVC, or at least not on all versions.

Argl. We should find a way to use the C API of the numpy rng then, or CPython's rng if easier.

@larsmans
Member

Would something like this be of use?

cimport cython
cimport numpy as np
from libc.string cimport memcpy


cdef enum:
    RNG_Bufsize = 1024


cdef class RNG:
    cdef double buf[RNG_Bufsize]
    cdef Py_ssize_t idx
    cdef object state

    def __cinit__(self, seed):
        self.idx = RNG_Bufsize
        self.state = seed

    @cython.cdivision
    cdef double flt(self) with gil:
        if self.idx == RNG_Bufsize:
            new = self.state(RNG_Bufsize)
            memcpy(self.buf, <void *>new.data, sizeof(self.buf))
        else:
            r = self.buf[self.idx]
            self.idx += 1

        return r

I haven't checked how fast this is -- it'll depend on how long it takes to obtain and release the GIL.

@glouppe
Member
glouppe commented Jul 15, 2013

@larsmans Yes, that would work, but our goal is to make the code gil-free (this is not yet the case, but we are close to it) in order to leave us the possibility to multi-thread the code in the future. I really don't know what is best though...

@ogrisel
Member
ogrisel commented Jul 15, 2013

Alright, let's use rand for now and address the issue of the thread safety of the rng later in another PR. We can always maintain our own minimalistic cython wrapper for randomkit later in scikit-learn.

@ogrisel
Member
ogrisel commented Jul 15, 2013

randomkit is the C library that provides the Mersenne twister implementation used in numpy:

https://github.com/numpy/numpy/tree/master/numpy/random/mtrand

@glouppe
Member
glouppe commented Jul 15, 2013

Alright, let's use rand for now and address the issue of the thread safety of the rng later in another PR

Done.

@glouppe
Member
glouppe commented Jul 15, 2013

@ogrisel I think I have addressed all your comments above.

  • Regarding long lines, I have broken most of them, but a few are still there. I let them there on purpose because I find breaking them actually impairs readability.
  • Regarding random number generation, I think we are still on the safe side, since joblib is based on multiprocessing, and all processes should thus have their own srand state. Let's look at a thread-safe gil-free random number generator during the sprint.
@ogrisel
Member
ogrisel commented Jul 15, 2013

+1, I also asked for advice on the cython users mailing list in the mean time.

I think I am ok with the current state of this PR. I would be great to get a +1 review of the _tree.pyx code by another DT expert before merging.

@arjoly arjoly commented on the diff Jul 15, 2013
sklearn/ensemble/_gradient_boosting.pyx
+ sample_mask : np.ndarray, shape=[n_total_samples]
+ An ndarray where ``n_total_in_bag`` elements are set to ``True``
+ the others are ``False``.
+ """
+ cdef np.ndarray[float64, ndim=1, mode="c"] rand = \
+ random_state.rand(n_total_samples)
+ cdef np.ndarray[int8, ndim=1, mode="c"] sample_mask = \
+ np_zeros((n_total_samples,), dtype=np_int8)
+
+ cdef int n_bagged = 0
+ cdef int i = 0
+
+ for i from 0 <= i < n_total_samples:
+ if rand[i] * (n_total_samples - i) < (n_total_in_bag - n_bagged):
+ sample_mask[i] = 1
+ n_bagged += 1
@arjoly
arjoly Jul 15, 2013 Member

What is the expected distribution of an element of sample_mask?

@mblondel
Member

In lightning, I copied a small subset of the random kit library used in NumPy and its Cython interface https://github.com/mblondel/lightning/blob/master/lightning/random/random_fast.pyx

glouppe and others added some commits Jul 7, 2013
@glouppe glouppe FIX: fix test_random_hasher 7880b7d
@glouppe glouppe WIP: fix adaboost 5e5dd90
@glouppe glouppe WIP: small optim to regression criterion 5fd21a1
@glouppe glouppe WIP: optimize tree construction procedure faeea4b
@glouppe glouppe WIP: optimization of the tree construction procedure 6de7739
@glouppe glouppe cleanup 561a6b5
@glouppe glouppe recompile _tree.pyx 2debc7a
@glouppe glouppe FIX: export_graphviz test 167bbec
@glouppe glouppe FIX: set random_state in adaboost c806cdf
@glouppe glouppe FIX: doctests 314ffce
@glouppe glouppe FIX: doctests in partial_dependence 93806cf
@glouppe glouppe FIX: feature_selection doctest ead9631
@glouppe glouppe FIX: feature_selection doctest (bis) 2d69d4c
@glouppe glouppe WIP: allow Splitter objects to be passed in constructors e66224f
@glouppe glouppe FIX ca3cf08
@glouppe glouppe Some PEP8 / Flake8 563f907
@glouppe glouppe Small optimization to RandomSplitter 7f9a595
@glouppe glouppe FIX: fix RandomSplitter 9b62f3c
@glouppe glouppe Cosmit fa34848
@glouppe glouppe FIX: free old structures 950622b
@glouppe glouppe WIP: Added BreimanSplitter 8f18d96
@glouppe glouppe WIP: small optimizations 3f540da
@glouppe glouppe WIP: fix BreimanSplitter b398083
@glouppe glouppe Cleanup e5d0416
@glouppe glouppe WIP: optimize swaps 730d6b0
@glouppe glouppe Regenerate _tree.c 2ee32a2
@glouppe glouppe WIP: some optimizations to criteria a3a7fc8
@glouppe glouppe WIP: add -O3 to setup.py a30fa20
@glouppe glouppe WIP: normalize option for compute_feature_importances 13bb04b
@glouppe glouppe WIP: Added deprecations in tree.py 2027f89
@glouppe glouppe WIP: updated documentation in tree.py 0bd3f50
@glouppe glouppe WIP: added deprecations in forest.py 5170345
@glouppe glouppe WIP: updated documentation 79953a0
@glouppe glouppe WIP: unroll loops 7785287
@glouppe glouppe WIP: setup.py eff6d64
@glouppe glouppe WIP: make sort a function, not a method 42af22d
@glouppe glouppe WIP: Cleaner Splitter interface 2e71ba3
@glouppe glouppe WIP: even cleaner splitter interface e2363c9
@glouppe glouppe WIP: some optimization in criteria 46d43e8
@glouppe glouppe WIP: remove some left-out comments 5d37cbd
@glouppe glouppe WIP: declare weighted_n_node_samples b1848d5
@glouppe glouppe WIP: better swaps ad8b6b9
@glouppe glouppe WIP: remove BreimanSplitter 247172e
@glouppe glouppe WIP: small optimization to predict 8499b11
@glouppe glouppe WIP: catch ValueError only 9b5cafe
@glouppe glouppe WIP: added some documentation details in _tree.pxd a2ea591
@glouppe glouppe WIP: PEP8 a few things 52b193a
@glouppe glouppe Benchmark: use default values in forests 01081b5
@glouppe glouppe WIP: remove irrelevant and unstable doctests f84e463
@glouppe glouppe WIP: address @ogrisel comments 8f0fd21
@glouppe glouppe WIP: address @ogrisel comments (2) 10a4155
@glouppe glouppe WIP: remove partition_features d1e72c5
@glouppe glouppe WIP: style in _tree.pyx 95f5df0
@glouppe glouppe WIP: make resize a private method, improve docstring cf19e1b
@glouppe glouppe WIP: use re-entrant rand_r 4c032ac
@glouppe glouppe FIX: doctest in partial_dependence 4e3d8f0
@glouppe glouppe WIP: break or shorten some long lines 24cb60d
@glouppe glouppe FIX: doctest in feature_selection 570bca0
@glouppe glouppe WIP: break one-liner if statements b564098
@glouppe glouppe WIP: revert use of rand_r 0c79587
@larsmans @glouppe larsmans ENH back-port rand_r from 4.4BSD
This function is not available on Windows and is deprecated by POSIX.
Fetched from OpenBSD:
http://www.openbsd.org/cgi-bin/cvsweb/src/lib/libc/stdlib/rand.c

This reverts commit 6f737ad.
c7266ec
@larsmans @glouppe larsmans FIX move rand_r to tree module for now
Simple build and link without an intermediate Cython module.
3b81aee
@glouppe glouppe FIX: broken tests based on rng a398577
@glouppe glouppe DOC: update header in rand_r.c 0be3a70
@glouppe glouppe TEST: skip test in feature_selection (too unstable) c1b93ed
@glouppe glouppe FIX: one more doctest a42701e
@glouppe glouppe WIP: Faster predictions if n_outputs==1 dc94d14
@glouppe glouppe WIP: Break comments on new line cb994bf
@glouppe glouppe WIP: make criteria nogil ready cc3f2cf
@glouppe glouppe WIP: enforce contiguous arrays to optimize construction bbdccca
@glouppe glouppe WIP: avoid data conversion in AdaBoost e0dcedd
@glouppe glouppe WIP: use np.ascontiguousarray instead of array2d f03095a
@glouppe glouppe TEST: add test_memory_layout 4e4316c
@glouppe glouppe FIX: broken test ed9d503
@glouppe glouppe WIP: Make trees and forests support string labels 2c3861f
@glouppe glouppe WIP: refactor some code in forest.fit 1bd7c8d
@glouppe glouppe TEST: skip doctest in feature_selection (unstable) 757de67
@glouppe glouppe WIP: better check inputs 8f9f2e4
@glouppe glouppe WIP: check inputs for gbrt 80dba53
@glouppe
Member
glouppe commented Jul 22, 2013

This PR has been rebased on top of master. Both @pprett and @arjoly gave their +1, @ogrisel also had a look earlier.

I think this is ready to be merged. Shall I click the green button?

@GaelVaroquaux
Member

I think this is ready to be merged. Shall I click the green button?

Go for it. Hurray, hurray, hurray

@glouppe glouppe merged commit 239054d into scikit-learn:master Jul 22, 2013

1 check passed

default The Travis CI build passed
Details
@glouppe
Member
glouppe commented Jul 22, 2013

Done :)

@pprett
Member
pprett commented Jul 22, 2013

awesome!

@GaelVaroquaux
Member

F####ing aye!

@arjoly
Member
arjoly commented Jul 22, 2013

Congratulations !!!

@satra
Member
satra commented Jul 22, 2013

fantastic work guys!

@mblondel
Member

๐Ÿป

@pprett pprett referenced this pull request Jul 22, 2013
Closed

Random Forest Performance #1435

@amueller
Member

Grats! awesome work!

@amueller
Member

it seems there is no whatsnew entry. I'm pretty sure there should be one ;)

@glouppe
Member
glouppe commented Jul 22, 2013

Just did in c0b35eb

Thanks for the support guys!

@amueller
Member

btw did you end up using funroll_all_loop? Afaik it should never be used (it unrolls loops the length of which is not known at compile time) and usually slows down things. - sorry for my late feedback, I'm at least three month behind on sklearn :-/

@amueller amueller commented on the diff Jul 22, 2013
sklearn/ensemble/weight_boosting.py
"""Implement a single boost using the SAMME.R real algorithm."""
estimator = self._make_estimator()
- if X_argsorted is not None:
- estimator.fit(X, y, sample_weight=sample_weight,
- X_argsorted=X_argsorted)
- else:
- estimator.fit(X, y, sample_weight=sample_weight)
+ try:
+ estimator.set_params(random_state=self.random_state)
@amueller
amueller Jul 22, 2013 Member

there is actually a helper function in utils.testing that does that, but it's probably not really worth using here.

@glouppe
Member
glouppe commented Jul 22, 2013

Indeed, this was still lying around. I have now removed it.

On 22 July 2013 15:19, Andreas Mueller notifications@github.com wrote:

btw did you end up using funroll_all_loop? Afaik it should never be used
(it unrolls loops the length of which is not known at compile time) and
usually slows down things. - sorry for my late feedback, I'm at least three
month behind on sklearn :-/

โ€”
Reply to this email directly or view it on GitHubhttps://github.com/scikit-learn/scikit-learn/pull/2131#issuecomment-21343067
.

@jnothman
Member

Well done! But I notice that all the attribute documentation for Tree has disappeared...

@jnothman jnothman referenced this pull request Jul 25, 2013
Closed

DOC docstring for Tree #2215

@ndawe
Member
ndawe commented Aug 22, 2013

Great work @glouppe! Just getting caught up on sklearn developments. Excited to see the speed improvements in my analysis framework!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment