New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Adds Minimal Cost-Complexity Pruning to Decision Trees #12887
Changes from 15 commits
a5f295a
9569e9f
1a554f6
84dbc05
5e10962
745cd18
c1cd149
90c294e
5c36185
4b277b9
b83b135
ffece26
b2e2a52
e95829f
c313151
fd5be88
2e348db
75709a0
efe9793
568eb04
eb28d50
847e1f0
fa9c83c
81c776e
0d85747
57963d5
25910e0
e2cd686
a43972a
39dbccd
1a347f8
43a656b
e59b662
6465355
bcfbfc3
b17433c
71d0513
7ee455e
2e62490
bba792d
ded8552
3623657
2a3b554
b0d76fc
97229ec
013ca9e
791077d
0fa13ed
af54d21
88f0011
ccd47d1
ec1b9fc
4a4b2ac
8132d2d
3e5486d
188ccb8
218311f
b8a2769
697a383
2de7dfd
7452f1f
a199ce8
dc6b6fd
45b5cdc
abf41ca
8cc77ca
971f85a
e81f2a3
bc956ca
7f620a8
d610101
b9247fc
cc5f1a9
5e2ace3
5b50196
f612457
86fdbc6
dda0f5e
40bab1a
0a06e46
9bf7d83
31e7816
7994897
2a42e0c
e8e3967
17b4112
1a8f07e
17d3888
9b01fc8
073fd00
1774b8c
73cdf1e
a688f60
82f3aa1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
r""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this example! A few remarks:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tend toward optimizing my examples for HTML output. This usually means keeping the description close to the plotting code and plotting one thing at a time. Given that, I am okay with combine the two plots. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok makes sense |
||
======================================================== | ||
Post pruning decision trees with cost complexity pruning | ||
======================================================== | ||
|
||
In this example, decision tree classifiers are trained with a post pruning | ||
technique called minimal cost complexity pruning. This technique is | ||
parameterized by the complexity parameter, :math:`\alpha`. Greater values of | ||
:math:`\alpha` will prune more of the tree, thus creating a smaller trees. | ||
""" | ||
|
||
############################################################################### | ||
# Train decision tree classifiers | ||
# ------------------------------- | ||
# Train 40 decision tree classifiers with :math:`\alpha` from 0.00 to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It ranges from 0 to 0.04 in your example ( |
||
# 0.40. | ||
import numpy as np | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.datasets import load_breast_cancer | ||
from sklearn.tree import DecisionTreeClassifier | ||
|
||
X, y = load_breast_cancer(return_X_y=True) | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) | ||
|
||
alphas = np.linspace(0, 0.04, 40) | ||
clfs = [] | ||
for alpha in alphas: | ||
clf = DecisionTreeClassifier(random_state=0, alpha=alpha) | ||
clf.fit(X_train, y_train) | ||
clfs.append(clf) | ||
|
||
############################################################################### | ||
# Plot training and test scores vs alpha | ||
# -------------------------------------- | ||
# Calcuate and plot the the training scores and test accuracy scores | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a few typos (I'll document myself on tree pruning and will try to provide mode in-depth review later)
also I think you should avoid the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's OK in the docstrings since it will be rendered like regular rst by sphinx, but in the comments it is not necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These comments are rendered into html: https://42001-843222-gh.circle-artifacts.com/0/doc/auto_examples/tree/plot_cost_complexity_pruning.html There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooh ok I didn't know it worked like that, sorry |
||
# for our classifiers. With :math:`\alpha` equal to 0.0, the decision tree is | ||
# overfitting with a 1.0 training accuracy score. As the decision tree is | ||
# pruned the testing accuracy score increases up to a point and then decreases. | ||
import matplotlib.pyplot as plt | ||
|
||
train_scores = [] | ||
test_scores = [] | ||
for clf in clfs: | ||
train_scores.append(clf.score(X_train, y_train)) | ||
test_scores.append(clf.score(X_test, y_test)) | ||
|
||
fig, ax = plt.subplots() | ||
ax.set_xlabel("alpha") | ||
ax.set_ylabel("accuracy") | ||
ax.set_title("Accuracy vs alpha for training and testing sets") | ||
ax.plot(alphas, train_scores, label="train") | ||
ax.plot(alphas, test_scores, label="test") | ||
ax.legend() | ||
fig.show() | ||
|
||
############################################################################### | ||
# Plot total number of nodes vs alpha | ||
# ----------------------------------- | ||
# Plot the total number of nodes for our classifiers. As :math:`\alpha` | ||
# increases, the number of nodes decreases. | ||
node_counts = [clf.tree_.node_count for clf in clfs] | ||
fig, ax = plt.subplots() | ||
ax.set_xlabel("alpha") | ||
ax.set_ylabel("number of nodes") | ||
ax.set_title("Number of nodes vs alpha") | ||
ax.plot(alphas, node_counts) | ||
fig.show() |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1828,3 +1828,71 @@ def test_empty_leaf_infinite_threshold(): | |||||
infinite_threshold = np.where(~np.isfinite(tree.tree_.threshold))[0] | ||||||
assert len(infinite_threshold) == 0 | ||||||
assert len(empty_leaf) == 0 | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("criterion", CLF_CRITERIONS) | ||||||
@pytest.mark.parametrize( | ||||||
"dataset", set(DATASETS.keys()) - {"reg_small", "boston"}) | ||||||
@pytest.mark.parametrize( | ||||||
"tree_cls", [DecisionTreeClassifier, ExtraTreeClassifier]) | ||||||
def test_prune_tree_clf_are_subtrees(criterion, dataset, tree_cls): | ||||||
dataset = DATASETS[dataset] | ||||||
X, y = dataset["X"], dataset["y"] | ||||||
assert_pruning_creates_subtree(tree_cls, X, y) | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("criterion", REG_CRITERIONS) | ||||||
@pytest.mark.parametrize("dataset", DATASETS.keys()) | ||||||
@pytest.mark.parametrize( | ||||||
"tree_cls", [DecisionTreeRegressor, ExtraTreeRegressor]) | ||||||
def test_prune_tree_reg_are_subtrees(criterion, dataset, tree_cls): | ||||||
dataset = DATASETS[dataset] | ||||||
X, y = dataset["X"], dataset["y"] | ||||||
assert_pruning_creates_subtree(tree_cls, X, y) | ||||||
|
||||||
|
||||||
def assert_pruning_creates_subtree(estimator_cls, X, y): | ||||||
estimators = [] | ||||||
for alpha in np.linspace(0.0, 0.2, 11): | ||||||
est = estimator_cls( | ||||||
max_leaf_nodes=20, alpha=alpha, random_state=0).fit(X, y) | ||||||
estimators.append(est) | ||||||
|
||||||
for prev_est, next_est in zip(estimators[:-1], estimators[1:]): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Just a nitpick: |
||||||
assert_is_subtree(prev_est.tree_, next_est.tree_) | ||||||
|
||||||
|
||||||
def assert_is_subtree(tree, subtree): | ||||||
assert tree.node_count >= subtree.node_count | ||||||
assert tree.max_depth >= subtree.max_depth | ||||||
|
||||||
tree_c_left = tree.children_left | ||||||
tree_c_right = tree.children_right | ||||||
subtree_c_left = subtree.children_left | ||||||
subtree_c_right = subtree.children_right | ||||||
|
||||||
stack = [(0, 0)] | ||||||
while len(stack) > 0: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
tree_n_idx, subtree_n_idx = stack.pop() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would propose |
||||||
assert_array_almost_equal( | ||||||
tree.value[tree_n_idx], subtree.value[subtree_n_idx]) | ||||||
assert_almost_equal( | ||||||
tree.impurity[tree_n_idx], subtree.impurity[subtree_n_idx]) | ||||||
assert_almost_equal( | ||||||
tree.n_node_samples[tree_n_idx], | ||||||
subtree.n_node_samples[subtree_n_idx]) | ||||||
assert_almost_equal( | ||||||
tree.weighted_n_node_samples[tree_n_idx], | ||||||
subtree.weighted_n_node_samples[subtree_n_idx]) | ||||||
|
||||||
if (subtree_c_left[subtree_n_idx] == subtree_c_right[subtree_n_idx]): | ||||||
# is a leaf | ||||||
assert_almost_equal(-2, subtree.threshold[subtree_n_idx]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||||||
else: | ||||||
# not a leaf | ||||||
assert_almost_equal( | ||||||
tree.threshold[tree_n_idx], subtree.threshold[subtree_n_idx]) | ||||||
stack.append( | ||||||
(tree_c_left[tree_n_idx], subtree_c_left[subtree_n_idx])) | ||||||
stack.append( | ||||||
(tree_c_right[tree_n_idx], subtree_c_right[subtree_n_idx])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be the PR number