In [1]:
import numpy as np
#set a random seed to make the algo replicable 
np.random.seed(0)

In [71]:
dataset = [[2.771244718,1.784783929,0],
    [1.728571309,1.169761413,0],
    [3.678319846,2.81281357,0],
    [3.961043357,2.61995032,0],
    [2.999208922,2.209014212,0],
    [7.497545867,3.162953546,1],
    [9.00220326,3.339047188,1],
    [7.444542326,0.476683375,1],
    [10.12493903,3.234550982,1],
    [6.642287351,3.319983761,1]]

In [2]:
# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
    left, right = list(), list()
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right

In [3]:
################################################### Criterion: gini ########################################
# Calculate the Gini index for a split dataset
def get_gini(groups, classes):
    # count all samples at split point
    n_instances = float(sum([len(group) for group in groups]))
    # sum weighted Gini index for each group
    gini = 0.0
    for group in groups:
        size = float(len(group))
        # avoid divide by zero
        if size == 0:
            continue
        score = 0.0
        # score the group based on the score for each class
        for class_val in classes:
            p = [row[-1] for row in group].count(class_val) / size
            score += p * p
        # weight the group score by its relative size
        gini += (1.0 - score) * (size / n_instances)
    return gini

# Select the best split point for a dataset
def get_split_gini(dataset):
    class_values = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups =float('inf'), float('inf'), float('inf'), None
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            gini = get_gini(groups, class_values)
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    return {'index':b_index, 'value':b_value, 'groups':b_groups}

In [4]:
#################################################### Criterion: EMSE ########################################
######### This part is so wrong, just have something up there for the frame work. Need to
######### Read the paper again and figure out the actual EMSE computation.

#EMSE is a criterion adapted from conventional MSE. It is used to find the causal effect using a tree algorithm
def get_emse(train_lst):
    
    train_lst = np.array(train_lst)
    
    ###get the conventional mse for trt split
    trt_mse = get_mse(train_lst)
    
    #randomly split the train_lst into treatment and estimate 
    np.random.shuffle(train_lst)
    length = train_lst.size
    mid_point = round(length/2)
    tst_lst = train_lst[:mid_point]
    est_lst = train_lst[mid_point:]
    
    # calculate the propensity score for trt group
    n_trt = trt_lst.size
    n_est = est_lst.size
    p = n_trt / (n_trt + n_est)
    
    ### calculate the adaptive part
    adaptive = (1 / n_trt + 1 / n_est) * (np.var(trt_lst) / p + np.var(est_lst) / (1 - p))
    
    #calculate emse
    emse = trt_mse - adaptive
    
    return emse

In [5]:
def get_split_emse(dataset):
    # initialize values to return
    b_index, b_value, b_score, b_groups = float('inf'), float('inf'), float('inf'), None
    
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            left_lst = [item[-1] for item in groups[0]]
            right_lst = [item[-1] for item in groups[1]]
            # skip when a split has no data point
            if len(left_lst) == 0 or len(right_lst) == 0:
                continue
            # calculate the mse
            emse_left = get_emse(left_lst)
            emse_right = get_emse(right_lst)
            emse = emse_left + emse_right
            if emse < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], emse, groups      

    return {'index':b_index, 'value':b_value, 'groups':b_groups}

In [95]:
dataset[0]

[2.771244718, 1.784783929, 0]

In [91]:
#################################################### Criterion: MSE ########################################
def get_mse(true_lst):
    """ 
    A function to calculate the mse
    
    Input:
    ------
        true_lst(np array): an array of true values 
    
    Output:
    ------:
        mse(float): the value of mean squared error of two lists
    
    """
    true_lst = np.array(true_lst)
    
    if true_lst.size == 0:
        pass
    #    raise IndexError('The length of list is 0.')
    
    #get the mean value of the list, since squared loss is optmized at the mean
    avg = true_lst.mean()
    #calculate mse
    mse = ((true_lst - avg) ** 2).mean()
    
    return mse



# Select the best split point for a dataset
def get_split_mse(dataset):
    # initialize values to return
    b_index, b_value, b_score, b_groups = float('inf'), float('inf'), float('inf'), None

    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            left_lst = [item[-1] for item in groups[0]]
            right_lst = [item[-1] for item in groups[1]]
            # skip when a split has no data point
            if len(left_lst) == 0 or len(right_lst) == 0:
                continue
            # calculate the mse
            mse_left = get_mse(left_lst)
            mse_right = get_mse(right_lst)
            mse = mse_left + mse_right
            if mse < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], mse, groups    
                
    ret_dict =  {'index':b_index, 'value':b_value, 'groups':b_groups}
    print(ret_dict)
    return ret_dict

In [93]:
# get the split based on criterion
def get_split(dataset, criterion):
    if criterion == 'mse':
        return get_split_mse(dataset)
    if criterion == 'emse':
        return get_split_emse(dataset)    
    elif criterion == 'gini':
        return get_split_gini(dataset)

    
# Create a terminal node value
def to_terminal_categorical(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

def to_terminal_continuous(group):
    outcomes = [row[-1] for row in group]
    return np.mean(outcomes)

def to_terminal(group, criterion):
    if criterion == 'gini':
        return to_terminal_categorical(group)
    elif criterion == 'mse' or criterion == 'emse':
        return to_terminal_continuous(group)    


# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth, criterion):
    #print(node)
    left, right = node['groups']
    del(node['groups'])
    # check for a no split
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right, criterion)
        return
    # check for max depth
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left, criterion), to_terminal(right, criterion)
        return
    # process left child
    if len(left) <= min_size:
        node['left'] = to_terminal(left, criterion)
    else:
        node['left'] = get_split(left, criterion)
        split(node['left'], max_depth, min_size, depth+1, criterion)
    # process right child
    if len(right) <= min_size:
        node['right'] = to_terminal(right, criterion)
    else:
        node['right'] = get_split(right, criterion)
        split(node['right'], max_depth, min_size, depth+1, criterion)

# Build a decision tree
def build_tree(train, max_depth, min_size, criterion):
    root = get_split(train, criterion)
    split(root, max_depth, min_size, 1, criterion)
    return root

# Print a decision tree
def print_tree(node, depth=0):
    if isinstance(node, dict):
        print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
        print_tree(node['left'], depth+1)
        print_tree(node['right'], depth+1)
    else:
        print('%s[%s]' % ((depth*' ', node)))




In [94]:
dataset_mse = [[2.771244718,1.784783929,0.2456],
    [1.728571309,1.169761413,0.87673],
    [3.678319846,2.81281357,0.92357],
    [3.961043357,2.61995032,0.12341],
    [2.999208922,2.209014212,0.01134],
    [7.497545867,3.162953546,1.56433],
    [9.00220326,3.339047188,1.23461],
    [7.444542326,0.476683375,0.34534],
    [10.12493903,3.234550982,0.84662],
    [6.642287351,3.319983761,0.34563]]
tree_mse = build_tree(dataset_mse, 3, 1, 'mse')
print_tree(tree_mse)

{'index': 0, 'value': 7.497545867, 'groups': ([[2.771244718, 1.784783929, 0.2456], [1.728571309, 1.169761413, 0.87673], [3.678319846, 2.81281357, 0.92357], [3.961043357, 2.61995032, 0.12341], [2.999208922, 2.209014212, 0.01134], [7.444542326, 0.476683375, 0.34534], [6.642287351, 3.319983761, 0.34563]], [[7.497545867, 3.162953546, 1.56433], [9.00220326, 3.339047188, 1.23461], [10.12493903, 3.234550982, 0.84662]])}
{'index': 0, 'value': 2.771244718, 'groups': ([[1.728571309, 1.169761413, 0.87673]], [[2.771244718, 1.784783929, 0.2456], [3.678319846, 2.81281357, 0.92357], [3.961043357, 2.61995032, 0.12341], [2.999208922, 2.209014212, 0.01134], [7.444542326, 0.476683375, 0.34534], [6.642287351, 3.319983761, 0.34563]])}
{'index': 0, 'value': 2.999208922, 'groups': ([[2.771244718, 1.784783929, 0.2456]], [[3.678319846, 2.81281357, 0.92357], [3.961043357, 2.61995032, 0.12341], [2.999208922, 2.209014212, 0.01134], [7.444542326, 0.476683375, 0.34534], [6.642287351, 3.319983761, 0.34563]])}
{'inde

In [97]:
dataset_gini = [[2.771244718,1.784783929,0],
    [1.728571309,1.169761413,0],
    [3.678319846,2.81281357,0],
    [3.961043357,2.61995032,0],
    [2.999208922,2.209014212,0],
    [7.497545867,3.162953546,1],
    [9.00220326,3.339047188,1],
    [7.444542326,0.476683375,1],
    [10.12493903,3.234550982,1],
    [6.642287351,3.319983761,1]]
#tree_gini = build_tree(dataset_gini, 3, 1, 'gini')
#print_tree(tree_gini)

[X1 < 6.642]
 [X1 < 2.771]
  [0]
  [X1 < 2.771]
   [0]
   [0]
 [X1 < 7.498]
  [X1 < 7.445]
   [1]
   [1]
  [X1 < 7.498]
   [1]
   [1]


In [68]:
dataset_emse = [[2.771244718,1.784783929,0.2456],
    [1.728571309,1.169761413,0.87673],
    [3.678319846,2.81281357,0.92357],
    [3.961043357,2.61995032,0.12341],
    [2.999208922,2.209014212,0.01134],
    [7.497545867,3.162953546,1.56433],
    [9.00220326,3.339047188,1.23461],
    [7.444542326,0.476683375,0.34534],
    [10.12493903,3.234550982,0.84662],
    [6.642287351,3.319983761,0.34563]]
#tree_emse = build_tree(dataset_emse, 2, 1, 'emse')
#print_tree(tree_emse)

In [69]:
# # Make a prediction with a decision tree
# def predict(node, row):
#     if row[node['index']] < node['value']:
#         if isinstance(node['left'], dict):
#             return predict(node['left'], row)
#         else:
#             return node['left']
#     else:
#         if isinstance(node['right'], dict):
#             return predict(node['right'], row)
#         else:
#             return node['right']

# dataset = [[2.771244718,1.784783929,0],
#     [1.728571309,1.169761413,0],
#     [3.678319846,2.81281357,0],
#     [3.961043357,2.61995032,0],
#     [2.999208922,2.209014212,0],
#     [7.497545867,3.162953546,1],
#     [9.00220326,3.339047188,1],
#     [7.444542326,0.476683375,1],
#     [10.12493903,3.234550982,1],
#     [6.642287351,3.319983761,1]]
 
# #  predict with a stump
# stump = {'index': 0, 'right': 1, 'value': 7.445, 'left': 0}
# for row in dataset:
#     prediction = predict(stump, row)
#     print('Expected=%d, Got=%d' % (row[-1], prediction))