Skip to content

Commit

Permalink
Added tests for entropy and info functions in splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
ONordander committed Apr 20, 2017
1 parent 180ca50 commit c022307
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
5 changes: 2 additions & 3 deletions id3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .id3 import Id3Estimator
from . import id3
from .export import export_graphviz
from . import splitter

__all__ = ['Id3Estimator',
'id3',
'export_graphviz']
__all__ = ['Id3Estimator', 'id3', 'export_graphviz', 'splitter']
4 changes: 2 additions & 2 deletions id3/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def __init__(self,
attribute_counts=None,
class_counts=None):
self.split_type = split_type
self.info = info
self.feature_idx = feature_idx
self.feature_name = feature_name
self.entropy = entropy
self.info = info
self.pivot = pivot
self.class_counts = class_counts
self.attribute_counts = attribute_counts
Expand Down Expand Up @@ -100,7 +100,7 @@ def _info_nominal(self, x, y):
for value, p in zip(items, count):
info += p * self._entropy(y[x == value])
return CalcRecord(CalcRecord.NOM, info * np.true_divide(1, n),
attribute_counts=np.stack((items, count), axis=-1))
attribute_counts=np.vstack((items, count)).T)

def _info_numerical(self, x, y):
""" Info for numerical feature feature_values
Expand Down
30 changes: 25 additions & 5 deletions id3/tests/test_id3.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
from id3 import Id3Estimator
from sklearn.datasets import load_breast_cancer
from numpy.testing import assert_almost_equal, assert_equal
import numpy as np
from id3 import Id3Estimator
from id3 import export_graphviz
from id3.splitter import Splitter

id3Estimator = Id3Estimator(pruner="ReducedError")

y = np.array([0, 1, 2, 2, 3])
x_nominal_col = np.array(['nom', 'nom', 'nan', 'nom', 'nan'])
x_numerical_col = np.array([1, 2, 5, 5, 1])
test_splitter = Splitter(None, None, None, None, None)


"""
def test_entropy():
y = np.array([0, 1, 2, 2, 3])
x = 1 / 5. * np.log2(1 / (1 / 5.)) + 1 / 5. * np.log2(1 / (1 / 5.)) + \
2 / 5. * np.log2(1 / (2 / 5.)) + 1 / 5. * np.log2(1 / (1 / 5.))
assert_almost_equal(Id3Estimator()._entropy(y), x)
"""
assert_almost_equal(test_splitter._entropy(y), x)


def test_info_nominal():
record = test_splitter._info_nominal(x_nominal_col, y)
assert_equal(record.split_type, 1)
assert_equal(record.attribute_counts.size, 4)
assert_almost_equal(record.info, 1.3509775004326936)


def test_info_numerical():
record = test_splitter._info_numerical(x_numerical_col, y)
assert_equal(record.split_type, 0)
assert_equal(record.attribute_counts.size, 10)
assert_almost_equal(record.pivot, 2)
assert_almost_equal(record.info, 0.9)


def test_breast_cancer():
Expand Down

0 comments on commit c022307

Please sign in to comment.