Skip to content

Commit

Permalink
Update id3.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tiepvupsu committed Jan 14, 2018
1 parent 4ffa6a0 commit 2172218
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions id3.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def fit(self, data, target):
def _entropy(self, ids):
# calculate entropy of a node with index ids
if len(ids) == 0: return 0
ids = [i+1 for i in ids]
# remove 0 freq since log is not defined at 0
ids = [i+1 for i in ids] # panda series index starts from 1
freq = np.array(self.target[ids].value_counts())
return entropy(freq)

Expand All @@ -83,7 +82,7 @@ def _split(self, node):
for val in values:
sub_ids = sub_data.index[sub_data[att] == val].tolist()
splits.append([sub_id-1 for sub_id in sub_ids])
# don't split if a node has too small points
# don't split if a node has too small number of points
if min(map(len, splits)) < self.min_samples_split: continue
# information gain
HxS= 0
Expand Down Expand Up @@ -124,4 +123,4 @@ def predict(self, new_data):
y = df.iloc[:, -1]
tree = DecisionTreeID3(max_depth = 3, min_samples_split = 2)
tree.fit(X, y)
print(tree.predict(X))
print(tree.predict(X))

0 comments on commit 2172218

Please sign in to comment.