In [2]:
import numpy as np


In [128]:
class CART:
    
    def __init__(self, x, y, prune = False):
        self.x = x
        self.y = y
        self.branches = self.buildTree(x,y,prune)
    
    def sign(self,x):
        if x<0:
            return -1
        else:
            return 1
    
    def hfunc(self, x, s, theta):
        if s:
            return self.sign(x-theta)
        else:
            return -self.sign(x-theta)
        
    def divideData(self, x, y, s, theta, dim):
        left_x = []
        left_y = []
        right_x = []
        right_y = []
        
        for i in range(len(y)):
            if self.hfunc(x[i][dim], s, theta) == -1:
                left_x.append(x[i])
                left_y.append(y[i])
            else:
                right_x.append(x[i])
                right_y.append(y[i])
        
        #print("len of left x : {0}   len of right x : {1}".format(len(left_x), len(right_x)))
        return [np.array(left_x), np.array(left_y)], [np.array(right_x), np.array(right_y)]
        
    def computeGini(self, x, y):
        lens = len(y) # N
        if lens == 0:
            return 0
        pos_num = np.sum(y==1)
        neg_num = np.sum(y!=1)
        gini = 1 - ((pos_num)/lens)**2 - ((neg_num)/lens)**2
        return gini
    
    
    def computeDivideGini(self, x, y, s, theta, dim):
        left, right = self.divideData(x, y, s, theta, dim)
        left_gini = self.computeGini(left[0], left[1])
        right_gini = self.computeGini(right[0], right[1])
        # compute impurity*weight
        return len(left[1])*left_gini + len(right[1])*right_gini
         
        
        
    def branchStump(self,x,y):
        #print('branch tree once..')
        dimensions = len(x[0])
        best_theta = 0
        best_s = True
        best_dim = 0
        best_gini = 100
        for dim in range(dimensions):
            thetas = np.sort(x[:,dim])
            ss = [True, False]
            for i,theta in enumerate(thetas):
                if i>0:
                    theta = (theta + thetas[i-1])/2
                else:
                    theta = theta/2
                    
                for s in ss:
                    gini = self.computeDivideGini(x, y, s, theta, dim)
                    if gini < best_gini:
                        best_gini = gini
                        best_s = s
                        best_theta = theta
                        best_dim = dim
        branch = [best_s, best_theta, best_dim]
        left, right = self.divideData(x, y, best_s, best_theta, best_dim)
        return branch, left, right
    
    def buildTree(self,x,y,prune=False):
        branches = []
        if prune: 
            #只有一层分支
            branch, left, right = self.branchStump(x,y)
            branches.append(branch)
            if sum(left[1]) >= 0:
                branches.append([1])
            else:
                branches.append([-1])
            if sum(right[1]) >= 0:
                branches.append([1])
            else:
                branches.append([-1])
            return branches
        else:
            #fully grown
            if abs(sum(y)) == len(y):
                branches.append(y[0])
                return branches 
            else:
                branch, left, right = self.branchStump(x,y)
                branches.append(branch)
                branches.append(self.buildTree(left[0],left[1]))
                branches.append(self.buildTree(right[0],right[1]))
                return branches
            
    def fit(self,x,branch):
        if len(branch) == 3:
            dim = branch[0][2]
            y = self.hfunc(x[dim],branch[0][0],branch[0][1])
            if y == -1:
                return self.fit(x,branch[1])
            else:
                return self.fit(x,branch[2])
        else:
            return branch[0]

    def predict(self,x):
        res = []
        for i in range(len(x)):
            res.append(self.fit(x[i],self.branches))
        return np.array(res)



In [130]:
if __name__ == '__main__':
    data = np.genfromtxt('train12.txt')
    train_x = data[:,:-1]
    train_y = data[:,-1]
    cart = CART(train_x, train_y)
    branches = cart.buildTree(train_x, train_y)
    for branch in branches:
        print(branch)

[True, 0.626233, 1]
[[True, 0.22443950000000001, 0], [[True, 0.11515275, 1], [1.0], [-1.0]], [[True, 0.541508, 0], [[True, 0.3586205, 1], [[True, 0.501625, 0], [1.0], [-1.0]], [[True, 0.2607515, 0], [1.0], [-1.0]]], [[True, 0.285925, 1], [[True, 0.2660385, 1], [1.0], [-1.0]], [1.0]]]]
[[True, 0.8781715, 0], [-1.0], [1.0]]


In [127]:
if __name__ == '__main__':
    data = np.genfromtxt('train12.txt')
    train_x = data[:,:-1]
    train_y = data[:,-1]
    cart = CART(train_x, train_y)
    branches = cart.buildTree(train_x, train_y)
    print(branches)

[[True, 0.626233, 1], [[True, 0.22443950000000001, 0], [[True, 0.11515275, 1], [1.0]]]]
