From 9cbc4e572d1362f76f4bb377a7a5344916455abd Mon Sep 17 00:00:00 2001 From: svaante Date: Thu, 20 Apr 2017 15:03:36 +0200 Subject: [PATCH] fix test --- id3/id3.py | 13 ++++++------- id3/splitter.py | 7 ++++--- id3/tests/test_common.py | 8 +------- id3/tests/test_id3.py | 2 +- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/id3/id3.py b/id3/id3.py index 28486b5..1727f49 100644 --- a/id3/id3.py +++ b/id3/id3.py @@ -2,10 +2,8 @@ This is a module to be used as a reference for building other modules """ import numpy as np -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import check_X_y, check_array, check_is_fitted -from sklearn.utils.multiclass import unique_labels -from sklearn.metrics import euclidean_distances from sklearn.model_selection import train_test_split from .node import Node @@ -48,8 +46,8 @@ def _build(self, examples_idx, features_idx): if features_idx.size == 0 or unique.size == 1: return Node(classification_name) - calc_record = self._splitter.calc(examples_idx, features_idx) - split_records = self._splitter.split(examples_idx, calc_record) + calc_record = self.splitter_.calc(examples_idx, features_idx) + split_records = self.splitter_.split(examples_idx, calc_record) new_features_idx = np.delete(features_idx, np.where(features_idx == calc_record.feature_idx)) @@ -94,9 +92,9 @@ def fit(self, X, y, feature_names=None, check_input=True, pruner=None): self : object Returns self. """ + X_, y = check_X_y(X, y) self.feature_names = feature_names self.pruner = pruner - X_, y = check_X_y(X, y) prune = isinstance(self.pruner, BasePruner) if prune: X_, X_test, y, y_test = train_test_split(X_, y, test_size=0.2) @@ -121,7 +119,7 @@ def fit(self, X, y, feature_names=None, check_input=True, pruner=None): self.y_encoder = ExtendedLabelEncoder() self.y = self.y_encoder.fit_transform(y) - self._splitter = Splitter(self.X, + self.splitter_ = Splitter(self.X, self.y, self.is_numerical, self.X_encoders, @@ -147,6 +145,7 @@ def predict(self, X): y : array of shape = [n_samples] Returns :math:`x^2` where :math:`x` is the first column of `X`. """ + check_is_fitted(self, 'tree_') X = check_array(X) X_ = np.zeros(X.shape) ret = np.empty(X.shape[0], dtype=X.dtype) diff --git a/id3/splitter.py b/id3/splitter.py index c4a2f59..5da228d 100644 --- a/id3/splitter.py +++ b/id3/splitter.py @@ -177,13 +177,14 @@ def calc(self, examples_idx, features_idx): tmp_calc_record = self._info_numerical(feature, y_) else: tmp_calc_record = self._info_nominal(feature, y_) - if self.feature_names is not None: - tmp_calc_record.feature_name = self.feature_names[idx] if tmp_calc_record < calc_record: ft_idx = features_idx[idx] calc_record = tmp_calc_record calc_record.feature_idx = ft_idx - calc_record.feature_name = self.feature_names[ft_idx] + if self.feature_names is not None: + calc_record.feature_name = self.feature_names[ft_idx] + else: + calc_record.feature_name = str(ft_idx) calc_record.entropy, calc_record.class_counts = self._entropy(y_, True) return calc_record diff --git a/id3/tests/test_common.py b/id3/tests/test_common.py index 5bb9c12..014621d 100644 --- a/id3/tests/test_common.py +++ b/id3/tests/test_common.py @@ -1,12 +1,6 @@ from sklearn.utils.estimator_checks import check_estimator -from id3 import (Id3Estimator, TemplateClassifier) +from id3 import Id3Estimator def test_estimator(): return check_estimator(Id3Estimator) - - -def test_classifier(): - return check_estimator(TemplateClassifier) - - diff --git a/id3/tests/test_id3.py b/id3/tests/test_id3.py index 8e7ca56..cfc7a91 100644 --- a/id3/tests/test_id3.py +++ b/id3/tests/test_id3.py @@ -5,7 +5,7 @@ from id3.data import load_contact_lenses, load_will_wait, load_weather from sklearn.preprocessing import LabelEncoder from sklearn.datasets import load_iris -from id3 import export_pdf, export_graphviz +from id3 import export_graphviz id3Estimator = Id3Estimator()