Skip to content

Commit

Permalink
[MRG+1] feature: add beta-threshold early stopping for decision tree …
Browse files Browse the repository at this point in the history
…growth (#6954)

* feature: add beta-threshold early stopping for decision tree growth

* check if value of beta is greater than or equal to 0

* test if default value of beta is 0 and edit input validation error message

* feature: separately validate beta for reg. and clf., and add tests for it

* feature: add beta to forest-based ensemble methods

* feature: add separate condition to determine that beta is float

* feature: add beta to gradient boosting estimators

* rename parameter to min_impurity_split, edit input validation and associated tests

* chore: fix spacing in forest and force recompilation of grad boosting extension

* remove trivial comment in grad boost and add whats new

* edit wording in test comment / rebuild

* rename constant with the same name as our parameter

* edit line length for what's new

* remove constant and set min_impurity_split to 1e-7 by default

* fix docstrings for new default

* fix defaults in gradientboosting and forest classes
  • Loading branch information
nelson-liu authored and glouppe committed Jul 27, 2016
1 parent d829091 commit 376aa50
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 18 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ New features
<https://github.com/scikit-learn/scikit-learn/pull/6667>`_) by `Nelson
Liu`_.

- Added weighted impurity-based early stopping criterion for decision tree
growth. (`#6954
<https://github.com/scikit-learn/scikit-learn/pull/6954>`_) by `Nelson
Liu`_

Enhancements
............

Expand Down
40 changes: 35 additions & 5 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,10 @@ class RandomForestClassifier(ForestClassifier):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
bootstrap : boolean, optional (default=True)
Whether bootstrap samples are used when building trees.
Expand Down Expand Up @@ -899,6 +903,7 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
bootstrap=True,
oob_score=False,
n_jobs=1,
Expand All @@ -911,7 +916,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -928,6 +933,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split


class RandomForestRegressor(ForestRegressor):
Expand Down Expand Up @@ -1001,6 +1007,10 @@ class RandomForestRegressor(ForestRegressor):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
bootstrap : boolean, optional (default=True)
Whether bootstrap samples are used when building trees.
Expand Down Expand Up @@ -1064,6 +1074,7 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
bootstrap=True,
oob_score=False,
n_jobs=1,
Expand All @@ -1075,7 +1086,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1091,6 +1102,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split


class ExtraTreesClassifier(ForestClassifier):
Expand Down Expand Up @@ -1160,6 +1172,10 @@ class ExtraTreesClassifier(ForestClassifier):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
bootstrap : boolean, optional (default=False)
Whether bootstrap samples are used when building trees.
Expand Down Expand Up @@ -1255,6 +1271,7 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
bootstrap=False,
oob_score=False,
n_jobs=1,
Expand All @@ -1267,7 +1284,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1284,6 +1301,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split


class ExtraTreesRegressor(ForestRegressor):
Expand Down Expand Up @@ -1355,6 +1373,10 @@ class ExtraTreesRegressor(ForestRegressor):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
bootstrap : boolean, optional (default=False)
Whether bootstrap samples are used when building trees.
Expand Down Expand Up @@ -1419,6 +1441,7 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
bootstrap=False,
oob_score=False,
n_jobs=1,
Expand All @@ -1430,7 +1453,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1446,6 +1469,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split


class RandomTreesEmbedding(BaseForest):
Expand Down Expand Up @@ -1500,6 +1524,10 @@ class RandomTreesEmbedding(BaseForest):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
sparse_output : bool, optional (default=True)
Whether or not to return a sparse CSR matrix, as default behavior,
or to return a dense array compatible with dense pipeline operators.
Expand Down Expand Up @@ -1544,6 +1572,7 @@ def __init__(self,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_leaf_nodes=None,
min_impurity_split=1e-7,
sparse_output=True,
n_jobs=1,
random_state=None,
Expand All @@ -1554,7 +1583,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=False,
oob_score=False,
Expand All @@ -1570,6 +1599,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = 1
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split
self.sparse_output = sparse_output

def _set_oob_score(self, X, y):
Expand Down
23 changes: 17 additions & 6 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ class BaseGradientBoosting(six.with_metaclass(ABCMeta, BaseEnsemble,
@abstractmethod
def __init__(self, loss, learning_rate, n_estimators, criterion,
min_samples_split, min_samples_leaf, min_weight_fraction_leaf,
max_depth, init, subsample, max_features,
max_depth, min_impurity_split, init, subsample, max_features,
random_state, alpha=0.9, verbose=0, max_leaf_nodes=None,
warm_start=False, presort='auto'):

Expand All @@ -736,6 +736,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion,
self.subsample = subsample
self.max_features = max_features
self.max_depth = max_depth
self.min_impurity_split = min_impurity_split
self.init = init
self.random_state = random_state
self.alpha = alpha
Expand Down Expand Up @@ -1358,6 +1359,10 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.
min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
init : BaseEstimator, None, optional (default=None)
An estimator object that is used to compute the initial
predictions. ``init`` has to provide ``fit`` and ``predict``.
Expand Down Expand Up @@ -1437,8 +1442,8 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, init=None, random_state=None,
max_features=None, verbose=0,
max_depth=3, min_impurity_split=1e-7, init=None,
random_state=None, max_features=None, verbose=0,
max_leaf_nodes=None, warm_start=False,
presort='auto'):

Expand All @@ -1450,7 +1455,9 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
max_depth=max_depth, init=init, subsample=subsample,
max_features=max_features,
random_state=random_state, verbose=verbose,
max_leaf_nodes=max_leaf_nodes, warm_start=warm_start,
max_leaf_nodes=max_leaf_nodes,
min_impurity_split=min_impurity_split,
warm_start=warm_start,
presort=presort)

def _validate_y(self, y):
Expand Down Expand Up @@ -1711,6 +1718,10 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.
min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
alpha : float (default=0.9)
The alpha-quantile of the huber loss function and the quantile
loss function. Only if ``loss='huber'`` or ``loss='quantile'``.
Expand Down Expand Up @@ -1791,7 +1802,7 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, init=None, random_state=None,
max_depth=3, min_impurity_split=1e-7, init=None, random_state=None,
max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None,
warm_start=False, presort='auto'):

Expand All @@ -1801,7 +1812,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_depth=max_depth, init=init, subsample=subsample,
max_features=max_features,
max_features=max_features, min_impurity_split=min_impurity_split,
random_state=random_state, alpha=alpha, verbose=verbose,
max_leaf_nodes=max_leaf_nodes, warm_start=warm_start,
presort=presort)
Expand Down
2 changes: 2 additions & 0 deletions sklearn/tree/_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause

Expand Down Expand Up @@ -95,6 +96,7 @@ cdef class TreeBuilder:
cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf
cdef double min_weight_leaf # Minimum weight in a leaf
cdef SIZE_t max_depth # Maximal tree depth
cdef double min_impurity_split # Impurity threshold for early stopping

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=*,
Expand Down
16 changes: 11 additions & 5 deletions sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Joel Nothman <joel.nothman@gmail.com>
# Fares Hedayati <fares.hedayati@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause

Expand Down Expand Up @@ -63,7 +64,6 @@ TREE_UNDEFINED = -2
cdef SIZE_t _TREE_LEAF = TREE_LEAF
cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED
cdef SIZE_t INITIAL_STACK_SIZE = 10
cdef DTYPE_t MIN_IMPURITY_SPLIT = 1e-7

# Repeat struct definition for numpy
NODE_DTYPE = np.dtype({
Expand Down Expand Up @@ -131,12 +131,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):

def __cinit__(self, Splitter splitter, SIZE_t min_samples_split,
SIZE_t min_samples_leaf, double min_weight_leaf,
SIZE_t max_depth):
SIZE_t max_depth, double min_impurity_split):
self.splitter = splitter
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_leaf = min_weight_leaf
self.max_depth = max_depth
self.min_impurity_split = min_impurity_split

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None,
Expand Down Expand Up @@ -166,6 +167,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
cdef SIZE_t min_samples_leaf = self.min_samples_leaf
cdef double min_weight_leaf = self.min_weight_leaf
cdef SIZE_t min_samples_split = self.min_samples_split
cdef double min_impurity_split = self.min_impurity_split

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr, X_idx_sorted)
Expand Down Expand Up @@ -223,7 +225,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
impurity = splitter.node_impurity()
first = 0

is_leaf = is_leaf or (impurity <= MIN_IMPURITY_SPLIT)
is_leaf = (is_leaf or
(impurity <= min_impurity_split))

if not is_leaf:
splitter.node_split(impurity, &split, &n_constant_features)
Expand Down Expand Up @@ -289,13 +292,15 @@ cdef class BestFirstTreeBuilder(TreeBuilder):

def __cinit__(self, Splitter splitter, SIZE_t min_samples_split,
SIZE_t min_samples_leaf, min_weight_leaf,
SIZE_t max_depth, SIZE_t max_leaf_nodes):
SIZE_t max_depth, SIZE_t max_leaf_nodes,
double min_impurity_split):
self.splitter = splitter
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_leaf = min_weight_leaf
self.max_depth = max_depth
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None,
Expand Down Expand Up @@ -421,6 +426,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
cdef SIZE_t n_node_samples
cdef SIZE_t n_constant_features = 0
cdef double weighted_n_samples = splitter.weighted_n_samples
cdef double min_impurity_split = self.min_impurity_split
cdef double weighted_n_node_samples
cdef bint is_leaf
cdef SIZE_t n_left, n_right
Expand All @@ -436,7 +442,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
(n_node_samples < self.min_samples_split) or
(n_node_samples < 2 * self.min_samples_leaf) or
(weighted_n_node_samples < self.min_weight_leaf) or
(impurity <= MIN_IMPURITY_SPLIT))
(impurity <= min_impurity_split))

if not is_leaf:
splitter.node_split(impurity, &split, &n_constant_features)
Expand Down
Loading

0 comments on commit 376aa50

Please sign in to comment.