In [286]:
training_data = [
    [ 'Green' , 3 , 'Apple'],
    [ 'Yellow' , 3 , 'Apple'],
    [ 'Red' , 1 , 'Grape'],
    [ 'Red' , 1 , 'Grape'],
    [ 'Yellow' , 3 , 'Lemon']
    
]

In [315]:
from sklearn.datasets import load_digits

In [320]:
dataset = load_digits()

In [321]:
data , label = dataset['data'] , dataset['target']

In [322]:
data

array([[ 0.,  0.,  5., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ..., 10.,  0.,  0.],
       [ 0.,  0.,  0., ..., 16.,  9.,  0.],
       ...,
       [ 0.,  0.,  1., ...,  6.,  0.,  0.],
       [ 0.,  0.,  2., ..., 12.,  0.,  0.],
       [ 0.,  0., 10., ..., 12.,  1.,  0.]])

In [266]:
from collections import Counter

In [267]:
def gini(rows):
    labels = [ rows[i][-1] for i in range(len(rows)) ]
    count_dict = Counter(labels)
    impurity = 1
    for label in count_dict:
        probab_of_indivisual_label = count_dict[label] / len(rows)
        impurity -= probab_of_indivisual_label**2
    return impurity
        
    

In [268]:
gini(training_data)

0.6399999999999999

In [269]:
class Question:
    def __init__(self,column,value):
        self.column = column
        self.value = value
    def __str__(self):
        return 'Question column: {column} , value: {value}'.format(column=self.column,value=self.value)

In [270]:
def split_dataset(rows,question):
    question_value = question.value
    question_column = question.column
    
    true_dataset = list()
    false_dataset = list()
    for row in rows:
        data_point = row[question_column]
        if str(data_point).isnumeric():
            if float(data_point) <=question_value:
                true_dataset.append(row)
            else:
                false_dataset.append(row)
        else:
            if data_point == question_value:
                true_dataset.append(row)
            else:
                false_dataset.append(row)
    return true_dataset , false_dataset

In [271]:
question = Question(0,'Green')

In [272]:
true_rows , false_rows = split_dataset(training_data,question)
true_rows

[['Green', 3, 'Apple']]

In [273]:
false_rows

[['Yellow', 3, 'Apple'],
 ['Red', 1, 'Grape'],
 ['Red', 1, 'Grape'],
 ['Yellow', 3, 'Lemon']]

In [274]:
def info_gain(left , right , current):
    fraction = len(left) / (len(left) + len(right))
    return current - (fraction * gini(left) + (1-fraction) * gini(right) )

In [275]:
info_gain(true_rows , false_rows , 0.64)

0.14

In [276]:
def find_best_split(rows):
    columns = len(rows[0]) - 1
    best_question = None
    best_gain = 0
    
    for col in range(columns):
        column_values = set(rows[i][col] for i in range(len(rows)))
        for value in column_values:
            question = Question(col , value)
            true_rows , false_rows = split_dataset(rows , question)
            gain = info_gain(true_rows , false_rows , gini(rows))
            if gain > best_gain:
                best_gain = gain
                best_question = question
    return best_gain , best_question
        

In [277]:
gain , question = find_best_split(training_data)

In [278]:
print(question.column , question.value)

0 Red


In [279]:
class Node:
    def __init__(self,question , left , right):
        self.question = question
        self.left = left
        self.right = right

In [280]:
class Leaf:
    def __init__(self,rows):
        self.rows = rows
        self.prediction = Counter(rows[i][-1] for i in range(len(rows)))

In [281]:
def make_decision_tree(rows):
    gain , question = find_best_split(rows)
    if gain == 0:
        return Leaf(rows)
    true_rows , false_rows = split_dataset(rows,question)
    left_tree = make_decision_tree(true_rows)
    
    right_tree = make_decision_tree(false_rows)
    
    return Node(question , left_tree , right_tree)
    

In [323]:
node = make_decision_tree(data)

In [324]:
print(node.question.column , node.question.value)

55 0.0


In [325]:
def print_tree(node, spacing=""):
    """World's most elegant tree printing function."""

    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.prediction)
        return

    # Print the question at this node
    print (spacing + str(node.question))

    # Call this function recursively on the true branch
    print (spacing + '--> True:')
    print_tree(node.left, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '--> False:')
    print_tree(node.right, spacing + "  ")

In [326]:
print_tree(node)

Question column: 55 , value: 0.0
--> True:
  Question column: 62 , value: 16.0
  --> True:
    Question column: 52 , value: 16.0
    --> True:
      Question column: 11 , value: 13.0
      --> True:
        Predict Counter({9.0: 2})
      --> False:
        Question column: 26 , value: 0.0
        --> True:
          Question column: 4 , value: 15.0
          --> True:
            Predict Counter({8.0: 1})
          --> False:
            Predict Counter({7.0: 2})
        --> False:
          Predict Counter({5.0: 2})
    --> False:
      Question column: 27 , value: 0.0
      --> True:
        Predict Counter({6.0: 3})
      --> False:
        Question column: 2 , value: 11.0
        --> True:
          Question column: 4 , value: 7.0
          --> True:
            Predict Counter({6.0: 1})
          --> False:
            Predict Counter({1.0: 3})
        --> False:
          Question column: 25 , value: 0.0
          --> True:
            Question column: 12 , value: 16.0
         