Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Adds Minimal Cost-Complexity Pruning to Decision Trees #12887

Merged
merged 94 commits into from Aug 20, 2019
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
a5f295a
ENH: Adds cost complexity pruning
thomasjpfan Dec 28, 2018
9569e9f
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Dec 28, 2018
1a554f6
DOC: Update
thomasjpfan Dec 28, 2018
84dbc05
DOC: Adds comments to algorithm
thomasjpfan Dec 28, 2018
5e10962
RFC: Small
thomasjpfan Dec 28, 2018
745cd18
RFC: Moves some logic to cython
thomasjpfan Dec 28, 2018
c1cd149
DOC: More comments
thomasjpfan Dec 28, 2018
90c294e
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Dec 28, 2018
5c36185
DOC: Removes unused parameter
thomasjpfan Dec 29, 2018
4b277b9
DOC: Rewords
thomasjpfan Dec 29, 2018
b83b135
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Dec 29, 2018
ffece26
ENH: Adds support for extra trees
thomasjpfan Dec 29, 2018
b2e2a52
DOC: Updates whats_new
thomasjpfan Dec 29, 2018
e95829f
RFC: Makes prune_tree public
thomasjpfan Dec 29, 2018
c313151
RFC: Less diffs
thomasjpfan Dec 29, 2018
fd5be88
RFC: Moves prune_tree closer to the end of fit
thomasjpfan Jan 1, 2019
2e348db
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jan 1, 2019
75709a0
BUG: Fix
thomasjpfan Jan 1, 2019
efe9793
BUG: Fix
thomasjpfan Jan 1, 2019
568eb04
RFC: Addresses code review
thomasjpfan Jan 29, 2019
eb28d50
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jan 29, 2019
847e1f0
RFC: Minimize diffs
thomasjpfan Jan 29, 2019
fa9c83c
RFC: Uses memoryviews
thomasjpfan Jan 29, 2019
81c776e
RFC: Deterministic ordering
thomasjpfan Jan 29, 2019
0d85747
ENH: Returns tree with greatest CCP less than alpha
thomasjpfan Jan 30, 2019
57963d5
RFC: Rename alpha to ccp_alpha
thomasjpfan Jan 31, 2019
25910e0
DOC: Uses ccp_alpha
thomasjpfan Jan 31, 2019
e2cd686
ENH: Users cython for pruning
thomasjpfan Feb 4, 2019
a43972a
ENH: Adds ccp_alpha to forest
thomasjpfan Feb 5, 2019
39dbccd
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Feb 5, 2019
1a347f8
BUG: Fixes doctest
thomasjpfan Feb 5, 2019
43a656b
ENH: Releases gil
thomasjpfan Feb 5, 2019
e59b662
BUG: Fix
thomasjpfan Feb 5, 2019
6465355
RFC Address comments
thomasjpfan Feb 7, 2019
bcfbfc3
STY Flake8
thomasjpfan Feb 7, 2019
b17433c
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Feb 7, 2019
71d0513
DOC adds raw for math
thomasjpfan Feb 7, 2019
7ee455e
RFC Address comments
thomasjpfan Feb 8, 2019
2e62490
DOC Adds pruning to user guide
thomasjpfan Feb 8, 2019
bba792d
DOC English
thomasjpfan Feb 8, 2019
ded8552
DOC Adds forests to whats_new
thomasjpfan Feb 8, 2019
3623657
ENH Adds pruning to gradient boosting
thomasjpfan Feb 8, 2019
2a3b554
DOC Fixes whats_new
thomasjpfan Feb 11, 2019
b0d76fc
DOC Show plt at the end
thomasjpfan Feb 12, 2019
97229ec
RFC Removes unneeded code
thomasjpfan Feb 13, 2019
013ca9e
STY pep257
thomasjpfan Feb 15, 2019
791077d
TST Adds prune all leaves test
thomasjpfan Feb 16, 2019
0fa13ed
RFC Address comments
thomasjpfan Feb 16, 2019
af54d21
DOC Adds more details
thomasjpfan Feb 27, 2019
88f0011
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Feb 28, 2019
ccd47d1
CLN Address comments
thomasjpfan Mar 12, 2019
ec1b9fc
DOC Fix
thomasjpfan Mar 12, 2019
4a4b2ac
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Apr 18, 2019
8132d2d
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Apr 26, 2019
3e5486d
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Apr 26, 2019
188ccb8
ENH Adds cost complexity pruning path
thomasjpfan Apr 26, 2019
218311f
DOC Adds docstring
thomasjpfan Apr 26, 2019
b8a2769
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan May 6, 2019
697a383
DOC Move whats_new
thomasjpfan May 6, 2019
2de7dfd
ENH Adds impurity tracking to pruning
thomasjpfan May 6, 2019
7452f1f
DOC New example using path function
thomasjpfan May 7, 2019
a199ce8
DOC Adjust titles
thomasjpfan May 7, 2019
dc6b6fd
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan May 7, 2019
45b5cdc
ENH Returns a bunch when calcuating path
thomasjpfan May 20, 2019
abf41ca
BUG Uses bunch in tests
thomasjpfan May 21, 2019
8cc77ca
DOC Adds more details in example
thomasjpfan May 21, 2019
971f85a
CLN Adds more comments
thomasjpfan May 21, 2019
e81f2a3
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan May 21, 2019
bc956ca
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan May 31, 2019
7f620a8
DOC Removes last node in all plots
thomasjpfan Jun 3, 2019
d610101
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jun 3, 2019
b9247fc
DOC Adjust layout
thomasjpfan Jun 4, 2019
cc5f1a9
CLN Address comments
thomasjpfan Jun 6, 2019
5e2ace3
CLN Adds error message to MemoryError
thomasjpfan Jun 17, 2019
5b50196
CLN Adds alpha dependency of t
thomasjpfan Jun 17, 2019
f612457
DOC Update wording
thomasjpfan Jun 17, 2019
86fdbc6
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jul 17, 2019
dda0f5e
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jul 30, 2019
40bab1a
CLN Remove file
thomasjpfan Jul 30, 2019
0a06e46
CLN Address comments
thomasjpfan Jul 30, 2019
9bf7d83
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jul 30, 2019
31e7816
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Aug 16, 2019
7994897
CLN Address NicolasHug's comments
thomasjpfan Aug 16, 2019
2a42e0c
CLN Refactors tests to use pruning_path
thomasjpfan Aug 16, 2019
e8e3967
TST Adds single node tree test
thomasjpfan Aug 16, 2019
17b4112
STY flake8
thomasjpfan Aug 16, 2019
1a8f07e
TST Adds test on impurities from path
thomasjpfan Aug 16, 2019
17d3888
DOC Adds words
thomasjpfan Aug 16, 2019
9b01fc8
DOC Adds words
thomasjpfan Aug 16, 2019
073fd00
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Aug 16, 2019
1774b8c
DOC Better words
thomasjpfan Aug 16, 2019
73cdf1e
DOC Adds docstring to ccp_pruning_path
thomasjpfan Aug 16, 2019
a688f60
DOC Uses new standrad
thomasjpfan Aug 16, 2019
82f3aa1
CLN Address joels comments
thomasjpfan Aug 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v0.21.rst
Expand Up @@ -192,6 +192,11 @@ Support for Python 3.4 and below has been officially dropped.
and :class:`tree.ExtraTreeRegressor`.
:issue:`12300` by :user:`Adrin Jalali <adrinjalali>`.

- |Feature| Adds minimal cost complexity pruning to
:class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`,
:class:`tree.ExtraTreeClassifier`, and :class:`tree.ExtraTreeRegressor`.
:issue:`6557` by :user:`Thomas Fan <thomasjpfan>`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:issue:`6557` by :user:`Thomas Fan <thomasjpfan>`.
:issue:`12887` by :user:`Thomas Fan <thomasjpfan>`.

This should be the PR number


- |Fix| Fixed an issue with :class:`tree.BaseDecisionTree`
and consequently all estimators based
on it, including :class:`tree.DecisionTreeClassifier`,
Expand Down
67 changes: 67 additions & 0 deletions examples/tree/plot_cost_complexity_pruning.py
@@ -0,0 +1,67 @@
r"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this example!

A few remarks:

  • you don't need a raw string (remove r in r""")
  • there are still typos https://github.com/scikit-learn/scikit-learn/pull/12887/files#r244643624
  • you need to add print(__doc__) before the imports
  • imports should be at the top (there's a matplotlib import in the middle of the file)
  • fig.show() immediately closes the window, I think you'll need to use plt.show instead
  • maybe we can have both plots on the same fig with sharex=True?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend toward optimizing my examples for HTML output. This usually means keeping the description close to the plotting code and plotting one thing at a time.

Given that, I am okay with combine the two plots.

Copy link
Member

@NicolasHug NicolasHug Feb 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok makes sense

========================================================
Post pruning decision trees with cost complexity pruning
========================================================

In this example, decision tree classifiers are trained with a post pruning
technique called minimal cost complexity pruning. This technique is
parameterized by the complexity parameter, :math:`\alpha`. Greater values of
:math:`\alpha` will prune more of the tree, thus creating a smaller trees.
"""

###############################################################################
# Train decision tree classifiers
# -------------------------------
# Train 40 decision tree classifiers with :math:`\alpha` from 0.00 to

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It ranges from 0 to 0.04 in your example (np.linspace(0, 0.04, 40))

# 0.40.
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

alphas = np.linspace(0, 0.04, 40)
clfs = []
for alpha in alphas:
clf = DecisionTreeClassifier(random_state=0, alpha=alpha)
clf.fit(X_train, y_train)
clfs.append(clf)

###############################################################################
# Plot training and test scores vs alpha
# --------------------------------------
# Calcuate and plot the the training scores and test accuracy scores
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few typos (I'll document myself on tree pruning and will try to provide mode in-depth review later)

  • Calcuate
  • the the
  • above "a smaller trees"

also I think you should avoid the `:math:` notation in comments

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The :math: notation is currently used in other examples such as: https://github.com/scikit-learn/scikit-learn/blob/master/examples/svm/plot_svm_scale_c.py. Are we discouraging the usage of :math: in our examples?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK in the docstrings since it will be rendered like regular rst by sphinx, but in the comments it is not necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh ok I didn't know it worked like that, sorry

# for our classifiers. With :math:`\alpha` equal to 0.0, the decision tree is
# overfitting with a 1.0 training accuracy score. As the decision tree is
# pruned the testing accuracy score increases up to a point and then decreases.
import matplotlib.pyplot as plt

train_scores = []
test_scores = []
for clf in clfs:
train_scores.append(clf.score(X_train, y_train))
test_scores.append(clf.score(X_test, y_test))

fig, ax = plt.subplots()
ax.set_xlabel("alpha")
ax.set_ylabel("accuracy")
ax.set_title("Accuracy vs alpha for training and testing sets")
ax.plot(alphas, train_scores, label="train")
ax.plot(alphas, test_scores, label="test")
ax.legend()
fig.show()

###############################################################################
# Plot total number of nodes vs alpha
# -----------------------------------
# Plot the total number of nodes for our classifiers. As :math:`\alpha`
# increases, the number of nodes decreases.
node_counts = [clf.tree_.node_count for clf in clfs]
fig, ax = plt.subplots()
ax.set_xlabel("alpha")
ax.set_ylabel("number of nodes")
ax.set_title("Number of nodes vs alpha")
ax.plot(alphas, node_counts)
fig.show()
10 changes: 10 additions & 0 deletions sklearn/tree/_tree.pxd
Expand Up @@ -103,3 +103,13 @@ cdef class TreeBuilder:
np.ndarray sample_weight=*,
np.ndarray X_idx_sorted=*)
cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight)


# =============================================================================
# Build Pruned Tree
# =============================================================================

cpdef build_pruned_tree(
Tree tree,
Tree orig_tree,
np.ndarray[np.npy_uint8, ndim=1] leaves_in_subtree)
98 changes: 98 additions & 0 deletions sklearn/tree/_tree.pyx
Expand Up @@ -1132,3 +1132,101 @@ cdef class Tree:
Py_INCREF(self)
arr.base = <PyObject*> self
return arr

# =============================================================================
# Build Pruned Tree
# =============================================================================

cpdef build_pruned_tree(
Tree tree,
Tree orig_tree,
np.ndarray[np.npy_uint8, ndim=1] leaves_in_subtree):
"""Builds a pruned tree.

Builds a pruned tree from the original tree. The values and nodes from the
original tree are copied into the pruned tree.
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
tree : Tree
Location to place the pruned tree
orig_tree : Tree
Original tree
leaves_in_subtree : numpy.ndarray, dtype=np.npy_uint8
Original node ids that are leaves in pruned tree
"""

cdef SIZE_t capacity = np.sum(leaves_in_subtree)
tree._resize(capacity)

cdef SIZE_t orig_node_id
cdef SIZE_t new_node_id
cdef SIZE_t depth
cdef SIZE_t parent
cdef bint is_left
cdef bint is_leaf

# value_stride for original tree and new tree are the same
cdef SIZE_t value_stride = orig_tree.value_stride
cdef SIZE_t max_depth_seen = -1
cdef int rc = 0
cdef Node* node
cdef double* orig_value_ptr
cdef double* new_value_ptr

# Only uses the start, depth, parent, and is_left variables
cdef Stack stack = Stack(INITIAL_STACK_SIZE)
cdef StackRecord stack_record

with nogil:
# push root node onto stack
rc = stack.push(0, 0, 0, _TREE_UNDEFINED, 0, 0.0, 0)
if rc == -1:
with gil:
raise MemoryError()

while not stack.is_empty():
stack.pop(&stack_record)

orig_node_id = stack_record.start
depth = stack_record.depth
parent = stack_record.parent
is_left = stack_record.is_left

is_leaf = leaves_in_subtree[orig_node_id]
node = &orig_tree.nodes[orig_node_id]

new_node_id = tree._add_node(
parent, is_left, is_leaf, node.feature, node.threshold,
node.impurity, node.n_node_samples,
node.weighted_n_node_samples)

if new_node_id == <SIZE_t>(-1):
rc = -1
break

# copy value from original tree to new tree
orig_value_ptr = orig_tree.value + value_stride * orig_node_id
new_value_ptr = tree.value + value_stride * new_node_id
memcpy(new_value_ptr, orig_value_ptr, sizeof(double) * value_stride)

if not is_leaf:
# Push right child on stack
rc = stack.push(
node.right_child, 0, depth + 1, new_node_id, 0, 0.0, 0)
if rc == -1:
break

# push left child on stack
rc = stack.push(
node.left_child, 0, depth + 1, new_node_id, 1, 0.0, 0)
if rc == -1:
break

if depth > max_depth_seen:
max_depth_seen = depth

if rc >= 0:
tree.max_depth = max_depth_seen
if rc == -1:
raise MemoryError()
68 changes: 68 additions & 0 deletions sklearn/tree/tests/test_tree.py
Expand Up @@ -1828,3 +1828,71 @@ def test_empty_leaf_infinite_threshold():
infinite_threshold = np.where(~np.isfinite(tree.tree_.threshold))[0]
assert len(infinite_threshold) == 0
assert len(empty_leaf) == 0


@pytest.mark.parametrize("criterion", CLF_CRITERIONS)
@pytest.mark.parametrize(
"dataset", set(DATASETS.keys()) - {"reg_small", "boston"})
@pytest.mark.parametrize(
"tree_cls", [DecisionTreeClassifier, ExtraTreeClassifier])
def test_prune_tree_clf_are_subtrees(criterion, dataset, tree_cls):
dataset = DATASETS[dataset]
X, y = dataset["X"], dataset["y"]
assert_pruning_creates_subtree(tree_cls, X, y)


@pytest.mark.parametrize("criterion", REG_CRITERIONS)
@pytest.mark.parametrize("dataset", DATASETS.keys())
@pytest.mark.parametrize(
"tree_cls", [DecisionTreeRegressor, ExtraTreeRegressor])
def test_prune_tree_reg_are_subtrees(criterion, dataset, tree_cls):
dataset = DATASETS[dataset]
X, y = dataset["X"], dataset["y"]
assert_pruning_creates_subtree(tree_cls, X, y)


def assert_pruning_creates_subtree(estimator_cls, X, y):
estimators = []
for alpha in np.linspace(0.0, 0.2, 11):
est = estimator_cls(
max_leaf_nodes=20, alpha=alpha, random_state=0).fit(X, y)
estimators.append(est)

for prev_est, next_est in zip(estimators[:-1], estimators[1:]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for prev_est, next_est in zip(estimators[:-1], estimators[1:]):
for prev_est, next_est in zip(estimators, estimators[1:]):

Just a nitpick: zip only iterates until the shortest iterable is exhausted

assert_is_subtree(prev_est.tree_, next_est.tree_)


def assert_is_subtree(tree, subtree):
assert tree.node_count >= subtree.node_count
assert tree.max_depth >= subtree.max_depth

tree_c_left = tree.children_left
tree_c_right = tree.children_right
subtree_c_left = subtree.children_left
subtree_c_right = subtree.children_right

stack = [(0, 0)]
while len(stack) > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
while len(stack) > 0:
while stack:

tree_n_idx, subtree_n_idx = stack.pop()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would propose tree_node_idx and subtree_node_idx to be more explicit (unless it makes the indentation look worse)

assert_array_almost_equal(
tree.value[tree_n_idx], subtree.value[subtree_n_idx])
assert_almost_equal(
tree.impurity[tree_n_idx], subtree.impurity[subtree_n_idx])
assert_almost_equal(
tree.n_node_samples[tree_n_idx],
subtree.n_node_samples[subtree_n_idx])
assert_almost_equal(
tree.weighted_n_node_samples[tree_n_idx],
subtree.weighted_n_node_samples[subtree_n_idx])

if (subtree_c_left[subtree_n_idx] == subtree_c_right[subtree_n_idx]):
# is a leaf
assert_almost_equal(-2, subtree.threshold[subtree_n_idx])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use TREE_UNDEFINED instead of -2

else:
# not a leaf
assert_almost_equal(
tree.threshold[tree_n_idx], subtree.threshold[subtree_n_idx])
stack.append(
(tree_c_left[tree_n_idx], subtree_c_left[subtree_n_idx]))
stack.append(
(tree_c_right[tree_n_idx], subtree_c_right[subtree_n_idx]))