Skip to content

[MRG+1] #3364 warm start in random forests #3409

Closed
wants to merge 3 commits into from

6 participants

@ldirer
ldirer commented Jul 17, 2014

Related to issue #3364.

@arjoly arjoly and 1 other commented on an outdated diff Jul 17, 2014
sklearn/ensemble/forest.py
@@ -252,23 +253,43 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Out of bag estimation only available"
" if bootstrap=True")
+ random_state = check_random_state(self.random_state)
+
+ if not self.warm_start:
+ # Free allocated memory, if any
+ self.estimators_ = []
+
+ n_old_estimators = len(self.estimators_)
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

Do we need this variable?

@ldirer
ldirer added a note Jul 17, 2014

n_old_estimators is used just 3 times but I think it makes the code easier to read than replacing it with len(self.estimators_).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 17, 2014
sklearn/ensemble/forest.py
@@ -252,23 +253,43 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Out of bag estimation only available"
" if bootstrap=True")
+ random_state = check_random_state(self.random_state)
+
+ if not self.warm_start:
+ # Free allocated memory, if any
+ self.estimators_ = []
+
+ n_old_estimators = len(self.estimators_)
+ n_new_estimators = self.n_estimators - n_old_estimators
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

For clarity, I would rename n_new_estimators to n_more_estimators

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly and 1 other commented on an outdated diff Jul 17, 2014
sklearn/ensemble/forest.py
+ if not self.warm_start:
+ # Free allocated memory, if any
+ self.estimators_ = []
+
+ n_old_estimators = len(self.estimators_)
+ n_new_estimators = self.n_estimators - n_old_estimators
+
+ if self.warm_start and n_old_estimators > 0:
+ # We draw from the random state to get the random state we would
+ # have got if we hadn't used a warm_start.
+ balancing_draw = random_state.randint(MAX_INT,
+ size=n_old_estimators)
+
+ if n_new_estimators < 0:
+ raise ValueError('n_estimators=%d must be larger or equal to '
+ 'estimators_.shape[0]=%d when '
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

estimators_ is a list not an array. Thus this would be len(estimators_ )

@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

I am surprise that it works.

@ldirer
ldirer added a note Jul 17, 2014

Indeed it doesn't work.
I just copy-pasted the ValueError from the GB class, and I did not test for that behavior.
I will write an additional test for this if that makes sense.

@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

Hm, it's not tested apparently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly and 1 other commented on an outdated diff Jul 17, 2014
sklearn/ensemble/forest.py
@@ -279,7 +300,7 @@ def fit(self, X, y, sample_weight=None):
for i in range(n_jobs))
# Reduce
- self.estimators_ = list(itertools.chain(*all_trees))
+ self.estimators_ += list(itertools.chain(*all_new_trees))
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

I would use self.estimators_.extend(list(itertools.chain.from_iterable(all_new_trees))) here.

I am not sure that the list(...) is still useful.

@ldirer
ldirer added a note Jul 17, 2014

Indeed it seems the list has no effect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 17, 2014
sklearn/ensemble/tests/test_forest.py
@@ -670,6 +670,37 @@ def test_1d_input():
yield check_1d_input, name, X, X_2d, y
+def check_warm_start(name, X, y, n_trees):
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

n_trees -> n_estimators

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@coveralls

Coverage Status

Coverage increased (+0.0%) when pulling 8aaf707 on ldirer:3364_warm_start into 814a3ea on scikit-learn:master.

@arjoly arjoly commented on an outdated diff Jul 17, 2014
sklearn/ensemble/tests/test_forest.py
@@ -670,6 +670,37 @@ def test_1d_input():
yield check_1d_input, name, X, X_2d, y
+def check_warm_start(name, X, y, n_trees):
+ """Test if fitting incrementally with warm start gives the same results
+ as a normal fit."""
+ ForestEstimator = FOREST_ESTIMATORS[name]
+ seed = 42
+ clf_warm_start = None
+ for n_estimators in np.append(np.arange(1, n_trees, n_trees//2), n_trees):
+ if clf_warm_start is None:
+ clf_warm_start = ForestEstimator(n_estimators=n_estimators,
+ random_state=seed, n_jobs=1,
+ warm_start=True)
+ else:
+ clf_warm_start.n_estimators = n_estimators
+ clf_warm_start.fit(X, y)
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

Can you check that len(clf_warm_start) increased appropriately?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 17, 2014
sklearn/ensemble/tests/test_forest.py
@@ -670,6 +670,37 @@ def test_1d_input():
yield check_1d_input, name, X, X_2d, y
+def check_warm_start(name, X, y, n_trees):
+ """Test if fitting incrementally with warm start gives the same results
+ as a normal fit."""
+ ForestEstimator = FOREST_ESTIMATORS[name]
+ seed = 42
+ clf_warm_start = None
+ for n_estimators in np.append(np.arange(1, n_trees, n_trees//2), n_trees):
+ if clf_warm_start is None:
+ clf_warm_start = ForestEstimator(n_estimators=n_estimators,
+ random_state=seed, n_jobs=1,
+ warm_start=True)
+ else:
+ clf_warm_start.n_estimators = n_estimators
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

Can you use set_params?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 17, 2014
sklearn/ensemble/forest.py
@@ -252,23 +253,43 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Out of bag estimation only available"
" if bootstrap=True")
+ random_state = check_random_state(self.random_state)
+
+ if not self.warm_start:
+ # Free allocated memory, if any
+ self.estimators_ = []
+
+ n_old_estimators = len(self.estimators_)
+ n_new_estimators = self.n_estimators - n_old_estimators
+
+ if self.warm_start and n_old_estimators > 0:
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

I would move this just before the generation of trees.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@ldirer
ldirer commented Jul 17, 2014

I made new commits including changes from all but the first of the comments above.

@arjoly
scikit-learn member
arjoly commented Jul 17, 2014

Can you document the new attribute?

@ldirer
ldirer commented Jul 17, 2014

I documented the new attribute and added some tests.

@arjoly arjoly commented on an outdated diff Jul 17, 2014
sklearn/ensemble/forest.py
@@ -252,23 +253,45 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Out of bag estimation only available"
" if bootstrap=True")
+ random_state = check_random_state(self.random_state)
+
+ if not self.warm_start:
+ # Free allocated memory, if any
+ self.estimators_ = []
+
+ n_more_estimators = self.n_estimators - len(self.estimators_)
+
+ if n_more_estimators < 0:
+ raise ValueError('n_estimators=%d must be larger or equal to '
+ 'len(estimators_)=%d when '
+ 'warm_start==True'
+ % (self.n_estimators,
+ len(self.estimators_)))
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

could be put in fewer lines

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 17, 2014
sklearn/ensemble/forest.py
@@ -868,6 +904,11 @@ class RandomForestRegressor(ForestRegressor):
verbose : int, optional (default=0)
Controls the verbosity of the tree building process.
+ warm_start : bool, default: False
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

Could you use the same style as other arguments?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@coveralls

Coverage Status

Coverage decreased (-0.01%) when pulling 7f1b3b9 on ldirer:3364_warm_start into 814a3ea on scikit-learn:master.

@arjoly arjoly commented on an outdated diff Jul 17, 2014
sklearn/ensemble/tests/test_forest.py
+ err_msg="Failed with {0}".format(name))
+
+
+def test_warm_start():
+ X, y = datasets.make_hastie_10_2(n_samples=80, random_state=1)
+ X = X.astype(np.float32)
+ for name in FOREST_ESTIMATORS:
+ yield check_warm_start, name, X, y
+
+
+def check_warm_start_zero_n_estimators(name, X, y):
+ """Test if warm start with zero n_estimators raises error """
+ ForestEstimator = FOREST_ESTIMATORS[name]
+ clf = ForestEstimator(n_estimators=100, max_depth=1, warm_start=True)
+ clf.fit(X, y)
+ clf.set_params(n_estimators=0)
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

Can you set this to a non zero value?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 17, 2014
sklearn/ensemble/tests/test_forest.py
+def test_warm_start_zero_n_estimators():
+ X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
+ X = X.astype(np.float32)
+ for name in FOREST_ESTIMATORS:
+ yield check_warm_start_zero_n_estimators, name, X, y
+
+
+def check_warm_start_clear(name, X, y):
+ """Test if fit clears state and grows a new forest when warm_start==False.
+ """
+ ForestEstimator = FOREST_ESTIMATORS[name]
+ clf = ForestEstimator(n_estimators=20, max_depth=1, warm_start=False,
+ random_state=1)
+ clf.fit(X, y)
+
+ clf_2 = ForestEstimator(n_estimators=20, max_depth=1, warm_start=True,
@arjoly
scikit-learn member
arjoly added a note Jul 17, 2014

Can you try with n_estimators < 20?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@ogrisel ogrisel commented on an outdated diff Jul 17, 2014
sklearn/ensemble/bagging.py
@@ -277,7 +277,8 @@ def fit(self, X, y, sample_weight=None):
self.estimators_ = None
# Parallel loop
- n_jobs, n_estimators, starts = _partition_estimators(self)
+ n_jobs, n_estimators, starts = _partition_estimators(
+ n_estimators=self.n_estimators, n_jobs=self.n_jobs)
@ogrisel
scikit-learn member
ogrisel added a note Jul 17, 2014

There is no need to use the kwargs notation here as the argument names are already as descriptive as the can be.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@ogrisel ogrisel and 2 others commented on an outdated diff Jul 17, 2014
sklearn/ensemble/forest.py
@@ -252,23 +253,43 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Out of bag estimation only available"
" if bootstrap=True")
+ random_state = check_random_state(self.random_state)
+
+ if not self.warm_start:
+ # Free allocated memory, if any
+ self.estimators_ = []
+
+ n_more_estimators = self.n_estimators - len(self.estimators_)
+
+ if n_more_estimators < 0:
+ raise ValueError('n_estimators=%d must be larger or equal to '
+ 'len(estimators_)=%d when warm_start==True'
+ % (self.n_estimators, len(self.estimators_)))
+
@ogrisel
scikit-learn member
ogrisel added a note Jul 17, 2014

Maybe we would also have:

if n_more_estimators == 0:
    warnings.warm("Warm fitting without increasing n_estimators does not fit new trees.")

and a test that checks that the warning is actually raised in that case.

@mblondel
scikit-learn member
mblondel added a note Jul 17, 2014

I think I would raise the ValueError exception when n_more_estimators <= 0, not n_more_estimators < 0.

@ogrisel
scikit-learn member
ogrisel added a note Jul 18, 2014

I think I would prefer a warning. I don't think just the following sequence should raise an exception in general:

model = SomeModel()
model.fit(X, y)
model.set_params(warm_start=true)
model.fit(X, y)
@arjoly
scikit-learn member
arjoly added a note Jul 18, 2014

I would go for a warning if the number of estimators doesn't change and for an exception if it changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@ogrisel ogrisel commented on an outdated diff Jul 17, 2014
sklearn/ensemble/forest.py
@@ -252,23 +253,43 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Out of bag estimation only available"
" if bootstrap=True")
+ random_state = check_random_state(self.random_state)
+
+ if not self.warm_start:
+ # Free allocated memory, if any
+ self.estimators_ = []
+
+ n_more_estimators = self.n_estimators - len(self.estimators_)
+
+ if n_more_estimators < 0:
+ raise ValueError('n_estimators=%d must be larger or equal to '
+ 'len(estimators_)=%d when warm_start==True'
+ % (self.n_estimators, len(self.estimators_)))
+
+ if self.warm_start and n_more_estimators == 0:
@ogrisel
scikit-learn member
ogrisel added a note Jul 17, 2014

Actually the warning should be raised here I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@ogrisel
scikit-learn member
ogrisel commented Jul 17, 2014

@arjoly @mblondel do you think we should add the warning? Other than that and the previous style comment, +1 on my side.

@ogrisel
scikit-learn member
ogrisel commented Jul 17, 2014

BTW @ldirer you should edit the title of the PR to prefix with:

  • "[WIP]" to tell others that you are still working on some changes and this is not ready for final review yet
  • "[MRG]" to tell others that your are done with what you had planned and ask for reviewers to comment on this or merge.
@coveralls

Coverage Status

Coverage decreased (-0.01%) when pulling 38acb41 on ldirer:3364_warm_start into 814a3ea on scikit-learn:master.

@ldirer ldirer changed the title from #3364 warm start in random forests to [WIP] #3364 warm start in random forests Jul 17, 2014
@coveralls

Coverage Status

Coverage decreased (-0.01%) when pulling 38acb41 on ldirer:3364_warm_start into 814a3ea on scikit-learn:master.

@ldirer
ldirer commented Jul 17, 2014

There is another issue not (yet) covered by any tests:

clf_warm_start.fit(X, y)
clf_warm_start.predict(X)
clf_warm_start.set_params(oob_score=True)
clf_warm_start.fit(X, y)
clf_warm_start.predict(X) 

I would say the expected behavior is that the oob_score is computed from the existing trees.
Currently the above code actually raises an error.

@ogrisel
scikit-learn member
ogrisel commented Jul 17, 2014

I would say the expected behavior is that the oob_score is computed from the existing trees.

Sounds reasonable. @glouppe @arjoly @mblondel what is your opinion?

@ldirer ldirer changed the title from [WIP] #3364 warm start in random forests to [MRG] #3364 warm start in random forests Jul 18, 2014
@ldirer ldirer changed the title from [MRG] #3364 warm start in random forests to [WIP] #3364 warm start in random forests Jul 18, 2014
@coveralls

Coverage Status

Coverage increased (+0.06%) when pulling ba0df01 on ldirer:3364_warm_start into 814a3ea on scikit-learn:master.

@ldirer ldirer changed the title from [WIP] #3364 warm start in random forests to [MRG] #3364 warm start in random forests Jul 19, 2014
@ogrisel ogrisel commented on an outdated diff Jul 19, 2014
sklearn/ensemble/forest.py
- # Reduce
- self.estimators_ = list(itertools.chain(*all_trees))
+ if n_more_estimators == 0:
+ warn("Warm fitting without increasing n_estimators does not "
@ogrisel
scikit-learn member
ogrisel added a note Jul 19, 2014

Maybe "Warm-start fitting" instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@coveralls

Coverage Status

Coverage increased (+0.01%) when pulling e644c2e on ldirer:3364_warm_start into e23d9c9 on scikit-learn:master.

@ogrisel
scikit-learn member
ogrisel commented Jul 19, 2014

Apart from this last batch of minor comments, this looks good to me.

@ogrisel ogrisel changed the title from [MRG] #3364 warm start in random forests to [MRG+1] #3364 warm start in random forests Jul 19, 2014
@arjoly
scikit-learn member
arjoly commented Jul 19, 2014

@glouppe is it good for you? an opinion?

@arjoly arjoly commented on an outdated diff Jul 20, 2014
sklearn/ensemble/tests/test_forest.py
@@ -670,6 +670,141 @@ def test_1d_input():
yield check_1d_input, name, X, X_2d, y
+def check_warm_start(name, X, y):
+ """Test if fitting incrementally with warm start gives a forest of the
+ right size and the same results as a normal fit."""
+ ForestEstimator = FOREST_ESTIMATORS[name]
+ seed = 42
+ clf_ws = None
+ for n_estimators in [5, 10]:
+ if clf_ws is None:
+ clf_ws = ForestEstimator(n_estimators=n_estimators,
+ random_state=seed, n_jobs=1,
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

You don't need to set n_jobs=1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 20, 2014
sklearn/ensemble/tests/test_forest.py
+ right size and the same results as a normal fit."""
+ ForestEstimator = FOREST_ESTIMATORS[name]
+ seed = 42
+ clf_ws = None
+ for n_estimators in [5, 10]:
+ if clf_ws is None:
+ clf_ws = ForestEstimator(n_estimators=n_estimators,
+ random_state=seed, n_jobs=1,
+ warm_start=True)
+ else:
+ clf_ws.set_params(n_estimators=n_estimators)
+ clf_ws.fit(X, y)
+ assert_equal(len(clf_ws), n_estimators)
+
+ clf_no_ws = ForestEstimator(n_estimators=10, random_state=seed,
+ n_jobs=1, warm_start=False)
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

same here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 20, 2014
sklearn/ensemble/tests/test_forest.py
@@ -670,6 +670,141 @@ def test_1d_input():
yield check_1d_input, name, X, X_2d, y
+def check_warm_start(name, X, y):
+ """Test if fitting incrementally with warm start gives a forest of the
+ right size and the same results as a normal fit."""
+ ForestEstimator = FOREST_ESTIMATORS[name]
+ seed = 42
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

You can add this as argument of the check function

check_warm_start(name, X, y, random_state=42)
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

By the way random_state is better than seed, since it's match for the parameter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 20, 2014
sklearn/ensemble/tests/test_forest.py
+ warm_start=True)
+ else:
+ clf_ws.set_params(n_estimators=n_estimators)
+ clf_ws.fit(X, y)
+ assert_equal(len(clf_ws), n_estimators)
+
+ clf_no_ws = ForestEstimator(n_estimators=10, random_state=seed,
+ n_jobs=1, warm_start=False)
+ clf_no_ws.fit(X, y)
+ assert_array_equal(clf_ws.apply(X), clf_no_ws.apply(X),
+ err_msg="Failed with {0}".format(name))
+
+
+def test_warm_start():
+ X, y = datasets.make_hastie_10_2(n_samples=80, random_state=1)
+ X = X.astype(np.float32)
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

This could go inside the check function to avoid verbose tests.

@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

You could probably reduce the number of samples to 20

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 20, 2014
sklearn/ensemble/tests/test_forest.py
+ clf = ForestEstimator(n_estimators=5, max_depth=1, warm_start=False,
+ random_state=1)
+ clf.fit(X, y)
+
+ clf_2 = ForestEstimator(n_estimators=5, max_depth=1, warm_start=True,
+ random_state=2)
+ clf_2.fit(X, y) # inits state
+ clf_2.set_params(warm_start=False, random_state=1)
+ clf_2.fit(X, y) # clears old state and equals clf
+
+ assert_array_almost_equal(clf_2.apply(X), clf.apply(X))
+
+
+def test_warm_start_clear():
+ X, y = datasets.make_hastie_10_2(n_samples=60, random_state=1)
+ X = X.astype(np.float32)
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

same here
You could probably reduce n_samples to 20

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 20, 2014
sklearn/ensemble/tests/test_forest.py
+ for name in FOREST_ESTIMATORS:
+ yield check_warm_start_clear, name, X, y
+
+
+def check_warm_start_smaller_n_estimators(name, X, y):
+ """Test if warm start second fit with smaller n_estimators raises error."""
+ ForestEstimator = FOREST_ESTIMATORS[name]
+ clf = ForestEstimator(n_estimators=5, max_depth=1, warm_start=True)
+ clf.fit(X, y)
+ clf.set_params(n_estimators=4)
+ assert_raises(ValueError, clf.fit, X, y)
+
+
+def test_warm_start_smaller_n_estimators():
+ X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
+ X = X.astype(np.float32)
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

same here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 20, 2014
sklearn/ensemble/tests/test_forest.py
+
+ clf_2 = ForestEstimator(n_estimators=5, max_depth=3, warm_start=True,
+ random_state=1)
+ clf_2.fit(X, y)
+ # Now clf_2 equals clf.
+
+ clf_2.set_params(random_state=2)
+ assert_warns(UserWarning, clf_2.fit, X, y)
+ # If we had fit the trees again we would have got a different forest as we
+ # changed the random state.
+ assert_array_equal(clf.apply(X), clf_2.apply(X))
+
+
+def test_warm_start_equal_n_estimators():
+ X, y = datasets.make_hastie_10_2(n_samples=60, random_state=1)
+ X = X.astype(np.float32)
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

same here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on the diff Jul 20, 2014
sklearn/ensemble/tests/test_forest.py
+ clf.fit(X, y)
+
+ clf_2 = ForestEstimator(n_estimators=5, max_depth=3, warm_start=False,
+ random_state=1, bootstrap=True, oob_score=False)
+ clf_2.fit(X, y)
+ clf_2.set_params(warm_start=True, oob_score=True, n_estimators=15)
+ clf_2.fit(X, y)
+
+ assert_true(hasattr(clf_2, 'oob_score_'))
+ assert_equal(clf.oob_score_, clf_2.oob_score_)
+
+ # Test that oob_score is computed even if we don't need to train
+ # additional trees.
+ clf_3 = ForestEstimator(n_estimators=15, max_depth=3, warm_start=True,
+ random_state=1, bootstrap=True, oob_score=False)
+ clf_3.fit(X, y)
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

Can you assert here that oob_score is absent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly and 2 others commented on an outdated diff Jul 20, 2014
sklearn/ensemble/tests/test_forest.py
+
+ # Test that oob_score is computed even if we don't need to train
+ # additional trees.
+ clf_3 = ForestEstimator(n_estimators=15, max_depth=3, warm_start=True,
+ random_state=1, bootstrap=True, oob_score=False)
+ clf_3.fit(X, y)
+ clf_3.set_params(oob_score=True)
+ clf_3.fit(X, y)
+
+ assert_true(hasattr(clf_3, 'oob_score_'))
+ assert_equal(clf.oob_score_, clf_3.oob_score_)
+
+
+def test_warm_start_oob():
+ X, y = datasets.make_hastie_10_2(n_samples=40, random_state=1)
+ X = X.astype(np.float32)
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

same here

@ldirer
ldirer added a note Jul 20, 2014

I was worried that the oob score could be the same just 'by chance', and that the test would not fail although it should.
If you think that 20 samples are enough or this is not a concern I'll change it.

@ogrisel
scikit-learn member
ogrisel added a note Jul 20, 2014

Why do you need astype here? It should work the default dtype.

@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

It should be enought since we are just checking that the oob score be the same.
Might I miss something?

@ldirer
ldirer added a note Jul 21, 2014

As for the 40 samples, it was more a way to make sure that there is enough 'randomness' to get different oob scores if the computation or the forest is different.
If you take too few samples it makes the probability that two different forests give the same oob error higher.
I don't think it is too relevant though and I am changing n_samples to 20, I just explain it to answer your question.

As for the astype I used it because it was used in the existing tests. I think I read somewhere in the doc that it makes some computation faster, but I am not sure it makes sense here.
Should I remove it from my tests and the existing ones?

@ogrisel
scikit-learn member
ogrisel added a note Jul 21, 2014

The RF model will convert internally to 32 bit float, so this is the same in this case and should be negligible on such a small dataset anyway.

It's only interesting to load the data directly in 32bit format if you have it stored on a hard-drive for instance to avoid the memory copy. In practice this is only important for dataset that are big enough to use a large fraction of the RAM.

@ogrisel
scikit-learn member
ogrisel added a note Jul 21, 2014

So yes please remove it from the new test. You can leave it in the other tests, it's harmless.

@ldirer
ldirer added a note Jul 21, 2014

I removed it from the new tests.
For information the cast is actually required in the two old tests that use it, removing it breaks the tests. Seems to be because they both use tree_.apply(X) at some point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@coveralls

Coverage Status

Coverage increased (+0.06%) when pulling 65cb768 on ldirer:3364_warm_start into e23d9c9 on scikit-learn:master.

@coveralls

Coverage Status

Coverage increased (+0.06%) when pulling 2f6406e on ldirer:3364_warm_start into e23d9c9 on scikit-learn:master.

@coveralls

Coverage Status

Coverage increased (+0.01%) when pulling d6a8881 on ldirer:3364_warm_start into 5517bad on scikit-learn:master.

@arjoly arjoly commented on an outdated diff Jul 20, 2014
sklearn/ensemble/forest.py
+ # making threading always more efficient than multiprocessing in
+ # that case.
+ all_new_trees = Parallel(n_jobs=n_jobs, verbose=self.verbose,
+ backend="threading")(
+ delayed(_parallel_build_trees)(
+ trees[starts[i]:starts[i + 1]],
+ self,
+ X,
+ y,
+ sample_weight,
+ verbose=self.verbose)
+ for i in range(n_jobs))
+
+ # Reduce
+ self.estimators_.extend(
+ itertools.chain.from_iterable(all_new_trees))
@arjoly
scikit-learn member
arjoly added a note Jul 20, 2014

nitpick; maybe we could import chain from itertools to reduce the number of characters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@ogrisel
scikit-learn member
ogrisel commented Jul 21, 2014

@mblondel @arjoly are you +1 besides the last nitpick comment?

I would like to remove the batching for forests in a new PR as I found via a series of benchmarks that is never faster than no-batching with the threading backend while causing a more complex code base. I would like to merge that PR first though, as otherwise rebasing will be complex and error-prone.

@mblondel
scikit-learn member

I didn't follow this PR closely. I trust @arjoly and your decision on this one :)

@arjoly arjoly commented on an outdated diff Jul 21, 2014
sklearn/ensemble/forest.py
- # Reduce
- self.estimators_ = list(itertools.chain(*all_trees))
+ if n_more_estimators == 0:
@arjoly
scikit-learn member
arjoly added a note Jul 21, 2014

elif n_more_estimators == 0:?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 21, 2014
sklearn/ensemble/tests/test_forest.py
+ warm_start=True)
+ else:
+ clf_ws.set_params(n_estimators=n_estimators)
+ clf_ws.fit(X, y)
+ assert_equal(len(clf_ws), n_estimators)
+
+ clf_no_ws = ForestEstimator(n_estimators=10, random_state=random_state,
+ warm_start=False)
+ clf_no_ws.fit(X, y)
+ assert_array_equal(clf_ws.apply(X), clf_no_ws.apply(X),
+ err_msg="Failed with {0}".format(name))
+
+
+def test_warm_start():
+ X, y = datasets.make_hastie_10_2(n_samples=20, random_state=1)
+ X = X.astype(np.float32)
@arjoly
scikit-learn member
arjoly added a note Jul 21, 2014

Would it be possible to generate the dataset inside the check function?

@arjoly
scikit-learn member
arjoly added a note Jul 21, 2014

This allows to have cleaner test output.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 21, 2014
sklearn/ensemble/tests/test_forest.py
+ clf = ForestEstimator(n_estimators=5, max_depth=1, warm_start=False,
+ random_state=1)
+ clf.fit(X, y)
+
+ clf_2 = ForestEstimator(n_estimators=5, max_depth=1, warm_start=True,
+ random_state=2)
+ clf_2.fit(X, y) # inits state
+ clf_2.set_params(warm_start=False, random_state=1)
+ clf_2.fit(X, y) # clears old state and equals clf
+
+ assert_array_almost_equal(clf_2.apply(X), clf.apply(X))
+
+
+def test_warm_start_clear():
+ X, y = datasets.make_hastie_10_2(n_samples=20, random_state=1)
+ X = X.astype(np.float32)
@arjoly
scikit-learn member
arjoly added a note Jul 21, 2014

Would it be possible to generate the dataset inside the check function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 21, 2014
sklearn/ensemble/tests/test_forest.py
+ for name in FOREST_ESTIMATORS:
+ yield check_warm_start_clear, name, X, y
+
+
+def check_warm_start_smaller_n_estimators(name, X, y):
+ """Test if warm start second fit with smaller n_estimators raises error."""
+ ForestEstimator = FOREST_ESTIMATORS[name]
+ clf = ForestEstimator(n_estimators=5, max_depth=1, warm_start=True)
+ clf.fit(X, y)
+ clf.set_params(n_estimators=4)
+ assert_raises(ValueError, clf.fit, X, y)
+
+
+def test_warm_start_smaller_n_estimators():
+ X, y = datasets.make_hastie_10_2(n_samples=20, random_state=1)
+ X = X.astype(np.float32)
@arjoly
scikit-learn member
arjoly added a note Jul 21, 2014

Would it be possible to generate the dataset inside the check function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly arjoly commented on an outdated diff Jul 21, 2014
sklearn/ensemble/tests/test_forest.py
+
+ clf_2 = ForestEstimator(n_estimators=5, max_depth=3, warm_start=True,
+ random_state=1)
+ clf_2.fit(X, y)
+ # Now clf_2 equals clf.
+
+ clf_2.set_params(random_state=2)
+ assert_warns(UserWarning, clf_2.fit, X, y)
+ # If we had fit the trees again we would have got a different forest as we
+ # changed the random state.
+ assert_array_equal(clf.apply(X), clf_2.apply(X))
+
+
+def test_warm_start_equal_n_estimators():
+ X, y = datasets.make_hastie_10_2(n_samples=20, random_state=1)
+ X = X.astype(np.float32)
@arjoly
scikit-learn member
arjoly added a note Jul 21, 2014

Would it be possible to generate the dataset inside the check function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@arjoly
scikit-learn member
arjoly commented Jul 21, 2014

LGTM when the last comment are addressed. Thanks @ldirer !

@glouppe
scikit-learn member
glouppe commented Jul 21, 2014

I'll have a look tommorrow morning on this.

@glouppe glouppe commented on an outdated diff Jul 22, 2014
sklearn/ensemble/forest.py
- # Reduce
- self.estimators_ = list(itertools.chain(*all_trees))
+ elif n_more_estimators == 0:
+ warn("Warm-start fitting without increasing n_estimators does not "
+ "fit new trees.")
+ else:
+ # Assign chunk of trees to jobs
+ n_jobs, n_trees, starts = _partition_estimators(n_more_estimators,
+ self.n_jobs)
+ trees = []
+
+ if self.warm_start and len(self.estimators_) > 0:
+ # We draw from the random state to get the random state we
+ # would have got if we hadn't used a warm_start.
+ balancing_draw = random_state.randint(
+ MAX_INT, size=len(self.estimators_))
@glouppe
scikit-learn member
glouppe added a note Jul 22, 2014

Don't name the variable balancing_draw, it is confusing. Rather, directly make a call to random_state.randint(MAX_INT, size=len(self.estimators_)) without saving the return value.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@glouppe glouppe commented on an outdated diff Jul 22, 2014
sklearn/ensemble/forest.py
@@ -706,6 +735,11 @@ class RandomForestClassifier(ForestClassifier):
verbose : int, optional (default=0)
Controls the verbosity of the tree building process.
+ warm_start : bool, optional (default=False)
+ When set to ``True``, reuse the solution of the previous call to fit
+ and add more estimators to the ensemble, otherwise, just fits a whole
@glouppe
scikit-learn member
glouppe added a note Jul 22, 2014

fits -> fit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@glouppe glouppe commented on an outdated diff Jul 22, 2014
sklearn/ensemble/forest.py
@@ -872,6 +908,11 @@ class RandomForestRegressor(ForestRegressor):
verbose : int, optional (default=0)
Controls the verbosity of the tree building process.
+ warm_start : bool, optional (default=False)
+ When set to ``True``, reuse the solution of the previous call to fit
+ and add more estimators to the ensemble, otherwise, just fits a whole
@glouppe
scikit-learn member
glouppe added a note Jul 22, 2014

fits -> fit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@glouppe glouppe commented on an outdated diff Jul 22, 2014
sklearn/ensemble/forest.py
@@ -1028,6 +1071,11 @@ class ExtraTreesClassifier(ForestClassifier):
verbose : int, optional (default=0)
Controls the verbosity of the tree building process.
+ warm_start : bool, optional (default=False)
+ When set to ``True``, reuse the solution of the previous call to fit
+ and add more estimators to the ensemble, otherwise, just fits a whole
@glouppe
scikit-learn member
glouppe added a note Jul 22, 2014

fits -> fit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@glouppe glouppe commented on an outdated diff Jul 22, 2014
sklearn/ensemble/forest.py
@@ -1198,6 +1248,11 @@ class ExtraTreesRegressor(ForestRegressor):
verbose : int, optional (default=0)
Controls the verbosity of the tree building process.
+ warm_start : bool, optional (default=False)
+ When set to ``True``, reuse the solution of the previous call to fit
+ and add more estimators to the ensemble, otherwise, just fits a whole
@glouppe
scikit-learn member
glouppe added a note Jul 22, 2014

fits -> fit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@glouppe glouppe commented on an outdated diff Jul 22, 2014
sklearn/ensemble/forest.py
@@ -1333,6 +1390,11 @@ class RandomTreesEmbedding(BaseForest):
verbose : int, optional (default=0)
Controls the verbosity of the tree building process.
+ warm_start : bool, optional (default=False)
+ When set to ``True``, reuse the solution of the previous call to fit
+ and add more estimators to the ensemble, otherwise, just fits a whole
@glouppe
scikit-learn member
glouppe added a note Jul 22, 2014

fits -> fit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@glouppe glouppe and 1 other commented on an outdated diff Jul 22, 2014
sklearn/ensemble/tests/test_forest.py
+ clf_ws = None
+ for n_estimators in [5, 10]:
+ if clf_ws is None:
+ clf_ws = ForestEstimator(n_estimators=n_estimators,
+ random_state=random_state,
+ warm_start=True)
+ else:
+ clf_ws.set_params(n_estimators=n_estimators)
+ clf_ws.fit(X, y)
+ assert_equal(len(clf_ws), n_estimators)
+
+ clf_no_ws = ForestEstimator(n_estimators=10, random_state=random_state,
+ warm_start=False)
+ clf_no_ws.fit(X, y)
+ assert_array_equal(clf_ws.apply(X), clf_no_ws.apply(X),
+ err_msg="Failed with {0}".format(name))
@glouppe
scikit-learn member
glouppe added a note Jul 22, 2014

Could you also make sure that all random_state values from clf_ws are identical to the random_state values in clf_no_ws? I.e., compare set([tree.random_state for tree in clf_ws]) with set([tree.random_state for tree in clf_no_ws])

@ogrisel
scikit-learn member
ogrisel added a note Jul 22, 2014

Good idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@glouppe
scikit-learn member
glouppe commented Jul 22, 2014

Nice addition to the forest module! Thanks for your work.

+1 for merge once my comments are taken into account.

ldirer added some commits Jul 17, 2014
@ldirer ldirer Changed _partition_estimators signature to make it compatible with wa…
…rm start.
90a28ef
@ldirer ldirer Added warm start to random forests.
	Changed _partition_estimators  signature to make it compatible with warm start.

Documented warm_start attribute.
6785ffa
@ldirer
ldirer commented Jul 22, 2014

Thanks for your comments, I included them in the last commits.

@ogrisel
scikit-learn member
ogrisel commented Jul 23, 2014

There is a real failure under Python 2.6:

======================================================================
ERROR: Failure: ImportError (cannot import name assert_set_equal)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/travis/anaconda/envs/testenv/lib/python2.6/site-packages/nose/loader.py", line 414, in loadTestsFromName
    addr.filename, addr.module)
  File "/home/travis/anaconda/envs/testenv/lib/python2.6/site-packages/nose/importer.py", line 47, in importFromPath
    return self.importFromDir(dir_path, fqname)
  File "/home/travis/anaconda/envs/testenv/lib/python2.6/site-packages/nose/importer.py", line 94, in importFromDir
    mod = load_module(part_fqname, fh, filename, desc)
  File "/home/travis/build/scikit-learn/scikit-learn/sklearn/ensemble/tests/test_forest.py", line 14, in <module>
    from nose.tools import assert_set_equal
ImportError: cannot import name assert_set_equal

----------------------------------------------------------------------

Please just use assert_equal instead.

Also please rebase on master to get the less verbose travis and squash your commits.

@ogrisel
scikit-learn member
ogrisel commented Jul 25, 2014

I will fix the test, rebase and merge to master.

@ogrisel
scikit-learn member
ogrisel commented Jul 25, 2014

Rebased and merge.

@ogrisel ogrisel closed this Jul 25, 2014
@ogrisel
scikit-learn member
ogrisel commented Jul 25, 2014

Thanks @ldirer for this great first contrib to the project!

@arjoly
scikit-learn member
arjoly commented Jul 25, 2014

Thanks @ldirer !

@ldirer
ldirer commented Jul 25, 2014

Thanks for wrapping it up!
I somehow broke my install trying to get Sphinx to work so I was not able to run the tests anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.