# Summary

A decision tree works by defining a tree where each node splits the data to minimize splitting the data to "unmix" the data as much as possible.

In [12]:
from collections import Counter
from numbers import Number

# Gini Impurity

between 0 and 1 where lower values mean less mixing at a node
chance of being incorrect if you randomly assign a label to an item in the set
https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity
"The Gini impurity can be computed by summing the probability p_i of an item with label_i being chosen times the probability of a mistake in categorizing that item. It reaches its minimum (zero) when all cases in the node fall into a single target category."

In [13]:
def gini(labels):    
    impurity = 0
    for label, count in Counter(labels).items():
        prob_chosen = count / float(len(labels))
        impurity += prob_chosen * (1 - prob_chosen)
    
    return impurity

In [14]:
gini(['a','b','c', 'd', 'e'])

0.8000000000000002

In [15]:
gini(['a','a','c'])

0.4444444444444445

In [16]:
gini(['a','a','a'])

0.0

# Information gain

Used to find the question that reduces uncertainty the most. Describes how much a question helps unmix labels at a node.

It is calculated by finding the diff of the impurity before the split and the weighted avg of the impurity in each of the outputs after the split. Weighted avg is used because the size of the split matters (e.g. splitting one item and leaving a group with a lot of impurity).

In [17]:
def info_gain(labels_in, labels_out_left, labels_out_right):
    gini_in = gini(labels_in)
    p = len(labels_out_left) / float(len(labels_in))
    
    return gini(labels_in) - p * gini(labels_out_left) - (1 - p) * gini(labels_out_right)

In [18]:
info_gain(['a','b'], ['a'], ['b'])

0.5

In [19]:
info_gain(['a','b', 'b'], ['a', 'b'], ['b'])

0.11111111111111116

# Define all possible node rules
The set of rules is defined from the input data. Try every value for every label.

In [20]:
class Rule:
    def __init__(self, column_name, column_index, column_value):
        self.name = column_name
        self.index = column_index
        self.value = column_value
        self.is_numeric = isinstance(column_value, Number)

    def match(self, row):
        val = row[self.index]
        if self.is_numeric:
            return val >= self.value
        else:
            return val == self.value

    def __repr__(self):
        # This is just a helper method to print
        # the question in a readable format.
        condition = "=="
        if self.is_numeric:
            condition = ">="
        return "Is %s %s %s?" % (
            self.name, condition, str(self.value))

In [53]:
# header = ['color', 'weight', 'label']
# label_index = 2
# data = [
#     ['green', 10.5, 'pear'],
#     ['green', 9.2, 'pear'],
#     ['red', 5.4, 'tomato']
# ]

header = ["color", "diameter", "label"]
data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]
label_index = 2

In [22]:
Rule('name', 0, 'green')

Is name == green?

In [23]:
Rule('name', 1, 2)

Is name >= 2?

In [24]:
Rule('name', 1, 2.3)

Is name >= 2.3?

In [77]:
def all_rules(header, data, label_index):
    rules = []
    
    for index, name in enumerate(header):
        if index != label_index:
            # get unique values for index
            vals = set()
            for row in data:
                vals.add(row[index])
                
            for val in vals:    
                rules.append(Rule(name, index, val))
            
    return set(rules)

In [78]:
rules = all_rules(header, data, label_index)
rules

{Is color == Green?,
 Is color == Red?,
 Is color == Yellow?,
 Is diameter >= 1?,
 Is diameter >= 3?}

# Find best split
Next define a method that finds the rule that maximizes the information gain. 

In [105]:
def get_labels(data, label_index):
    return list(map(lambda row: row[label_index], data))

def best_split(data, rules):
    """
    split the data using each rule and see which one has the maximum info gain
    """
#     print("\n+++++++++")
#     print("input: {}".format(get_labels(data, label_index)))

    max_info_gain = 0.
    max_rule = None
    max_left = None
    max_right = None
    
    for rule in rules:
        # for each rule bucket the results into two groups
        left = []
        right = []
        for row in data:
            if rule.match(row):
                right.append(row)
            else:
                left.append(row)
        
        new_info_gain = info_gain(get_labels(data, label_index), get_labels(left, label_index), get_labels(right, label_index))
        
#         print("rule: {}".format(rule))
#         print("info_gain: {}".format(new_info_gain))
#         print("left: {}".format(get_labels(left, label_index)))
#         print("right: {}".format(get_labels(right, label_index)))
#         print("\n")
        
        if new_info_gain > max_info_gain:
            max_info_gain = new_info_gain
            max_rule = rule
            max_left = left
            max_right = right
            
#     print("max info gain: {}".format(max_info_gain))
#     print("max rule: {}".format(max_rule))
#     print("max left: {}".format(max_left))
#     print("max right: {}".format(max_right))
#     print("+++++++++")        
            
    return max_info_gain, max_rule, max_left, max_right             

In [98]:
best_split(data, rules)


+++++++++
input: ['Apple', 'Apple', 'Grape', 'Grape', 'Lemon']
rule: Is color == Red?
info_gain: 0.37333333333333335
left: ['Apple', 'Apple', 'Lemon']
right: ['Grape', 'Grape']


rule: Is color == Yellow?
info_gain: 0.17333333333333334
left: ['Apple', 'Grape', 'Grape']
right: ['Apple', 'Lemon']


rule: Is color == Green?
info_gain: 0.14
left: ['Apple', 'Grape', 'Grape', 'Lemon']
right: ['Apple']


rule: Is diameter >= 3?
info_gain: 0.37333333333333335
left: ['Grape', 'Grape']
right: ['Apple', 'Apple', 'Lemon']


rule: Is diameter >= 1?
info_gain: 0.0
left: []
right: ['Apple', 'Apple', 'Grape', 'Grape', 'Lemon']


max info gain: 0.37333333333333335
max rule: Is color == Red?
max left: [['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]
max right: [['Red', 1, 'Grape'], ['Red', 1, 'Grape']]
+++++++++


(0.37333333333333335,
 Is color == Red?,
 [['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']],
 [['Red', 1, 'Grape'], ['Red', 1, 'Grape']])

# Building the full tree

In [116]:
class Node:
    def __init__(self,
                 rule,
                 false_node,
                 true_node):
        self.rule = rule
        self.false_node = false_node
        self.true_node = true_node
        
    def is_leaf(self):
        return False

## Leaf nodes

When a leaf node is reached the decision tree returns the label value for objects end at this leaf

In [117]:
class Leaf:
    def __init__(self, labels):
        tot_count = len(labels)
        self.predictions = {}
        for label, count in Counter(labels).items():
            self.predictions[label] = float(count) / tot_count
        
    def is_leaf(self):
        return True   

## Build tree

In [118]:
def build_tree(header, training_data, label_index):
    # create all rules from training data
    rules = all_rules(header, training_data, label_index)
    
    # keep finding best split until cannot be split further
    return build_tree_rec(training_data, rules, label_index)
    
    
def build_tree_rec(training_data, rules, label_index):
    # base case: impurity is 0 (completely separated)
    labels = get_labels(training_data, label_index)
    if gini(labels) == 0:
        print("unmixed: {}\n".format(labels))
        return Leaf(labels)
    
    # split data
    info_gain, rule, false_data, true_data = best_split(training_data, rules)
    
    # stop if info gain is 0. that means it can't be split further
    if info_gain == 0:
        print("no more info gain: {}\n".format(labels))
        return Leaf(labels)
    
    # build decision tree for each outcome
    print(rule)
    print("new node false: {}".format(false_data))
    print("new node true: {}\n".format(true_data))
    return (Node(
        rule, 
        build_tree_rec(false_data, rules, label_index),
        build_tree_rec(true_data, rules, label_index)
    ))
    

In [119]:
tree = build_tree(header, data, label_index)

Is diameter >= 3?
new node false: [['Red', 1, 'Grape'], ['Red', 1, 'Grape']]
new node true: [['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]

unmixed: ['Grape', 'Grape']

Is color == Green?
new node false: [['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]
new node true: [['Green', 3, 'Apple']]

no more info gain: ['Apple', 'Lemon']

unmixed: ['Apple']



In [120]:
def print_tree(node, spacing=""):
    """World's most elegant tree printing function."""

    # Base case: we've reached a leaf
    if node.is_leaf():
        print (spacing + "Predict", node.predictions)
        return

    # Print the question at this node
    print (spacing + str(node.rule))

    # Call this function recursively on the true branch
    print (spacing + '--> True:')
    print_tree(node.true_node, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '--> False:')
    print_tree(node.false_node, spacing + "  ")

In [121]:
print_tree(tree)

Is diameter >= 3?
--> True:
  Is color == Green?
  --> True:
    Predict {'Apple': 1.0}
  --> False:
    Predict {'Apple': 0.5, 'Lemon': 0.5}
--> False:
  Predict {'Grape': 1.0}


# Using decision tree

In [70]:
def classify(row, node):
    """See the 'rules of recursion' above."""

    # Base case: we've reached a leaf
    if node.is_leaf():
        return node.predictions

    # Decide whether to follow the true-branch or the false-branch.
    # Compare the feature / value stored in the node,
    # to the example we're considering.
    if node.rule.match(row):
        return classify(row, node.true_node)
    else:
        return classify(row, node.false_node)

In [124]:
print(data[0])
classify(data[0], tree)

['Green', 3, 'Apple']


{'Apple': 1.0}

In [127]:
classify(['Yellow', 3], tree)

{'Apple': 0.5, 'Lemon': 0.5}

# Tree Stats

In [129]:
def tree_depth(tree):
    if tree.is_leaf():
        return 1
    
    return 1 + max(tree_depth(tree.false_node), tree_depth(tree.true_node))

In [132]:
tree_depth(tree)

3

In [134]:
def leaf_count(tree):
    if tree.is_leaf():
        return 1
    
    return leaf_count(tree.false_node) + leaf_count(tree.true_node)

In [135]:
leaf_count(tree)

3

# References

https://www.youtube.com/watch?v=LDRbO9a6XPU