In [70]:
data0 = [[3, 5, 6, 1], 
        [8, 8, 2, 0],
        [2, 3, 5, 1],
        [1, 9, 8, 1],
       ]

In [71]:
import operator

def gini_index(regions, classes):
    '''Get gini index in a group of samples
    Args:
        regions: list of regions containing samples in each region
        classes: all distinct classes in all regions
    Return:
        gini index for current regions
    Note:
        Gini index for a region: G_r = sum(p_class(1-p_class))
                                    = 1 - sum(p_class**2) for each class
        Gini_index = sum(n_region/n_total * G_r) for each region
    '''

    n_samples = sum(len(region) for region in regions)
    gini = 0
    for region in regions:
        if not region:
            continue
        score = 0
        for c in classes:
            n_class = len([row for row in region if row[-1] == c])
            score += ((n_class * 1.0) / len(region)) ** 2
        gini += (1 - score) * n_class / n_samples

    return gini


def split(samples, classes):
    '''Make a split in a group of samples
    Args:
        samples: A group of samples
        classes: all classes in the samples
    Return: 
        The best split as a dict with keys:
            index: the index for the feature to split
            value: the value to split
            regions: the split regions
    '''
    
    gini_index_parent = gini_index([samples], classes)
    split = {'index': -1, 'value': None, 'gini': gini_index_parent}
    for index in range(len(samples[0]) - 1):
        values = sorted(list(set(row[index] for row in samples)))
        for value in values:
            region1 = [row for row in samples if row[index] < value]
            region2 = [row for row in samples if row[index] >= value]
            gini = gini_index([region1, region2], classes)
            if gini < split['gini']:
                split['index'] = index
                split['value'] = value   
                split['gini'] = gini
                print region1, region2
    return split


def get_leaf(samples, classes):
    '''Generate a leaf with a class
    Args:
        samples: a group of samples
        classes: all classes in the samples
    Return:
        the majority class in the samples
    '''
    
    class_counts = {c: 0 for c in classes}
    for row in samples:        
        class_counts[row[-1]] += 1    
    max_c_cnt = max(class_counts.iteritems(), key=operator.itemgetter(1))
    return max_c_cnt[0], max_c_cnt[1] / 1.0 / len(samples)


def get_split(samples):
    '''Do recursive splitting for a group of samples
    Args:
        samples: a group of samples
    Return:
        tree: split for each node
    Note:
        stop cases:
        iterates:
    '''
    
    
def predict():
    '''Make prediction given a new group of samples
    '''

In [72]:
get_leaf(data0, [0, 1])


(1, 0.75)

In [73]:
split(data0, [0, 1])

[[1, 9, 8, 1]] [[3, 5, 6, 1], [8, 8, 2, 0], [2, 3, 5, 1]]
[[2, 3, 5, 1], [1, 9, 8, 1]] [[3, 5, 6, 1], [8, 8, 2, 0]]
[[3, 5, 6, 1], [2, 3, 5, 1], [1, 9, 8, 1]] [[8, 8, 2, 0]]


{'gini': 0.0, 'index': 0, 'value': 8}

In [74]:
print gini_index([data0[1:3], [data0[0], data0[3]]], list(set([row[-1] for row in data0])))
print gini_index([data0[:2], data0[2:]], list(set([row[-1] for row in data0])))
print gini_index([data0], list(set([row[-1] for row in data0])))
print gini_index([data0, []], list(set([row[-1] for row in data0])))

0.125
0.125
0.28125
0.28125


In [25]:
def a():
    print 'a'
    
def b():
    a()

In [29]:
gini_index

<function __main__.gini_index>