In [1]:
import pandas as pd 
import numpy as np
import math


class Node:
    def __init__(self):
        self.children = []
        self.value = ""
        self.isLeaf = False
        self.pred = ""


def entropy(examples):
    pos =0.0
    neg=0.0
    for _,row in examples.iterrows():
        if row["classification"] == "Yes":
            pos +=1
        else:
            neg +=1
    if pos == 0.0 or neg == 0.0:
        return 0.00
    else:
        p = pos/(pos+neg)
        n = neg/(pos+neg)
        return -(p*math.log(p,2)+ n*math.log(n,2))
        
def info_gain(examples,attr):
    uniq = np.unique(examples[attr])
    gain = entropy(examples)
    for u in uniq:
        subdata = examples[examples[attr] == u]
        sub_e = entropy(subdata)
        gain -= (float(len(subdata))/float(len(examples))) * sub_e
    print(gain)
    return gain 

def ID3(examples,attr):
    root = Node()
    max_gain = 0
    max_feat= ""
    print(attr)
    for feature in attr:
        gain = info_gain(examples,feature)
        if gain > max_gain:
            max_feat = feature
            max_gain = gain
    root.value = max_feat
    print("max_feat",max_feat)
    uniq = np.unique(examples[max_feat])
    print(uniq)
    for u in uniq:
        subdata = examples[examples[max_feat] == u]
        print("entropy",entropy(subdata))
        if entropy(subdata) == 0.0:
            newNode = Node()
            newNode.isLeaf = True
            newNode.value = u
            newNode.pred = np.unique(subdata["classification"])
            root.children.append(newNode)
        else:
            dummyNode = Node()
            dummyNode.value = u
            new_attrs = attr.copy()
            new_attrs.remove(max_feat)
            child = ID3(subdata,new_attrs)
            dummyNode.children.append(child)
            root.children.append(dummyNode)
    return root

def printTree(root:Node,depth = 0):
    for i in range(depth):
        print("\t",end="")
    print(root.value,end="")
    if root.isLeaf:
        print("->",root.pred)
    else:
        for child in root.children:
            printTree(child,depth+1)

data=pd.read_csv('../data/play.csv')

features = [feat for feat in data]
features.remove("classification")
d = ID3(data,features)
printTree(d)


['A1', 'A2', 'A3']
0.034851554559677256
0.034851554559677256
0.09127744624168022
max_feat A3
['High' 'Normal']
entropy 0.6500224216483541
['A1', 'A2']
0.3166890883150208
0.10917033867559889
max_feat A1
[False  True]
entropy 1.0
['A2']
1.0
max_feat A2
['Cool' 'Hot']
entropy 0.0
entropy 0.0
entropy 0.0
entropy 1.0
['A1', 'A2']
0.31127812445913283
0.31127812445913283
max_feat A1
[False  True]
entropy 0.9182958340544896
['A2']
0.0
max_feat A2
['Cool']
entropy 0.9182958340544896
[]
max_feat 


KeyError: ''