In [34]:
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris

In [35]:
import matplotlib.pyplot as plt
import seaborn as sns
import plotly as py
py.offline.init_notebook_mode(connected=True)
%matplotlib inline

In [36]:
iris = load_iris()
x=iris.data
y=iris.target

In [37]:
data=np.c_[x,y]

In [38]:
header=['sepal_length','sepal_width','petal_length','petal_width','species']
iris_df=pd.DataFrame(data=data,columns=header)

In [39]:
class Question:
    #initialise column and value variables->
    #eg->if ques is ->is sepal_length>=1cm then
    #sepal_length==col and 1cm=value
    def __init__(self,column,value):
        self.column=column
        self.value=value
    #it matches wheter the given data is in accordace with the value set or not
    #returns true and false accordingly
    def match(self,data):
        value=data[self.column]
        return value>=self.value
    # This is just a helper method to print
    # the question in a readable format.
    def __repr__(self):
        condition = ">="
        return "Is %s %s %s?" % (
            header[self.column], condition, str(self.value))

In [40]:
#count the unique values of labels and store them in a dictionary
def count_values(rows):
    #will return a dictionary with species values as key and frequency as values
    count={}
    #takes whole dataset in as argument
    for row in  rows:
        #traverse on each datapoint
        label=row[-1]
        #labels are in the last column
        #if label is not even once come initialise it
        if label not in count:
            count[label]=0
        #increase the count of present label by 1
        count[label]+=1
    return count 

In [41]:
#spliting the data based on the respective ques.
def partition(rows,question):
    #intialise two seprate lists 
    true_row,false_row=[],[]
    for row in rows:
        #traverse on each datapoint
        #match the given datapoint with the respective question
        if question.match(row):
            #if question.match returns true aka value is satisfied
            #append the given row in true row list
            true_row.append(row)
        else:
            false_row.append(row)
    return true_row,false_row

In [42]:
def gini(rows):
    #stores dictionary of frequency of labels
    count=count_values(rows)
    #initialise impurity as 1
    impurity=1
    for label in count:
        #probablity of a unique label
        probab_of_label=count[label]/float(len(rows))
        #calculation gini impurity acc to formula
        impurity-=probab_of_label**2
    return impurity

In [43]:
def entropy(rows):
    #initialise entropy
    entropy=0
    from math import log
    #calculating log(x) in base 2
    log2=lambda x:log(x)/log(2)
    count=count_values(rows)
    #storing and traversing the dictionary
    for label in count:
        #probablity of each unique label
        p=count[label]/float(len(rows))
        #calculating entropy
        entropy-=p*log2(p)
    return entropy

In [44]:
def info_gain_gini(current,left,right):
    #porbab of one branch
    p =float(len(left))/len(left)+len(right)
    #formula for info gian
    return current-p*gini(left)-(1-p)*gini(right)

In [45]:
def info_gain_entropy(current,left,right):
    p =float(len(left))/len(left)+len(right)
    return current-p*entropy(left)-(1-p)*entropy(right)

In [46]:
def best_split(rows):
    #initialise best gain and best question
    best_gain=0
    best_question=None
    #calculate the current_gain
    current=gini(rows)
    #total number of features
    features=len(rows[0])-1
    for col in range(features):
        #collects all unique classes for a feature
        values=set([row[col] for row in rows])
        for val in values:
            #traverse each unique classs
            #ask the corresponding question
            question=Question(col,val)
            #devide the data based on that ques
            true_rows,false_rows=partition(rows,question)
            if len(true_rows)==0 or len(false_rows) ==0:
                #no use go to next iteration
                continue
            #calculate corresponding gain
            gain=info_gain_gini(current,true_rows,false_rows)
            #if gain is > than the best replace
            if gain>=best_gain:
                best_gain,best_question=gain,question
            #iterate through each unique class of each feature and return the best gain and best question     
    return best_gain,best_question

In [47]:
class DecisionNode:
    def __init__(self,question,true_branch,false_branch):
        #question object stores col and val variables regarding the question of that node
        self.question = question
        #this stores the branch that is true
        self.true_branch = true_branch
        #this stores the false branch
        self.false_branch = false_branch

In [48]:
class Leaf:
    def __init__(self,rows):
        #stores unique labels and their values in predictio
        self.predictions=count_values(rows)

In [49]:
def build_tree(rows):
    #takes the whole dataset as argument
    #gets the best gain and best question
    gain,question=best_split(rows)
    
    #if gian=0 i.e. leaf conditions are satisfied
    if gain==0:
        #make a leaf object and return
        return Leaf(rows)
    # If we reach here, we have found a useful feature / value
    # to partition on.
    true_rows, false_rows = partition(rows, question)

    # Recursively build the true branch.
    true_branch = build_tree(true_rows)

    # Recursively build the false branch.
    false_branch = build_tree(false_rows)

    #returns the root question node storing branches as well as the quesiton
    return DecisionNode(question, true_branch, false_branch)

In [50]:
tree=build_tree(data)

In [51]:
def print_tree(node,indentation=""):
    '''printing function'''
    #base case means we have reached the leaf
    #if the node object is of leaf type
    if isinstance(node,Leaf):
        print(indentation+"PREDICTION",node.predictions)
        return 
    #print the question at node
    print(indentation + str(node.question))
    
    #call the function on true branch 
    print(indentation+ "True Branch")
    print_tree(node.true_branch,indentation + " ")
    
    #on flase branch
    print(indentation+ "False Branch")
    print_tree(node.false_branch,indentation + " ")

In [52]:
print_tree(tree)

Is petal_length >= 6.9?
True Branch
 PREDICTION {2.0: 1}
False Branch
 Is sepal_width >= 4.4?
 True Branch
  PREDICTION {0.0: 1}
 False Branch
  Is sepal_width >= 4.2?
  True Branch
   PREDICTION {0.0: 1}
  False Branch
   Is sepal_length >= 7.9?
   True Branch
    PREDICTION {2.0: 1}
   False Branch
    Is sepal_width >= 4.1?
    True Branch
     PREDICTION {0.0: 1}
    False Branch
     Is sepal_width >= 4.0?
     True Branch
      PREDICTION {0.0: 1}
     False Branch
      Is petal_length >= 6.7?
      True Branch
       PREDICTION {2.0: 2}
      False Branch
       Is petal_length >= 6.6?
       True Branch
        PREDICTION {2.0: 1}
       False Branch
        Is petal_length >= 6.3?
        True Branch
         PREDICTION {2.0: 1}
        False Branch
         Is sepal_length >= 7.7?
         True Branch
          PREDICTION {2.0: 1}
         False Branch
          Is sepal_length >= 7.4?
          True Branch
           PREDICTION {2.0: 1}
          False Branch
           Is 