In [1]:
from __future__ import print_function

In [2]:
data=[
    ['Yellow',3,'Apple'],
    ['Green',3,'Apple'],
    ['Red',1,'Grape'],
    ['Red',1,'Grape'],
    ['Yellow',3,'Lemon']
]

In [3]:
header=["color","diameter","label"]

In [6]:
def uniq_val(rows,cols):
    return set ([rows[cols]] for rows in rows)

In [9]:
def class_counts(rows):
    counts={}
    for row in rows:
        global label
        label=row[-1]
        if label not in counts:
            counts[label]=0
        counts[label]=0
    return counts
class_counts(data)

{'Apple': 0, 'Grape': 0, 'Lemon': 0}

In [10]:
def numeric(value):
    global numeric
    return isinstance(value,int) or isinstance(value,float)
print(numeric(7))
numeric("red")

True


False

In [18]:
class Question():
    def __init__(self, column, value):
        self.column = column
        self.value = value

    def match(self, example):
        # Compare the feature value in an example to the feature value in this question.
        val = example[self.column]
        if numeric(val):
            return val >= self.value
        else:
            return val == self.value

    def __repr__(self):
        # This is just a helper method to print the question in a readable format.
        condition = "=="
        if numeric(self.value):
            condition = ">="
        return "Is %s %s %s?" % (
            header[self.column], condition, str(self.value))
print(Question(1, 3))
Question(0,'Green')

Is diameter >= 3?


Is color == Green?

In [19]:
# Partitions of a dataset.
def partition(rows, question):
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows, false_rows
true_rows, false_rows = partition(data, Question(0, 'Red'))
print(true_rows)
false_rows

[['Red', 1, 'Grape'], ['Red', 1, 'Grape']]


[['Yellow', 3, 'Apple'], ['Green', 3, 'Apple'], ['Yellow', 3, 'Lemon']]

In [20]:
def gini(rows):
    global gini
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl**2
    return impurity

In [21]:
def info_gain(left, right, current_uncertainty):
    p = float(len(left)) / (len(left) + len(right))
    return current_uncertainty - p * gini(left) - (1 - p) * gini(right)

In [23]:
true_rows, false_rows = partition(data, Question(0,'Red'))
print(true_rows)
false_rows

[['Red', 1, 'Grape'], ['Red', 1, 'Grape']]


[['Yellow', 3, 'Apple'], ['Green', 3, 'Apple'], ['Yellow', 3, 'Lemon']]

In [25]:
true_rows, false_rows = partition(data, Question(0,'Green'))
print(true_rows)
false_rows

[['Green', 3, 'Apple']]


[['Yellow', 3, 'Apple'],
 ['Red', 1, 'Grape'],
 ['Red', 1, 'Grape'],
 ['Yellow', 3, 'Lemon']]

In [26]:
 def find_best_split(rows):
    best_gain = 0  # keep track of the best information gain
    best_question = None  # keep train of the feature / value that produced it
    current_uncertainty = gini(rows)
    n_features = len(rows[0]) - 1  # number of columns

    for col in range(n_features):  # for each feature

        values = set([row[col] for row in rows])  # unique values in the column

        for val in values:  # for each value

            question = Question(col, val)

            # try splitting the dataset
            true_rows, false_rows = partition(rows, question)

            # Skip this split if it doesn't divide the dataset.
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue

            # Calculate the information gain from this split
            gain = info_gain(true_rows, false_rows, current_uncertainty)

            if gain >= best_gain:
                best_gain, best_question = gain, question

    return best_gain, best_question

In [28]:
best_gain, best_question = find_best_split(data)
best_question

Is diameter >= 3?

In [29]:
class Leaf:
    def __init__(self, rows):
        self.predictions = class_counts(rows)

In [30]:
class Decision_Node:
    def __init__(self,
                 question,
                 true_branch,
                 false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [31]:
def build_tree(rows):

    # Try partitioing the dataset on each of the unique attribute, calculate the information gain, and return the question that produces the highest gain.
    gain, question = find_best_split(rows)

    # we'll return a leaf
    if gain == 0:
        return Leaf(rows)

    # If we reach here, we have found a useful feature / value to partition on
    true_rows, false_rows = partition(rows, question)

    # Recursively build the true branch.
    true_branch = build_tree(true_rows)

    # Recursively build the false branch.
    false_branch = build_tree(false_rows)

    return Decision_Node(question, true_branch, false_branch)

In [32]:
def print_tree(node, spacing=""):

    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)
        return

    # Print the question at this node
    print (spacing + str(node.question))

    # Call this function recursively on the true branch
    print (spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

In [34]:
my_tree = build_tree(data)

In [35]:
print_tree(my_tree)

Predict {'Apple': 0, 'Grape': 0, 'Lemon': 0}


In [36]:
def classify(row, node):
    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        return node.predictions
    if node.question.match(row):
        return classify(row, node.true_branch)
    else:
        return classify(row, node.false_branch)

In [37]:
classify(data[0], my_tree)

{'Apple': 0, 'Grape': 0, 'Lemon': 0}

In [38]:
def print_leaf(counts):
    total = sum(counts.values()) * 1.0
    probs = {}
    for lbl in counts.keys():
        probs[lbl] = str(float(counts[lbl] / total * 100)) + "%"
    return probs

In [39]:
# Evaluate
testing_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 4, 'Apple'],
    ['Red', 2, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]