Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
svaante committed Apr 20, 2017
1 parent 135944b commit 9cbc4e5
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 18 deletions.
13 changes: 6 additions & 7 deletions id3/id3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
This is a module to be used as a reference for building other modules
"""
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import euclidean_distances
from sklearn.model_selection import train_test_split

from .node import Node
Expand Down Expand Up @@ -48,8 +46,8 @@ def _build(self, examples_idx, features_idx):
if features_idx.size == 0 or unique.size == 1:
return Node(classification_name)

calc_record = self._splitter.calc(examples_idx, features_idx)
split_records = self._splitter.split(examples_idx, calc_record)
calc_record = self.splitter_.calc(examples_idx, features_idx)
split_records = self.splitter_.split(examples_idx, calc_record)
new_features_idx = np.delete(features_idx,
np.where(features_idx ==
calc_record.feature_idx))
Expand Down Expand Up @@ -94,9 +92,9 @@ def fit(self, X, y, feature_names=None, check_input=True, pruner=None):
self : object
Returns self.
"""
X_, y = check_X_y(X, y)
self.feature_names = feature_names
self.pruner = pruner
X_, y = check_X_y(X, y)
prune = isinstance(self.pruner, BasePruner)
if prune:
X_, X_test, y, y_test = train_test_split(X_, y, test_size=0.2)
Expand All @@ -121,7 +119,7 @@ def fit(self, X, y, feature_names=None, check_input=True, pruner=None):
self.y_encoder = ExtendedLabelEncoder()
self.y = self.y_encoder.fit_transform(y)

self._splitter = Splitter(self.X,
self.splitter_ = Splitter(self.X,
self.y,
self.is_numerical,
self.X_encoders,
Expand All @@ -147,6 +145,7 @@ def predict(self, X):
y : array of shape = [n_samples]
Returns :math:`x^2` where :math:`x` is the first column of `X`.
"""
check_is_fitted(self, 'tree_')
X = check_array(X)
X_ = np.zeros(X.shape)
ret = np.empty(X.shape[0], dtype=X.dtype)
Expand Down
7 changes: 4 additions & 3 deletions id3/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,14 @@ def calc(self, examples_idx, features_idx):
tmp_calc_record = self._info_numerical(feature, y_)
else:
tmp_calc_record = self._info_nominal(feature, y_)
if self.feature_names is not None:
tmp_calc_record.feature_name = self.feature_names[idx]
if tmp_calc_record < calc_record:
ft_idx = features_idx[idx]
calc_record = tmp_calc_record
calc_record.feature_idx = ft_idx
calc_record.feature_name = self.feature_names[ft_idx]
if self.feature_names is not None:
calc_record.feature_name = self.feature_names[ft_idx]
else:
calc_record.feature_name = str(ft_idx)
calc_record.entropy, calc_record.class_counts = self._entropy(y_, True)
return calc_record

Expand Down
8 changes: 1 addition & 7 deletions id3/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
from sklearn.utils.estimator_checks import check_estimator
from id3 import (Id3Estimator, TemplateClassifier)
from id3 import Id3Estimator


def test_estimator():
return check_estimator(Id3Estimator)


def test_classifier():
return check_estimator(TemplateClassifier)


2 changes: 1 addition & 1 deletion id3/tests/test_id3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from id3.data import load_contact_lenses, load_will_wait, load_weather
from sklearn.preprocessing import LabelEncoder
from sklearn.datasets import load_iris
from id3 import export_pdf, export_graphviz
from id3 import export_graphviz

id3Estimator = Id3Estimator()

Expand Down

0 comments on commit 9cbc4e5

Please sign in to comment.