Skip to content

Commit

Permalink
[MRG+1] correct condition in decision tree construction (#7441)
Browse files Browse the repository at this point in the history
* remove unused condition in decision tree construction

* edit is_leaf condition for min_weight_leaf

* edit ordering of statements

* remove extra parens and add whats new
  • Loading branch information
nelson-liu authored and glouppe committed Sep 29, 2016
1 parent 05e702c commit 2f4b661
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -20,6 +20,13 @@ New features
Enhancements
............

- Edited criterion for leaf nodes in decision tree criterion by declaring a
node as a leaf if the weighted number of samples at the node is less than
2 * the minimum weight specified to be at a node. This makes growth more
efficient, but trees using parameters that modify the weight at each leaf
will be grown differently. (`#7441
<https://github.com/scikit-learn/scikit-learn/pull/7441>`_) by `Nelson
Liu`_.

Bug fixes
.........
Expand Down
18 changes: 9 additions & 9 deletions sklearn/tree/_tree.pyx
Expand Up @@ -216,10 +216,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
n_node_samples = end - start
splitter.node_reset(start, end, &weighted_n_node_samples)

is_leaf = ((depth >= max_depth) or
(n_node_samples < min_samples_split) or
(n_node_samples < 2 * min_samples_leaf) or
(weighted_n_node_samples < min_weight_leaf))
is_leaf = (depth >= max_depth or
n_node_samples < min_samples_split or
n_node_samples < 2 * min_samples_leaf or
weighted_n_node_samples < 2 * min_weight_leaf)

if first:
impurity = splitter.node_impurity()
Expand Down Expand Up @@ -436,11 +436,11 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
impurity = splitter.node_impurity()

n_node_samples = end - start
is_leaf = ((depth > self.max_depth) or
(n_node_samples < self.min_samples_split) or
(n_node_samples < 2 * self.min_samples_leaf) or
(weighted_n_node_samples < self.min_weight_leaf) or
(impurity <= min_impurity_split))
is_leaf = (depth > self.max_depth or
n_node_samples < self.min_samples_split or
n_node_samples < 2 * self.min_samples_leaf or
weighted_n_node_samples < 2 * self.min_weight_leaf or
impurity <= min_impurity_split)

if not is_leaf:
splitter.node_split(impurity, &split, &n_constant_features)
Expand Down

0 comments on commit 2f4b661

Please sign in to comment.