# CART Decision Tree (Classification)

In [1149]:
def unique_vals(rows, col):
    """ Find unique values for a column"""
    return set([row[col] for row in rows])

In [1100]:
def class_counts(X, y):
    """ Counts num of each type of target class in dataset"""
    counts = {} # dictionary for label -> count
    for i in range(len(X)):
        label = y[i]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

In [1101]:
def is_numeric(value):
    #this method is required to perform comparisions.
    #if column is numeric then exact match is searched for or else greater than equal
    return isinstance(value, int) or isinstance(value, float)

In [1102]:
class Question:
    ''' Ques to partition dataset.
    It records a column num (0 for color) and a name (eg - Green is the value)
    Match method compares feature value in example to that stored in question.'''
    
    def __init__(self, column, value, feature_names):
        self.column = column
        self.value = value
        self.feature_names = feature_names
        
    def match(self, example):
        #Compares feature value in example to feature value in question
        val = example[self.column]
        #we go left if a feature is less than a threshold, right otherwise
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value
    
    #repr returns readable value of object instead of object's memory location
    def __repr__(self):
        #Method to print ques in readable format
        condition = '=='
        if is_numeric(self.value):
            condition = '>='
        return 'Is %s %s %s?' %(self.feature_names[self.column], condition, str(self.value))

In [1103]:
def partition(rows, target, question):
        ''' For each row, test if it matches question
        If yes, add to "true" rows else "false" '''
        
        true_rows, false_rows, true_y, false_y = [], [], [], []
        for i in range(len(rows)):
            if question.match(rows[i]):
                true_rows.append(rows[i])
                true_y.append(target[i])
            else:
                false_rows.append(rows[i])
                false_y.append(target[i])
        return true_rows, false_rows, true_y, false_y

In [1104]:
def find_best_split(X, target, feature_names):
        ''' Find best ques to ask by iterating over every feature / value
        and calculating information gain.'''
        
        #Aim is to maximize information gain
        best_gain = 0 #keep track of best information gain
        best_question = None #keep train of feature / value that produced it
        
        #Uncertainity of entire sample space
        current_uncertainity = gini_impurity(X, target)
        n_features = len(X[0]) #number of independent var
        for col in range(n_features): #for each feature
            values = set([row[col] for row in X]) #unique values in column
            for val in values: #for each value
                #eg for val 1 -> Color, green
                question = Question(col, val, feature_names)
                
                #try for splitting dataset
                #will check how many rows say color is green
                true_rows, false_rows, true_y, false_y = partition(X, target, question)
                
                #skip the split if it doesn't divide the dataset
                #Here we are getting a pure feature, thus on the basis on this, we get a straightforwd decision.
                #No split required
                if len(true_rows) == 0 or len(false_rows) == 0:
                    continue
                
                gain = info_gain(true_rows, false_rows, true_y, false_y, current_uncertainity)
                #print('Gain:'+str(gain)+' for question: '+str(question))
                
                if gain>=best_gain:
                    best_gain, best_question = gain, question
        #print('Best Gain:'+str(best_gain)+' for question: '+str(best_question))   
        return best_gain, best_question

# Since we're creating CART, impurity critirion is Gini Index

In [1105]:
def gini_impurity(X, y):
        #Goal is to minimize impurity
        #Each node is split so that the Gini impurity of the children 
        #(more specifically the average of the Gini of the children weighted by their size) is minimized.
        counts = class_counts(X, y)
        impurity = 1
        #Checking for each unique value(class) of target variable
        for label in counts:
            label_probability = counts[label] / float(len(X))
            #impurity = 1 - sum(prob of each (squared))
            impurity -= label_probability ** 2
        return impurity
    
def info_gain(X_left, X_right, y_left, y_right, current_uncertainity):
        ''' Uncertainity of starting node, minus weighted impurity of
        child nodes'''
        #left refers to true rows and right to false
        p = float(len(X_left)) / (len(X_left) + len(X_right))
        #current uncertainity refers to impurity at parent
        return current_uncertainity - p * gini_impurity(X_left, y_left) - (1 - p) * gini_impurity(X_right, y_right)

In [1106]:
class Leaf:
    ''' Leaf node classifies data.
    Holds dictionary of class -> number of times it appears in training data'''
    def __init__(self, X, y):
        self.predictions = class_counts(X, y)

In [1107]:
class Decision_Node:
    '''Asks a question
    Holds a reference to question and to the two child nodes'''
    def __init__(self, question, true_branch, false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [1108]:
def build_tree(X, y, feature_names):
    '''Builds a tree'''
    gain, question = find_best_split(X, y, feature_names)
    #Base case. No further info gain
    #Add depth here for pruning
    #print(gain, question)
    if gain == 0:
        return Leaf(X, y)
    # Here, we've found a useful feature to partition on
    true_rows, false_rows, true_y, false_y = partition(X, y, question)
    #Recursively build true and false branches
    true_branch = build_tree(true_rows, true_y, feature_names)
    false_branch = build_tree(false_rows, false_y, feature_names)
    
    #Return a ques node
    #Records the best feature to ask at this point as well as branches to follow depending on the answer
    return Decision_Node(question, true_branch, false_branch)

In [1109]:
def print_tree(node, spacing=''):
    #Base case: if we reached leaf
    if isinstance(node, Leaf):
        print(spacing+"Predict", node.predictions)
        return
    
    #Print question at this node
    print(spacing + str(node.question))
    
    #Call this function recursively on the true branch
    print(spacing+'--> True:')
    print_tree(node.true_branch, spacing+" ")
    
    #Call this function recursively on the false branch
    print(spacing+'--> False:')
    print_tree(node.false_branch, spacing+" ")
    
    

In [1110]:
def classify(row, node):
    #Base case: we've reached a leaf
    if isinstance(node, Leaf):
        return node.predictions
    
    #Decide whether to follow true or false branch
    #Compare feature/ value stored in node to example we're considering
    
    if node.question.match(row):
        return classify(row, node.true_branch)
    else:
        return classify(row, node.false_branch)

In [1111]:
def print_leaf(counts):
    ''' Prints the predictions at a leaf. '''
    total = sum(counts.values())*1.0
    probability= {}
    for label in counts.keys():
        probability[label] = str(int(counts[label] * 100 / total)) + "%"
    return probability

In [1120]:
def predict(counts):
    ''' Prints the predictions at a leaf. '''
    total = sum(counts.values())*1.0
    probability= {}
    label = ''
    for lbl in counts.keys():
        label = lbl 
        #probability[label] = str(int(counts[label] * 100 / total)) + "%"
    return label

In [950]:
from sklearn.model_selection import train_test_split as split
from sklearn.datasets import load_iris as iris

In [951]:
data = iris().data
target = iris().target

In [1142]:
X_train, X_test, y_train, y_test = split(data, target, test_size=0.33, random_state=42)

In [1144]:
iris_dt = build_tree(X_train, y_train, iris().feature_names)
count = 0
for i in range(len(X_test)):
    if predict(classify(X_test[i], iris_dt)) == y_test[i]:
               count += 1
print ('Prediction %age', count*100/float(len(X_test)))
#print(len(X_train))
#print_tree(my_tree)

50
Prediction %age 100.0


In [1146]:
print_tree(iris_dt)

Is petal width (cm) >= 1.0?
--> True:
 Is petal width (cm) >= 1.8?
 --> True:
  Is petal length (cm) >= 4.9?
  --> True:
   Predict {2: 28}
  --> False:
   Is sepal width (cm) >= 3.2?
   --> True:
    Predict {1: 1}
   --> False:
    Predict {2: 2}
 --> False:
  Is petal length (cm) >= 5.6?
  --> True:
   Predict {2: 2}
  --> False:
   Is petal width (cm) >= 1.7?
   --> True:
    Is petal length (cm) >= 5.0?
    --> True:
     Predict {1: 1}
    --> False:
     Predict {2: 1}
   --> False:
    Is petal length (cm) >= 5.0?
    --> True:
     Is petal width (cm) >= 1.6?
     --> True:
      Predict {1: 1}
     --> False:
      Predict {2: 1}
    --> False:
     Predict {1: 32}
--> False:
 Predict {0: 31}


In [1150]:
X_train

array([[5.7, 2.9, 4.2, 1.3],
       [7.6, 3. , 6.6, 2.1],
       [5.6, 3. , 4.5, 1.5],
       [5.1, 3.5, 1.4, 0.2],
       [7.7, 2.8, 6.7, 2. ],
       [5.8, 2.7, 4.1, 1. ],
       [5.2, 3.4, 1.4, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [5.1, 3.8, 1.9, 0.4],
       [5. , 2. , 3.5, 1. ],
       [6.3, 2.7, 4.9, 1.8],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [5.6, 2.7, 4.2, 1.3],
       [5.1, 3.4, 1.5, 0.2],
       [5.7, 3. , 4.2, 1.2],
       [7.7, 3.8, 6.7, 2.2],
       [4.6, 3.2, 1.4, 0.2],
       [6.2, 2.9, 4.3, 1.3],
       [5.7, 2.5, 5. , 2. ],
       [5.5, 4.2, 1.4, 0.2],
       [6. , 3. , 4.8, 1.8],
       [5.8, 2.7, 5.1, 1.9],
       [6. , 2.2, 4. , 1. ],
       [5.4, 3. , 4.5, 1.5],
       [6.2, 3.4, 5.4, 2.3],
       [5.5, 2.3, 4. , 1.3],
       [5.4, 3.9, 1.7, 0.4],
       [5. , 2.3, 3.3, 1. ],
       [6.4, 2.7, 5.3, 1.9],
       [5. , 3.3, 1.4, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 2.4, 3.8, 1.1],
       [6.7, 3