In [19]:
import numpy as np

class Node():
    
    def __init__(self, data, parent=None):
        self.data = data
        self.value = np.mean(data[:, -1])
        self.std = np.std(data[:, -1])
        
        self.split_feat_idx = None
        self.split_feat_value = None
        
        self.left = None
        self.right = None
        self.parent = parent

In [21]:
class Tree():

    def __init__(self, data, criterion, error_thres, split_fold_n):
        self.root = Node(data)
        self.current_node = self.root
        self.criterion = criterion
        self.error_thres = error_thres
        self.split_fold_n = split_fold_n
    
    def split(self):
        current_data = self.current_node.data
        current_std = self.current_node.std
        
        current_std_gain = 0
        
        # loop each column of features to search for split position
        for i in range(current_data.shape[1] - 1):
            current_data = current_data[current_data[:, i].argsort()]
            current_feature = current_data[:, i]
            current_y = current_data[:, -1]

            self.split_fold_n = len(current_feature) # maybe too slow
            interval = int(len(current_feature) / self.split_fold_n)
            
            for j in range(self.split_fold_n - 1):
                split_pos = (j + 1) * interval
                
                split_std0, split_std1 = np.std(current_y[: split_pos]), np.std(current_y[split_pos:])
                
                std_gain = current_std - split_std0 - split_std1
                
                if std_gain > current_std_gain:
                    current_split_pos = [i, split_pos]
                    current_split_value = current_feature[split_pos]
                    current_std_gain = std_gain
        
        # split 
        self.current_node.split_feat_idx = current_split_pos[0]
        self.current_node.split_feat_value = current_split_value
        
        current_data = current_data[current_data[:, current_split_pos[0]].argsort()]
        split_data0, split_data1 = current_data[: current_split_pos[1]], current_data[current_split_pos[1]:]
        
        node0 = Node(data=split_data0, parent=self.current_node)
        node1 = Node(data=split_data1, parent=self.current_node)
        
        self.current_node.left = node0
        self.current_node.right = node1
    
    def check_split(self):
        if len(self.current_node.data) == 1:
            return False
        
        if self.current_node.std < self.error_thres:
            return False
        
        current_node_y = self.current_node.value[:, -1]
        if len(current_node_y[current_node_y == current_node_y[0]]) == len(current_node_y):
            return False
        
        return True
        
    def generate(self):
        
        if not self.check_split():
            return
        
        self.split()
        
        for self.current_node in [self.current_node.left, self.current_node.right]:
            self.generate()
    
    def reset(self):
        self.current_node = self.root
    
    def predict(self, X):
        
        self.reset()
        
        feat_idx, feat_value = self.current_node.split_feat_idx, self.current_node.split_feat_value
        
        if X[feat_idx] < feat_value:
            
            if self.current_node.left:
                self.current_node = self.current_node.left                
                self.predict(X)
            else:
                return self.current_node.value
        else:
            
            if self.current_node.right:
                self.current_node = self.current_node.right
                self.predict(X)
            else:
                return self.current_node.value
        
    def prune(self):
        pass
            
        