In [136]:
import queue

In [1]:
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

In [2]:
col_names=['color','diameter','label']

In [9]:
def unique_labels(data,col):  #function to return unique values for anycolumn
    return set([row[col] for row in data])

In [11]:
unique_labels(training_data,0)

{'Green', 'Red', 'Yellow'}

In [16]:
def class_counts(data,label_loc=-1): #count the values in the dataset #Presumes that the label is at the last column
    d={}
    for rows in data:
        label=rows[-1]
        if label not in d:
            d[label]=1
        else:
            d[label]+=1
            
    return d

In [17]:
class_counts(training_data)

{'Apple': 2, 'Grape': 2, 'Lemon': 1}

In [18]:
def is_numeric(value):
    if type(value)==float or type(value)==int:
        return True
    
    return False

In [26]:
is_numeric(90),is_numeric('abc')

(True, False)

In [44]:
class Question():
    
    def __init__(self,column,value):
        self.column=column
        self.value=value
        
    def match(self,example):
        exp_val=example[self.column]
        
        if is_numeric(exp_val):
            return exp_val>=self.value
        else:
            return exp_val==self.value
        
    def __repr__(self):
        if is_numeric(self.value):
            condition='>='
        else:
            condition='=='
        return f'Is {col_names[self.column]} {str(condition)} {str(self.value)} ?'

In [50]:
q=Question(1,3)
q

Is diameter >= 3 ?

In [54]:
q.match(training_data[-1]) #last sample indeed has a value >=3

True

In [60]:
def partition(data,question):
    true_clause=[]
    false_clause=[]
    for row in data:
        if question.match(row):
            true_clause.append(row)
        else:
            false_clause.append(row)
            
    return true_clause,false_clause

In [76]:
true_row,false_row=partition(training_data,q)
true_row,false_row

([['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']],
 [['Red', 1, 'Grape'], ['Red', 1, 'Grape']])

In [62]:
def gini_impurity(rows):
    initial_value=1
    counts=class_counts(rows)
    for label in counts:
        lbl_count=counts[label]
        prob_label=lbl_count/len(rows)
        initial_value-=prob_label**2
        
    return initial_value

In [75]:
gini_impurity(false_row)

0.0

In [77]:
def info_gain(left_split,right_split,orignal_gini):
    d=len(left_split)+len(right_split)
    
    return orignal_gini-((len(left_split)/d)*gini_impurity(left_split)+(len(right_split)/d)*gini_impurity(right_split))

In [90]:
q=Question(1,3)
print(q)
true_row,false_row=partition(training_data,q)
info_gain(true_row,false_row,gini_impurity(training_data))

Is diameter >= 3 ?


0.37333333333333324

In [107]:
def best_split(data):
    n_features=len(data[0])-1  #last feature label is omitted
    best_gain=0
    best_question=None
    
    for f in range(0,n_features):
        unique_vals=set([rows[f] for rows in data])
        
        for values in unique_vals:
            question=Question(f,values)
            
            true_row,false_row=partition(data,question)
            gain=info_gain(true_row,false_row,gini_impurity(data))
            
            if gain >=best_gain:
                best_gain=gain
                best_question=question
                
    return best_gain,best_question

In [108]:
best_split(training_data)

(0.37333333333333324, Is diameter >= 3 ?)

In [149]:
class Leaf_Node():
    
    def __init__(self,rows):
        self.predictions=class_counts(rows)  
        self.prediction=max(self.predictions,key=lambda x:x[1])

In [150]:
class Decision_Node():
    
    def __init__(self,left_split,right_split,question):
        self.question=question
        self.left_split=left_split
        self.right_split=right_split
    

In [151]:
def build_tree(dataset):
    
    
    gain,question=best_split(dataset)
    
    if gain==0:
        return Leaf_Node(dataset)
    
    true_data,false_data=partition(dataset,question)
    
    right_tree=build_tree(true_data)
    false_tree=build_tree(false_data)
    
    
    return Decision_Node(false_tree,right_tree,question)
    

In [152]:
root_decision=build_tree(training_data)

In [158]:
def print_decision_tree(root):
    q=queue.Queue()
    q.put(root)
    Level=0
    while not (q.empty()):
        node=q.get()
        
        if isinstance(node,Leaf_Node):
            print('Predictions',str(node.predictions))
        
        else:
            
            print(str(Level)+' '+str(node.question))
            Level+=1
            q.put(node.left_split)
            q.put(node.right_split)

In [159]:
print_decision_tree(root_decision)

0 Is diameter >= 3 ?
Predictions {'Grape': 2}
1 Is color == Green ?
Predictions {'Apple': 1, 'Lemon': 1}
Predictions {'Apple': 1}


In [163]:
global predictions
predictions=[]
def classify(data,root):
    if isinstance(root,Leaf_Node):
        predictions.append(root.prediction)
        return
        
    if root.question.match(data):
        classify(data,root.right_split)
        
    else:
        classify(data,root.left_split)
    

In [164]:
testing_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 4, 'Apple'],
    ['Red', 2, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]
for row in testing_data:
    classify(row,root_decision)

In [165]:
predictions

['Apple', 'Apple', 'Grape', 'Grape', 'Apple']