In [347]:
import numpy as np
import pandas as pd
from sklearn import datasets
iris=datasets.load_iris()
names=iris.target_names
columns=iris.feature_names
x = np.array(iris.data, dtype = float)
y = np.array(iris.target, dtype = float)
x=x.tolist()
for rows in x:
    if type(rows).__module__ == np.__name__:
        rows=rows.tolist()
idx=0
for i in x:
    i.append(y[idx])
    idx+=1

In [339]:
def class_count(rows,idx=0): #for counting total number of each example in the dataset
    count={}
    for row in rows:
        example=row[-1]
        if example not in count:
            count[example]=1
        else:
            count[example]+=1
        idx+=1
    return count

class_count(x)

{0.0: 50, 1.0: 50, 2.0: 50}

In [340]:
def gini_impurity(rows): #for calculating gini impurity at one level
    count=class_count(rows)
    gini_impurity=0
    for example in count:
        prob_of_example=count[example]/sum(list(count.values()))
        gini_impurity+=(prob_of_example)**2
    gini_impurity=1-gini_impurity
    return gini_impurity   
gini_impurity(x)

0.6666666666666667

In [341]:
def gini_split(left,right,curr_gini_impurity): #for calculating the info gain after splitting into child nodes
    p=float(len(left)/(len(left)+len(right)))
    gini_split=(curr_gini_impurity)-p*(gini_impurity(left))-(1-p)*(gini_impurity(right))    
    return gini_split

In [342]:
def partition(rows,feature,value): #partitioning dataset on basis of a feature
    true_rows=[]
    false_rows=[]
    idx=columns.index(feature)
    for row in rows:
        if row[idx]>value:
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows,false_rows


In [343]:
def get_best_feature(rows):
    best_gini=0
    best_feature=None
    best_feature_val=0
    curr_gini=gini(rows)
    n_features=len(columns)
    for col in columns: #for each column
        idx=columns.index(col)
        values=set((row[idx]) for row in rows) #unique values in a column
        for val in values: #for each value
            true,false=partition(rows,col,val) #splitting the data
            if len(true)==0 or len(false)==0: #skipping this split if it doesn't divide the data
                continue
            gini_split_val=gini_split(true,false,curr_gini)
            if gini_split_val>best_gini:
                best_gini=gini_split_val
                best_feature=col
                best_feature_val=val
    return best_gini,best_feature,best_feature_val
get_best_feature(x)


(0.3333333333333334, 'petal length (cm)', 1.9)

In [344]:
class Leaf:
    def __init__(self,rows):
        self.predictions=class_count(rows)

class Decision_Node:
    def __init__(self,feature,val,left,right):
        self.feature=feature
        self.val=val
        self.left=left
        self.right=right

In [350]:
def build_tree(rows,level=0):
    print('Level ',level)
    curr_count=class_count(rows)
    curr_gini=gini_impurity(rows)
    gini_change,feature,val=get_best_feature(rows)
    print('Current gini impurity is: ',curr_gini)
    if gini_change==0:
        print('Reached leaf node')
        for i in curr_count:
            print('Count of ',i,': ',curr_count[i])
        return Leaf(rows)
    #print('Splitting on feature: ',feature,' with change in gini impurity: ',gini_change)
    print(feature,'>',val,'?')
    for i in curr_count:
        print('Count of ',i,': ',curr_count[i])
    left,right=partition(rows,feature,val)
    left_branch=build_tree(left,level+1)
    right_branch=build_tree(right,level+1)
    return Decision_Node(feature,val,left_branch,right_branch)


In [351]:
build_tree(x)

Level  0
Current gini impurity is:  0.6666666666666667
petal length (cm) > 1.9 ?
Count of  0.0 :  50
Count of  1.0 :  50
Count of  2.0 :  50
Level  1
Current gini impurity is:  0.5
petal width (cm) > 1.7 ?
Count of  1.0 :  50
Count of  2.0 :  50
Level  2
Current gini impurity is:  0.04253308128544431
petal length (cm) > 4.8 ?
Count of  1.0 :  1
Count of  2.0 :  45
Level  3
Current gini impurity is:  0.0
Reached leaf node
Count of  2.0 :  43
Level  3
Current gini impurity is:  0.4444444444444444
sepal length (cm) > 5.9 ?
Count of  1.0 :  1
Count of  2.0 :  2
Level  4
Current gini impurity is:  0.0
Reached leaf node
Count of  2.0 :  2
Level  4
Current gini impurity is:  0.0
Reached leaf node
Count of  1.0 :  1
Level  2
Current gini impurity is:  0.16803840877914955
petal length (cm) > 4.9 ?
Count of  1.0 :  49
Count of  2.0 :  5
Level  3
Current gini impurity is:  0.4444444444444444
petal width (cm) > 1.5 ?
Count of  1.0 :  2
Count of  2.0 :  4
Level  4
Current gini impurity is:  0.44444

<__main__.Decision_Node at 0x19142d4de50>