In [69]:
import numpy as np

In [2]:
tr_data = [['Green',3,'Apple'],['Red',1,'Grapes'],['Red',1,'Grapes'],['Yellow',3,'Lemon'],['Yellow',3,'Apple']]
head = ["colour","dia","label"]

In [3]:
def unique_vals(Data,col):
    return set([row[col] for row in Data])

In [4]:
unique_vals(tr_data,0)

{'overcast', 'rain', 'sunny'}

In [5]:
def class_cnt(Data):
    cnt = {}
    for row in Data:
        label=row[-1]
        if label not in cnt:
            cnt[label] = 0
        cnt[label] += 1
    return cnt

In [6]:
class_cnt(tr_data)

{'no': 5, 'yes': 9}

In [7]:
class Question:
    def __init__(self,column,value):
        self.column = column
        self.value = value
    def match(self,example):
        val = example[self.column]
        return val == self.value
    def __repr__(self):
        return "IS %s %s %s ?" %(head[self.column]," == ",str(self.value))

In [8]:
Question(0,"no")

NameError: name 'head' is not defined

In [77]:
q = Question(1,3)
q.match(tr_data[0])

True

In [78]:
q

IS dia  ==  3 ?

In [1]:
def partition(Data,question):
    true_r,false_r = [],[]
    for row in Data:
        if(question.match(row)):
            true_r.append(row)
        else:
            false_r.append(row)
    return true_r,false_r

In [80]:
true_r,false_r = partition(tr_data,Question(0,"Red"))
print(true_r)
print(false_r)

[['Red', 1, 'Grapes'], ['Red', 1, 'Grapes']]
[['Green', 3, 'Apple'], ['Yellow', 3, 'Lemon'], ['Yellow', 3, 'Apple']]


In [81]:
def gini(Data):
    counts = class_cnt(Data)
    impurity=1
    for lbl in counts:
        p_of_lbl = counts[lbl]/float(len(Data))
        impurity-=p_of_lbl**2
    return impurity

In [82]:
gini(true_r)

0.0

In [83]:
def info_gain(l,r,curr_uncertainity):
    p = float(len(l)/(len(l)+len(r)))
    return curr_uncertainity- p*gini(l)-(1-p)*gini(r)

In [84]:
info_gain(true_r,false_r,gini(tr_data))

0.37333333333333324

In [85]:
def find_best_split(Data):
    best_gain=0
    best_ques= None
    curr_uncertainity = gini(Data)
    n_features = len(Data[0])-1
    for col in range(n_features):
        values = unique_vals(Data,col)
        for val in values:
            question = Question(col,val)
            true_r,false_r=partition(Data,question)
            if(len(true_r)==0 or len(false_r)==0):
                continue
            gain = info_gain(true_r,false_r,curr_uncertainity)
            if gain>=best_gain:
                best_gain,best_ques = gain,question
    return best_gain,best_ques

In [86]:
best_gain,best_ques=find_best_split(tr_data)
print(best_gain)
print(best_ques)

0.37333333333333324
IS dia  ==  3 ?


In [87]:
class Leaf: 
    def __init__(self,Data):
        self.predictions = class_cnt(Data)

In [88]:
class Decision_Node:
        def __init__(self,question,true_branch,false_branch):
            self.question=question
            self.true_branch=true_branch
            self.false_branch=false_branch
            #print(self.question)

In [89]:
def build_tree(Data,i=0):
    gain,question = find_best_split(Data)
    if gain==0:
        return Leaf(Data)
    true_r,false_r = partition(Data,question)
    true_branch=build_tree(true_r,i)
    false_branch=build_tree(false_r,i)
    return Decision_Node(question,true_branch,false_branch)

In [90]:
my_tree = build_tree(tr_data)
print(my_tree)

<__main__.Decision_Node object at 0x7efdd0570a58>


In [91]:
def prnt_Tree(node,spacing=""):
    if isinstance(node,Leaf):
        print(spacing+"Predict",node.predictions)
        return
    print(spacing+str(node.question))
    print(spacing+"--> True:")
    prnt_Tree(node.true_branch,spacing+"\t")
    print(spacing+"--> False:")
    prnt_Tree(node.false_branch,spacing+"\t")

In [92]:
prnt_Tree(my_tree)

IS dia  ==  3 ?
--> True:
	IS colour  ==  Yellow ?
	--> True:
		Predict {'Lemon': 1, 'Apple': 1}
	--> False:
		Predict {'Apple': 1}
--> False:
	Predict {'Grapes': 2}


In [93]:
print(tr_data)

[['Green', 3, 'Apple'], ['Red', 1, 'Grapes'], ['Red', 1, 'Grapes'], ['Yellow', 3, 'Lemon'], ['Yellow', 3, 'Apple']]


In [95]:
def prnt_leaf(cnt):
    total = sum(cnt.values())*1.0
    probs = {}
    for lbl in cnt.keys():
        probs[lbl]=str(int(cnt[lbl]/total*100))+"%"
    return probs

In [96]:
def classify(row,node):
    if isinstance(node,Leaf):
        return node.predictions
    if node.question.match(row):
        return classify(row,node.true_branch)
    else:
        return classify(row,node.false_branch)

In [97]:
test_data = [["Red",1,"Apple"],["Yellow",3,"Apple"]]

In [103]:
for row in test_data:
    print("Actual: %s. Prediction : %s" %(row[-1],prnt_leaf(classify(row,my_tree))))

Actual: Apple. Prediction : {'Grapes': '100%'}
Actual: Apple. Prediction : {'Lemon': '50%', 'Apple': '50%'}
