In [1]:
#Inputing the given data
header=["Age","Salary","Class"]
training_data=[ [30,65,"G"],
                [23,15,"B"],
                [40,75,"G"],
                [55,40,"B"],
                [55,100,"G"],
                [45,60,"G"],
                ]

#functions to calculate unique values in a column and a dictionary with the classes in a dataframe and the number of each class.
def unique_vals(rows,col):
    return set([row[col] for row in rows])

def class_counts(rows): #returns a dictionary with the class and the count for that class
    counts={} #empty dictionary with count for every class
    for row in rows:
        label=row[-1]
        if label not in counts:
            counts[label]=0
        counts[label]+=1
    return counts



In [2]:
#A class to create a question for a given attribute. It takes the index of the column in the dataframe and also the value which needs to be checked for.
#In this model, I have not pre sorted the array, since the input test data is not huge.
#Therefore, I have also not taken the average of the value vi,vi+1 with the minimum gini index at value vi. I have taken value vi for simplicity of the code.
class Question:

    def __init__(self,col_index,value):
        self.col_index=col_index
        self.value=value
    
    def match(self,example):
        val=example[self.col_index]
        return val<=self.value
    
    def __repr__(self):
        a=header[self.col_index]
        b=str(self.value)
        return f"Is {a} <= {b}?"

#Question(1,3)
#q=Question(0,35)
#example=training_data[0]
#q.match(example)


In [3]:
#Partitions the rows into true and false/ left and right rows, with the question given to the function.
def partition(rows,question):
    true_rows,false_rows=[],[]
    for example in rows:
        if question.match(example):
            true_rows.append(example)
        else:
            false_rows.append(example)
    return true_rows,false_rows
#true_rows, false_rows = partition(training_data, Question(0, 35))
#true_rows

In [4]:
#Calculates the gini_index of the input dataframe/ rows.
def gini_index(rows):
    counts=class_counts(rows)
    impurity=1
    for label in counts:
        prob_label=counts[label]/len(rows)
        impurity-=(prob_label)**2
    return impurity

#calculates the info gain/ cumulative gini index.
def info_gain(left,right,current_uncertainty):
    len1=len(left)
    len2=len(right)
    p=float(len1/(len1+len2))
    return current_uncertainty-(p*(gini_index(left))+(1-p)*(gini_index(right)))


In [5]:
#function to find the best split from the given dataframe
def find_best_split(rows):
    best_gain=0
    best_question=None
    current_uncertainty=gini_index(rows)
    n_features=len(rows[0])-1 #no of columns

    for col in range(n_features):
        values=set([row[col] for row in rows])

        for val in values:
            question=Question(col,val)
            true_rows,false_rows=partition(rows,question)
            if len(true_rows)==0 or len(false_rows)==0:
                continue
            gain=info_gain(true_rows,false_rows,current_uncertainty)
            if gain>=best_gain:
                best_gain,best_question=gain,question
    return best_gain,best_question

In [6]:
#Defines two classes, Leaf and Decision Node, used for storage purposes
class Leaf:
    def __init__(self,rows):
        self.predictions=class_counts(rows)
class Decision_Node:
    def __init__(self,question,true_branch,false_branch):
        self.question=question
        self.true_branch=true_branch
        self.false_branch=false_branch
       

In [7]:
#the main function where the tree is built. It is a recursive function that returns the Decision Node if gain is not equal to zero and a leaf if the gain is zero.
def build_tree(rows):
    if len(rows)==len(training_data):
        question=Question(0,35)
        true_rows,false_rows=partition(rows,question)
        uncertainty=gini_index(rows)
        gain=info_gain(true_rows,false_rows,uncertainty)
    else:
        gain, question = find_best_split(rows)
        if gain==0:
            return Leaf(rows)
        true_rows, false_rows=partition(rows,question)
        
    true_branch=build_tree(true_rows)
    false_branch=build_tree(false_rows)

    return Decision_Node(question,true_branch,false_branch)

In [8]:
#funtcion used to print the tree.
def print_tree(node, spacing=""):

    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)
        return

    # Print the question at this node
    print (spacing + str(node.question))

    # Call this function recursively on the true branch
    print (spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

In [9]:
#Inputing the training data and printing out the results
my_tree=build_tree(training_data)
print_tree(my_tree)

Is Age <= 35?
--> True:
  Is Salary <= 15?
  --> True:
    Predict {'B': 1}
  --> False:
    Predict {'G': 1}
--> False:
  Is Salary <= 40?
  --> True:
    Predict {'B': 1}
  --> False:
    Predict {'G': 3}
