Skip to content

Commit

Permalink
Added a max_to_prune argument to pruning_order
Browse files Browse the repository at this point in the history
  • Loading branch information
Steve Genoud committed Jul 10, 2012
1 parent c51698a commit c139f6a
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions sklearn/tree/tree.py
Expand Up @@ -305,14 +305,23 @@ def _copy(self):
return new_tree


def pruning_order(self):
@property
def leaves(self):
return _get_leaves(self.children)

def pruning_order(self, max_to_prune=None):
"""
Compute the order for which the tree should be pruned.
The algorithm used is weakest link pruning. It removes first the nodes
that improve the tree the least.
Parameters
----------
max_to_prune : int, optional (default=all the nodes)
maximum number of nodes to prune
Returns
-------
nodes : numpy array
Expand All @@ -326,14 +335,17 @@ def pruning_order(self):
"""

if max_to_prune is None:
max_to_prune = self.node_count

children = self.children.copy()
nodes = list()

while True:
node = _next_to_prune(self, children)
nodes.append(node)

if node == 0:
if (len(nodes) == max_to_prune) or (node == 0):
return np.array(nodes)

#Remove the subtree from the children array
Expand Down Expand Up @@ -364,17 +376,16 @@ def prune(self, n_leaves):
"""

nodes = self.pruning_order()
to_remove_count = len(nodes) - n_leaves + 1
nodes_to_remove = nodes[:to_remove_count]
to_remove_count = self.node_count - len(self.leaves) - n_leaves
nodes_to_remove = self.pruning_order(to_remove_count)

out_tree = self._copy()

for node in nodes_to_remove:
#TODO: Add a Tree method to remove a branch of a tree
out_tree.children[out_tree.children[node], :] = Tree.UNDEFINED
out_tree.children[node, :] = Tree.LEAF
out_tree.node_count -= 1
out_tree.node_count -= 2

return out_tree

Expand Down Expand Up @@ -1169,3 +1180,4 @@ def cv_scores_vs_n_leaves(clf, X, y, max_n_leaves=10, cv=10):
scores.append(loc_scores)

return zip(*scores)

0 comments on commit c139f6a

Please sign in to comment.