Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Adds Minimal Cost-Complexity Pruning to Decision Trees #12887

Merged
merged 94 commits into from Aug 20, 2019

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Dec 29, 2018

Reference Issues/PRs

Fixes #6557

What does this implement/fix? Explain your changes.

This PR implements Minimal Cost-Complexity Pruning based on L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification and Regression Trees", Wadsworth, Belmont, CA, 1984.

Most of this implementation is the same as the literature. There are two differences:

  1. In Breiman, r(t) is the an estimate of the probability of misclassification. This PR, uses the impurity as r(t).
  2. The weighted number of samples is used to compute the probability of a point landing on a nodes.

A cost complexity parameter, alpha, was added to __init__ to control cost complexity pruning. The post pruning is done at the end of fit.

The code performing Minimal Cost-Complexity Pruning is mostly done in Python. The Python part produces the node ids that will become the leaves of the new subtree. These leaves are passed to a Cython function called build_pruned_tree that builds a tree. This was written in Cython since the tree building API is in Cython.

In Cython, the Stack class is used to go through the tree. Not all the fields of the StackRecord is used. This is a trade off between the code complexity of adding yet another Stack class, and being a little memory inefficient.

Currently, prune_tree is public, which allows for the following use case:

clf = DecisionTreeClassifier(alpha=0.0)
clf.fit(X, y)
clf.set_params(alpha=0.1)
clf.prune_tree()

If we prefer, we can make prune_tree private and not encourage this use case.

Copy link
Member

@jnothman jnothman left a comment

Some API comments.

Some questions for the novice:

  • does calling prune_tree with the same alpha repeatedly return the same tree?
  • does calling prune_tree with increasing alpha return a strict sub-tree?

sklearn/tree/tree.py Outdated Show resolved Hide resolved
sklearn/tree/tree.py Outdated Show resolved Hide resolved
@@ -510,6 +515,110 @@ def decision_path(self, X, check_input=True):
X = self._validate_X_predict(X, check_input)
return self.tree_.decision_path(X)

def prune_tree(self):
Copy link
Member

@jnothman jnothman Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be called after fit automatically to facilitate cross validation etc.

I wonder if this should instead be a public function?

Copy link
Member Author

@thomasjpfan thomasjpfan Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, prune_tree is a public function that is called at the end of fit. It should work with our cross validation classes/functions.

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Jan 1, 2019

does calling prune_tree with the same alpha repeatedly return the same tree?

As long as the original tree is the same, using the same alpha will return the same tree. I will add a test for this behavior.

does calling prune_tree with increasing alpha return a strict sub-tree?

When alpha gets high enough, the entire tree can be pruned, leaving just the root node.

###############################################################################
# Plot training and test scores vs alpha
# --------------------------------------
# Calcuate and plot the the training scores and test accuracy scores
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few typos (I'll document myself on tree pruning and will try to provide mode in-depth review later)

  • Calcuate
  • the the
  • above "a smaller trees"

also I think you should avoid the `:math:` notation in comments

Copy link
Member Author

@thomasjpfan thomasjpfan Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The :math: notation is currently used in other examples such as: https://github.com/scikit-learn/scikit-learn/blob/master/examples/svm/plot_svm_scale_c.py. Are we discouraging the usage of :math: in our examples?

Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK in the docstrings since it will be rendered like regular rst by sphinx, but in the comments it is not necessary.

Copy link
Member Author

@thomasjpfan thomasjpfan Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh ok I didn't know it worked like that, sorry

Copy link
Member

@NicolasHug NicolasHug left a comment

A first pass of cosmestic comments.

@thomasjpfan it seems to me that in general, and unless there's a compelling reason not to, sklearn code uses (potentially long) descriptive variable names.

For example par_idx could be renamed to parent_idx.
Same for cur_alpha, cur_idx, etc.


# bubble up values to ancestor nodes
for idx in leaf_idicies:
cur_R = r_node[idx]
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid upper-case in variable names (same for R_diff)

leaves_in_subtree = np.zeros(shape=n_nodes, dtype=np.uint8)

stack = [(0, -1)]
while len(stack) > 0:
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while stack is more pythonic (same below)


stack = [(0, -1)]
while len(stack) > 0:
node_id, parent = stack.pop()
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node_idx to stay consistent with the rest of the function.
I would also suggest parent_idx instead of parent,

and the parents array could just be named parent.

# computes number of leaves in all branches and the overall impurity of
# the branch. The overall impurity is the sum of r_node in its leaves.
n_leaves = np.zeros(shape=n_nodes, dtype=np.int32)
leaf_idicies, = np.where(leaves_in_subtree)
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaf_indicies

r_branch[leaf_idicies] = r_node[leaf_idicies]

# bubble up values to ancestor nodes
for idx in leaf_idicies:
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for leaf_idx in...?


# descendants of branch are not in subtree
stack = [cur_idx]
while len(stack) > 0:
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while stack

inner_nodes[idx] = False
leaves_in_subtree[idx] = 0
in_subtree[idx] = False
n_left, n_right = child_l[idx], child_r[idx]
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually n_something denotes a count of a number. Here those are just indices right?

leaves_in_subtree[cur_idx] = 1

# updates number of leaves
cur_leaves, n_leaves[cur_idx] = n_leaves[cur_idx], 0
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would propose

n_pruned_leaves = n_leaves[cur_idx] - 1
n_leaves[cur_idx] = 0

and accordingly update n_leaves[cur_idx] below


# bubble up values to ancestors
cur_idx = parents[cur_idx]
while cur_idx != _tree.TREE_LEAF:
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird to bubble up to a leaf.

Whatever you're comparing to here should explicitly be the same value as what you used for defining the root's parent above (stack = [(0, -1)])

I would simply use while cur_idx != -1:

@adrinjalali
Copy link
Member

@adrinjalali adrinjalali commented Jan 1, 2019

Some random thoughts:

  • In the context of ensembles and random forests, the parameter needs to be exposed there as well.
  • It'd be nice if we knew the overhead of the pruning. Specially once the user uses it in the context of random forests, then the overhead is multiplied by the number of trees, and that times the number of fits in a grid search, might be significant. Related to that, two point come to mind:
    • having some numbers related to the overhead would be nice.
    • a potential warm_start for the pruning, maybe (cause the rest of fit doesn't have to be run again for different alpha values).
    • contemplating moving the code to cython might be an idea.
  • I'm not sure if it's necessary to create a copy of the tree for the pruned one. Probably having a masked version of the tree would be optimal for trying out multiple alpha values and a potential warm_start. That also depends on how much overhead that copying has.

@NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Jan 1, 2019

@jnothman , to add to @thomasjpfan answers:

does calling prune_tree with the same alpha repeatedly return the same tree?

The procedure is deterministic so calling prune_tree with same alpha and same original tree will give you the same pruned tree. Also as far as I understand, tree.prune_tree(alpha) == tree.prune_tree(alpha).prune_tree(alpha).

does calling prune_tree with increasing alpha return a strict sub-tree?

A subtree yes, but not necessarily a strict one:

with slpha1 > alpha2, tree.prune_tree(alpha_1) is a subtree of tree.prune_tree(alpha_2) but they may also be equal. This is because the alpha parameter is only used as a threshold here.

in_subtree = np.ones(shape=n_nodes, dtype=np.bool)

cur_alpha = 0
while cur_alpha < self.alpha:
Copy link
Member

@NicolasHug NicolasHug Jan 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 thoughts:

  • on the resources that I read (here and here), the very first pruning step is to remove all the pure leaves (equivalent to using alpha=0 apparently). This is not done here since cur_alpha is immediately overwritten. I wonder if this is done by default in the tree growing algorithm.
  • As you check for cur_alpha < self.alpha and cur_alpha is computed before the tree is pruned in the loop, this means that the alpha of the returned pruned tree will be greater than self.alpha. It would seem more natural to me to return a tree whose alpha is less than self.alpha. In any case we would need to explain how alpha is used in the docs, something like "subtrees whose scores are less than alpha are discarded. The score is computed as ..."

Copy link
Member Author

@thomasjpfan thomasjpfan Jan 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • When alpha = 0, the number of leaves does not contribute to the cost-complexity measure, which I interpreted as "do not prune". Removing the leaves when alpha=0, will increase the cost-complexity measure.

  • Returning a tree whose alpha is less than self.alpha makes sense and should documented.

Copy link
Member

@NicolasHug NicolasHug Jan 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the leaves when alpha=0, will increase the cost-complexity measure.

It cannot increase the cost-complexity: the first step is to prune (some of) the pure leaves. That is if a node N has 2 child leaves where all the samples in both leaves belong to the same class, then the first step will remove those 2 leaves and make N a leaf (which will still be pure). The process is repeated with N and its sibling if needed.

Copy link
Member Author

@thomasjpfan thomasjpfan Jan 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is if a node N has 2 child leaves where all the samples in both leaves belong to the same class, then the first step will remove those 2 leaves and make N a leaf

This makes sense. I will review the tree building code to see if this can happen.

To prevent future confusion, I want to get on the same page with our definition of a pure leaf. From my understanding, a pure leaf is a leaf whose samples belong to the same class, independent of all other leaves. From reading your response, you consider two leaves to be pure if they are siblings and their samples belong to the same class. Is this correct?

Copy link
Member

@NicolasHug NicolasHug Jan 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, a pure leaf is a leaf whose samples belong to the same class, independent of all other leaves

I meant this as well

Copy link
Member

@NicolasHug NicolasHug Jan 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just looked at the tree code, I think we can assume that this "first step" is not needed here after all, since a node is made a leaf according to the min_impurity_decrease (or deprecated min_impurity_split) parameter.

That is, if a node is pure according to min_impurity_decrease, it will be made a leaf, and thus the example case I mentioned above (a node with 2 pure leaves) cannot exist.

@jnothman
Copy link
Member

@jnothman jnothman commented Jan 2, 2019

Where I was going with my questions was the idea of warm_start as well...

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Jan 7, 2019

@adamgreenhall @jnothman

  • Exposing to the ensemble trees make sense.
  • I will do some experimenting to benchmark the overhead of pruning, in its current form and a Cython version of it. I'll post the results here.
  • The masked tree together with the warm_start parameter are great ideas. A masked tree would allow for the level of pruning to be adjusted without growing the tree again, which looks really nice. The current copying approach, allows for the original tree to be deleted, and the pruned tree will take up less space.

@jnothman
Copy link
Member

@jnothman jnothman commented Jan 8, 2019

Copy link
Member

@NicolasHug NicolasHug left a comment

Thanks Thomas, last minor comments but LGTM!

:math:`t`, and its branch, :math:`T_t`, can be equal depending on
:math:`\alpha`. We define the effective :math:`\alpha` of a node to be the
value where they are equal, :math:`R_\alpha(T_t)=R_\alpha(t)` or
:math:`\alpha_{eff}(t)=(R(t)-R(T_t))/(|\tilde{T}|-1)`. A non-terminal node
Copy link
Member

@NicolasHug NicolasHug Aug 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:math:`\alpha_{eff}(t)=(R(t)-R(T_t))/(|\tilde{T}|-1)`. A non-terminal node
:math:`\alpha_{eff}(t)=(R(t)-R(T_t))/(|T|-1)`. A non-terminal node

removed tilde

Minimal Cost-Complexity Pruning
===============================

Minimal cost-complexity pruning is an algorithm used to prune a tree after it
Copy link
Member

@NicolasHug NicolasHug Aug 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add ref here to L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees (Chapter 3)

ax.set_xlabel("alpha")
ax.set_ylabel("accuracy")
ax.set_title("Accuracy vs alpha for training and testing sets")
ax.plot(ccp_alphas, train_scores, label="train", drawstyle="steps-post")
Copy link
Member

@NicolasHug NicolasHug Aug 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ax.plot(ccp_alphas, train_scores, label="train", drawstyle="steps-post")
ax.plot(ccp_alphas, train_scores, marker='o', label="train", drawstyle="steps-post")

Copy link
Member

@NicolasHug NicolasHug Aug 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same below

Tree orig_tree,
DOUBLE_t ccp_alpha):
"""Build a pruned tree from the original tree by transforming the nodes in
leaves_in_subtree into leaves.
Copy link
Member

@NicolasHug NicolasHug Aug 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this one??


cdef:
UINT32_t total_items = path_finder.count
np.ndarray ccp_alphas = np.empty(shape=total_items,
Copy link
Member

@NicolasHug NicolasHug Aug 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping but not super important I think

Copy link
Member

@jnothman jnothman left a comment

Incidental comments

where :math:`|\tilde{T}|` is the number of terminal nodes in :math:`T` and
:math:`R(T)` is traditionally defined as the total misclassification rate of
the terminal nodes. Alternatively, scikit-learn uses the total sample weighted
impurity of the terminal nodes for :math:`R(T)`. As shown in the previous
Copy link
Member

@jnothman jnothman Aug 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

best to say "above" rather than "in the previous section" or link to it so it can withstand change.

:math:`t`, and its branch, :math:`T_t`, can be equal depending on
:math:`\alpha`. We define the effective :math:`\alpha` of a node to be the
value where they are equal, :math:`R_\alpha(T_t)=R_\alpha(t)` or
:math:`\alpha_{eff}=(R(t)-R(T_t))/(|\tilde{T}|-1)`. A non-terminal node with
Copy link
Member

@jnothman jnothman Aug 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use \frac since this is not easily readable anyway.

===============================

Minimal cost-complexity pruning is an algorithm used to prune a tree, described
in Chapter 3 of [BRE]_. This algorithm is parameterized by :math:`\alpha\ge0`
Copy link
Member

@jnothman jnothman Aug 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth adding a small note to say that this is one method used to avoid over-fitting in trees.

:mod:`sklearn.tree`
...................

- |Feature| Adds minimal cost complexity pruning to
Copy link
Member

@jnothman jnothman Aug 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth mentioning what the public api is... i.e. is it just ccp_alpha?

@NicolasHug NicolasHug merged commit 67c94c7 into scikit-learn:master Aug 20, 2019
15 of 17 checks passed
@NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Aug 20, 2019

Failure is unrelated and Joel's comments were addressed so I guess it's good to merge 🎉

Thanks @thomasjpfan !

@jnothman
Copy link
Member

@jnothman jnothman commented Aug 20, 2019

sebp added a commit to sebp/scikit-survival that referenced this issue Apr 9, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
sebp added a commit to sebp/scikit-survival that referenced this issue Apr 9, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports
sebp added a commit to sebp/scikit-survival that referenced this issue Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)
sebp added a commit to sebp/scikit-survival that referenced this issue Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)

Do not add ccp_alpha to SurvivalTree, because
it relies node_impurity, which is not set for SurvivalTree.
sebp added a commit to sebp/scikit-survival that referenced this issue Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)

Do not add ccp_alpha to SurvivalTree, because
it relies node_impurity, which is not set for SurvivalTree.
sebp added a commit to sebp/scikit-survival that referenced this issue Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)

Do not add ccp_alpha to SurvivalTree, because
it relies node_impurity, which is not set for SurvivalTree.
@TrigonaMinima
Copy link

@TrigonaMinima TrigonaMinima commented May 14, 2020

I have a question: in the literature[1], the authors first prune the max grown tree and then prune it according to the different alpha. Following that, they either use a test set or cross validation to find the best alpha or the corresponding "best pruned tree". Here, we have selected the tree before alpha crosses the ccp_alpha. Am I right or did I miss something? Is the activity of selecting the "best" ccp_alpha left to the user?

1: L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees. Wadsworth, Belmont, CA, 1984.

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented May 14, 2020

Am I right or did I miss something? Is the activity of selecting the "best" ccp_alpha left to the user?

Yes this needs to be done with our cross-validation tools.

There is a more interesting way to do this by setting aside some of the training data for validation, in such a way that the tree can automatically find an alpha. This has not been implemented here.

@LEEPEIQIN
Copy link

@LEEPEIQIN LEEPEIQIN commented Jun 6, 2020

Thank you for your great work and it benifits me a lot.
However, I am curious about the reason why you use GINI Impurity instead of misclassification rate as the cost function? Could you give a reference or any key words to let me search on google?
Thank you very much!

@LEEPEIQIN
Copy link

@LEEPEIQIN LEEPEIQIN commented Jun 6, 2020

I find some further discussion in "performance learning" (Johannes Fürnkranz, Eyke Hüllermeier). p87-88.
Thank you very much.

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Jun 6, 2020

Using the criterion allows pruning to be extended to regression trees. (the criterion for classification defaults to gini impurity)

@LEEPEIQIN
Copy link

@LEEPEIQIN LEEPEIQIN commented Jun 7, 2020

Thank you very much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants