In [30]:
from pandas import *
from numpy import *


dataset = [[65,65,23.4,1],
    [47,15,36.5,0],
    [34,75,77.2,1],
    [97,40,25.4,0],
    [36,100,99,1],
    [24,60,26,1],
    [43,56,43.2,0],
    [23,90,22.6,1],
    [12,23,23.2,1]]

def test_split(index, value, dataset):
    L, R = list(), list()
    for row in dataset:
        if row[index] < value:
            L.append(row)
        else:
            R.append(row)
    return L, R


def gini_index(groups, classes):
    n_instances = float(sum([len(group) for group in groups]))
    gini = 0.0
    for group in groups:
        size = float(len(group))
        if size == 0:
            continue
        score = 0.0
        for class_val in classes:
            p = [row[-1] for row in group].count(class_val) / size
            score += p * p
        gini += (1.0 - score) * (size / n_instances)
    return gini

def get_split(dataset):
    class_values = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups = 999, 999, 999, None
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            gini = gini_index(groups, class_values)
            print('X%d < %.3f Gini=%.3f' % ((index+1), row[index], gini))
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    return {'index':b_index, 'value':b_value, 'groups':b_groups}


split = get_split(dataset)
print('Split: [X%d < %.3f]' % ((split['index']+1), split['value']))


X1 < 65.000 Gini=0.429
X1 < 47.000 Gini=0.333
X1 < 34.000 Gini=0.333
X1 < 97.000 Gini=0.333
X1 < 36.000 Gini=0.267
X1 < 24.000 Gini=0.381
X1 < 43.000 Gini=0.167
X1 < 23.000 Gini=0.417
X1 < 12.000 Gini=0.444
X2 < 65.000 Gini=0.267
X2 < 15.000 Gini=0.444
X2 < 75.000 Gini=0.333
X2 < 40.000 Gini=0.429
X2 < 100.000 Gini=0.417
X2 < 60.000 Gini=0.167
X2 < 56.000 Gini=0.333
X2 < 90.000 Gini=0.381
X2 < 23.000 Gini=0.333
X3 < 23.400 Gini=0.381
X3 < 36.500 Gini=0.400
X3 < 77.200 Gini=0.381
X3 < 25.400 Gini=0.333
X3 < 99.000 Gini=0.417
X3 < 26.000 Gini=0.433
X3 < 43.200 Gini=0.444
X3 < 22.600 Gini=0.444
X3 < 23.200 Gini=0.417
Split: [X1 < 43.000]


In [31]:

def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

def split(node, max_depth, min_size, depth):
    L, R = node['groups']
    del(node['groups'])
    if not L or not R:
        node['L'] = node['R'] = to_terminal(L + R)
        return
    if depth >= max_depth:
        node['L'], node['R'] = to_terminal(L), to_terminal(R)
        return
    if len(L) <= min_size:
        node['L'] = to_terminal(L)
    else:
        node['L'] = get_split(L)
        split(node['L'], max_depth, min_size, depth+1)
    if len(R) <= min_size:
        node['R'] = to_terminal(R)
    else:
        node['R'] = get_split(R)
        split(node['R'], max_depth, min_size, depth+1)

def build_tree(train, max_depth, min_size):
    root = get_split(train)
    split(root, max_depth, min_size, 1)
    return root

def print_tree(node, depth=0):
    if isinstance(node, dict):
        print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
        print_tree(node['L'], depth+1)
        print_tree(node['R'], depth+1)
    else:
        print('%s[%s]' % ((depth*' ', node)))


tree = build_tree(dataset, 2, 2)
print_tree(tree)

X1 < 65.000 Gini=0.429
X1 < 47.000 Gini=0.333
X1 < 34.000 Gini=0.333
X1 < 97.000 Gini=0.333
X1 < 36.000 Gini=0.267
X1 < 24.000 Gini=0.381
X1 < 43.000 Gini=0.167
X1 < 23.000 Gini=0.417
X1 < 12.000 Gini=0.444
X2 < 65.000 Gini=0.267
X2 < 15.000 Gini=0.444
X2 < 75.000 Gini=0.333
X2 < 40.000 Gini=0.429
X2 < 100.000 Gini=0.417
X2 < 60.000 Gini=0.167
X2 < 56.000 Gini=0.333
X2 < 90.000 Gini=0.381
X2 < 23.000 Gini=0.333
X3 < 23.400 Gini=0.381
X3 < 36.500 Gini=0.400
X3 < 77.200 Gini=0.381
X3 < 25.400 Gini=0.333
X3 < 99.000 Gini=0.417
X3 < 26.000 Gini=0.433
X3 < 43.200 Gini=0.444
X3 < 22.600 Gini=0.444
X3 < 23.200 Gini=0.417
X1 < 34.000 Gini=0.000
X1 < 36.000 Gini=0.000
X1 < 24.000 Gini=0.000
X1 < 23.000 Gini=0.000
X1 < 12.000 Gini=0.000
X2 < 75.000 Gini=0.000
X2 < 100.000 Gini=0.000
X2 < 60.000 Gini=0.000
X2 < 90.000 Gini=0.000
X2 < 23.000 Gini=0.000
X3 < 77.200 Gini=0.000
X3 < 99.000 Gini=0.000
X3 < 26.000 Gini=0.000
X3 < 22.600 Gini=0.000
X3 < 23.200 Gini=0.000
X1 < 65.000 Gini=0.250
X1 < 47.0

In [32]:

def predict(node, row):
    if row[node['index']] < node['value']:
        if isinstance(node['L'], dict):
            return predict(node['L'], row)
        else:
            return node['L']
    else:
        if isinstance(node['right'], dict):
            return predict(node['right'], row)
        else:
            return node['right']

stump = {'index': 0, 'right': 1, 'value': 1, 'L': 0}
for row in dataset:
    prediction = predict(stump, row)
    print('expected=%d, model_output=%d' % (row[-1], prediction))

expected=1, model_output=1
expected=0, model_output=1
expected=1, model_output=1
expected=0, model_output=1
expected=1, model_output=1
expected=1, model_output=1
expected=0, model_output=1
expected=1, model_output=1
expected=1, model_output=1
