From 3fb0358845ee57ce719abdaa3b1364f516ad788d Mon Sep 17 00:00:00 2001 From: Flytet Date: Tue, 25 Apr 2017 16:41:07 +0200 Subject: [PATCH] Added max depth as pre pruning and tests --- id3/builder.py | 3 ++- id3/tests/test_id3.py | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/id3/builder.py b/id3/builder.py index 04e3bf4..818e691 100644 --- a/id3/builder.py +++ b/id3/builder.py @@ -70,7 +70,8 @@ def _build(self, tree, examples_idx, features_idx, depth=0): classification_name = self.y_encoder.inverse_transform(classification) if (features_idx.size == 0 or items.size == 1 - or examples_idx.size < self.min_samples_split): + or examples_idx.size < self.min_samples_split + or depth >= self.max_depth): node = Node(classification_name) tree.classification_nodes.append(node) return node diff --git a/id3/tests/test_id3.py b/id3/tests/test_id3.py index 354250d..1de3939 100644 --- a/id3/tests/test_id3.py +++ b/id3/tests/test_id3.py @@ -42,9 +42,24 @@ def test_simple(): """ -def test_breast_cancer(): +def test_fit(): bunch = load_breast_cancer() - id3Estimator = Id3Estimator(prune=True, min_samples_split=20) + id3Estimator = Id3Estimator() id3Estimator.fit(bunch.data, bunch.target, bunch.feature_names) + assert_equal(id3Estimator.tree_.root.value, "worst perimeter") + assert_equal(len(id3Estimator.tree_.classification_nodes), 64) + assert_equal(len(id3Estimator.tree_.feature_nodes), 63) export_graphviz(id3Estimator.tree_, "cancer.dot") + + id3Estimator = Id3Estimator(max_depth=2) + id3Estimator.fit(bunch.data, bunch.target, bunch.feature_names) + assert_equal(id3Estimator.tree_.root.value, "worst perimeter") + assert_equal(len(id3Estimator.tree_.classification_nodes), 4) + assert_equal(len(id3Estimator.tree_.feature_nodes), 3) + + id3Estimator = Id3Estimator(min_samples_split=20) + id3Estimator.fit(bunch.data, bunch.target, bunch.feature_names) + assert_equal(id3Estimator.tree_.root.value, "worst perimeter") + assert_equal(len(id3Estimator.tree_.classification_nodes), 35) + assert_equal(len(id3Estimator.tree_.feature_nodes), 34)