Skip to content

Commit

Permalink
Added max depth as pre pruning and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ONordander committed Apr 25, 2017
1 parent 1799b48 commit 3fb0358
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
3 changes: 2 additions & 1 deletion id3/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def _build(self, tree, examples_idx, features_idx, depth=0):
classification_name = self.y_encoder.inverse_transform(classification)
if (features_idx.size == 0
or items.size == 1
or examples_idx.size < self.min_samples_split):
or examples_idx.size < self.min_samples_split
or depth >= self.max_depth):
node = Node(classification_name)
tree.classification_nodes.append(node)
return node
Expand Down
19 changes: 17 additions & 2 deletions id3/tests/test_id3.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,24 @@ def test_simple():
"""


def test_breast_cancer():
def test_fit():
bunch = load_breast_cancer()

id3Estimator = Id3Estimator(prune=True, min_samples_split=20)
id3Estimator = Id3Estimator()
id3Estimator.fit(bunch.data, bunch.target, bunch.feature_names)
assert_equal(id3Estimator.tree_.root.value, "worst perimeter")
assert_equal(len(id3Estimator.tree_.classification_nodes), 64)
assert_equal(len(id3Estimator.tree_.feature_nodes), 63)
export_graphviz(id3Estimator.tree_, "cancer.dot")

id3Estimator = Id3Estimator(max_depth=2)
id3Estimator.fit(bunch.data, bunch.target, bunch.feature_names)
assert_equal(id3Estimator.tree_.root.value, "worst perimeter")
assert_equal(len(id3Estimator.tree_.classification_nodes), 4)
assert_equal(len(id3Estimator.tree_.feature_nodes), 3)

id3Estimator = Id3Estimator(min_samples_split=20)
id3Estimator.fit(bunch.data, bunch.target, bunch.feature_names)
assert_equal(id3Estimator.tree_.root.value, "worst perimeter")
assert_equal(len(id3Estimator.tree_.classification_nodes), 35)
assert_equal(len(id3Estimator.tree_.feature_nodes), 34)

0 comments on commit 3fb0358

Please sign in to comment.