Skip to content

Commit

Permalink
pep8 fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
ldirer committed Jul 17, 2014
1 parent b95136c commit 7f1b3b9
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,9 @@ def check_warm_start(name, X, y):
clf_ws = None
for n_estimators in [1, 10, 20]:
if clf_ws is None:
clf_ws = ForestEstimator(n_estimators=n_estimators, random_state=seed,
n_jobs=1, warm_start=True)
clf_ws = ForestEstimator(n_estimators=n_estimators,
random_state=seed, n_jobs=1,
warm_start=True)
else:
clf_ws.set_params(n_estimators=n_estimators)
clf_ws.fit(X, y)
Expand All @@ -699,6 +700,47 @@ def test_warm_start():
yield check_warm_start, name, X, y


def check_warm_start_zero_n_estimators(name, X, y):
"""Test if warm start with zero n_estimators raises error """
ForestEstimator = FOREST_ESTIMATORS[name]
clf = ForestEstimator(n_estimators=100, max_depth=1, warm_start=True)
clf.fit(X, y)
clf.set_params(n_estimators=0)
assert_raises(ValueError, clf.fit, X, y)


def test_warm_start_zero_n_estimators():
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
X = X.astype(np.float32)
for name in FOREST_ESTIMATORS:
yield check_warm_start_zero_n_estimators, name, X, y


def check_warm_start_clear(name, X, y):
"""Test if fit clears state and grows a new forest when warm_start==False.
"""
ForestEstimator = FOREST_ESTIMATORS[name]
clf = ForestEstimator(n_estimators=20, max_depth=1, warm_start=False,
random_state=1)
clf.fit(X, y)

clf_2 = ForestEstimator(n_estimators=20, max_depth=1, warm_start=True,
random_state=None)
clf_2.fit(X, y) # inits state
# assert_array_almost_equal(clf_2.predict(X), clf.predict(X))
clf_2.set_params(warm_start=False, random_state=1)
clf_2.fit(X, y) # clears old state and equals clf

assert_array_almost_equal(clf_2.apply(X), clf.apply(X))


def test_warm_start_clear():
X, y = datasets.make_hastie_10_2(n_samples=60, random_state=1)
X = X.astype(np.float32)
for name in FOREST_ESTIMATORS:
yield check_warm_start_zero_n_estimators, name, X, y


if __name__ == "__main__":
import nose
nose.runmodule()

0 comments on commit 7f1b3b9

Please sign in to comment.