Skip to content

Commit

Permalink
ENH iForest - expose warm_start (#13451) (#13496)
Browse files Browse the repository at this point in the history
* ENH iForest - expose warm_start (#13451)

* Incorporates comments from PR #13496

* versionadded=0.21

* adition in whatsnew

* test using iris dataset

* Update sklearn/ensemble/tests/test_iforest.py

smaller dataset for testing

Co-Authored-By: petibear <40757147+petibear@users.noreply.github.com>

* Trigger CI

* Corrected the PR reference

* doc entry on warm_start + renamed the test

* Corrections in the doc example

* comments made inline in the doc example
  • Loading branch information
pmarko1711 authored and adrinjalali committed Mar 27, 2019
1 parent 7500693 commit 49cdee6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
13 changes: 13 additions & 0 deletions doc/modules/outlier_detection.rst
Expand Up @@ -252,6 +252,19 @@ This algorithm is illustrated below.
:align: center
:scale: 75%

.. _iforest_warm_start:

The :class:`ensemble.IsolationForest` supports ``warm_start=True`` which
allows you to add more trees to an already fitted model::

>>> from sklearn.ensemble import IsolationForest
>>> import numpy as np
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [0, 0], [-20, 50], [3, 5]])
>>> clf = IsolationForest(n_estimators=10, warm_start=True)
>>> clf.fit(X) # fit 10 trees # doctest: +SKIP
>>> clf.set_params(n_estimators=20) # add 10 more trees # doctest: +SKIP
>>> clf.fit(X) # fit the added trees # doctest: +SKIP

.. topic:: Examples:

* See :ref:`sphx_glr_auto_examples_ensemble_plot_isolation_forest.py` for
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v0.21.rst
Expand Up @@ -158,6 +158,10 @@ Support for Python 3.4 and below has been officially dropped.
- |Enhancement| Minimized the validation of X in
:class:`ensemble.AdaBoostClassifier` and :class:`ensemble.AdaBoostRegressor`
:issue:`13174` by :user:`Christos Aridas <chkoar>`.

- |Enhancement| :class:`ensemble.IsolationForest` now exposes ``warm_start``
parameter, allowing iterative addition of trees to an isolation
forest. :issue:`13496` by :user:`Peter Marko <petibear>`.

- |Efficiency| Make :class:`ensemble.IsolationForest` more memory efficient
by avoiding keeping in memory each tree prediction. :issue:`13260` by
Expand Down
10 changes: 9 additions & 1 deletion sklearn/ensemble/iforest.py
Expand Up @@ -120,6 +120,12 @@ class IsolationForest(BaseBagging, OutlierMixin):
verbose : int, optional (default=0)
Controls the verbosity of the tree building process.
warm_start : bool, optional (default=False)
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble, otherwise, just fit a whole
new forest. See :term:`the Glossary <warm_start>`.
.. versionadded:: 0.21
Attributes
----------
Expand Down Expand Up @@ -173,7 +179,8 @@ def __init__(self,
n_jobs=None,
behaviour='old',
random_state=None,
verbose=0):
verbose=0,
warm_start=False):
super().__init__(
base_estimator=ExtraTreeRegressor(
max_features=1,
Expand All @@ -185,6 +192,7 @@ def __init__(self,
n_estimators=n_estimators,
max_samples=max_samples,
max_features=max_features,
warm_start=warm_start,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose)
Expand Down
22 changes: 22 additions & 0 deletions sklearn/ensemble/tests/test_iforest.py
Expand Up @@ -295,6 +295,28 @@ def test_score_samples():
clf2.score_samples([[2., 2.]]))


@pytest.mark.filterwarnings('ignore:default contamination')
@pytest.mark.filterwarnings('ignore:behaviour="old"')
def test_iforest_warm_start():
"""Test iterative addition of iTrees to an iForest """

rng = check_random_state(0)
X = rng.randn(20, 2)

# fit first 10 trees
clf = IsolationForest(n_estimators=10, max_samples=20,
random_state=rng, warm_start=True)
clf.fit(X)
# remember the 1st tree
tree_1 = clf.estimators_[0]
# fit another 10 trees
clf.set_params(n_estimators=20)
clf.fit(X)
# expecting 20 fitted trees and no overwritten trees
assert len(clf.estimators_) == 20
assert clf.estimators_[0] is tree_1


@pytest.mark.filterwarnings('ignore:default contamination')
@pytest.mark.filterwarnings('ignore:behaviour="old"')
def test_deprecation():
Expand Down

0 comments on commit 49cdee6

Please sign in to comment.