<a href="https://colab.research.google.com/github/seabay/ml_practice/blob/master/decision_tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import numpy as np

In [94]:

class Node:

  def __init__(self, f_index, val, l, r, is_leaf, data):
    self.f_index=f_index
    self.val=val
    self.left=l
    self.right=r
    self.is_leaf=is_leaf
    self.y=data

  def __str__(self):
    return str(self.f_index) + ", " + str(self.val) + ", " + str(self.is_leaf)


class RegressionTree:

  def __init__(self):
    self.root=None

  def mse(self, y):
    if y.shape[0]==0: return 0
    my = np.mean(y, axis=0)
    return np.sum((y-my)*(y-my),axis=0)

  def find_split(self, data):

    f_size = data.shape[1]-1  ## exclude y
    min_mse=float("+inf")
    f_i=None
    f_v=None
    f_d_l=None
    f_d_r=None

    for i in range(f_size):
      vals = np.percentile(data[:, i], [25, 50, 75, 100])
      print(vals)
      d_l = None
      d_r = None
      for v in vals:
        mask = data[:, i] <=v
        d_l=data[mask, :]
        d_r=data[~mask, :]  
        mse_val=self.mse(d_l[:, -1]) + self.mse(d_r[:, -1])
        #print(d_l.shape, d_r.shape, mse_val)
        if mse_val<min_mse:
          min_mse=mse_val
          f_i=i
          f_v=v
          f_d_l=d_l
          f_d_r=d_r

    print("select: ", f_i, f_v)
    print()
    
    return f_i, f_v, f_d_l, f_d_r


  def fit(self, data):
    self.root=self.fit_helper(data)
  
  def fit_helper(self, data):

    if data.shape[0]<=2:
      return Node(None, None, None, None, True, data[:, -1])
    ret = self.find_split(data)
    cnode=Node(ret[0], ret[1], None, None, False, None)
    cnode.left=self.fit_helper(ret[2])
    cnode.right=self.fit_helper(ret[3])
    return cnode


  def predict(self, data):

    d=data.tolist()
    ps=[]
    for item in d:
      n=self.root
      while not n.is_leaf:
        i=n.f_index
        v=n.val
        if item[i]<=v:
          n=n.left
        else: n=n.right

      if n.is_leaf:
        ps.append(np.mean(n.y))
    
    return np.concatenate([np.array(data), np.array(ps)[:, None]], axis=1)









In [95]:
X=np.random.randn(30, 3) * np.sqrt(1/20)

In [96]:
X[:,-1]

array([-0.16517177, -0.11388138, -0.00981753, -0.0088032 , -0.06991879,
       -0.2332216 ,  0.2527685 ,  0.01230981,  0.04027444,  0.08807233,
       -0.06450929, -0.03025672, -0.099153  ,  0.33695184,  0.10297544,
        0.15340058,  0.14821389, -0.12515908, -0.1451662 , -0.05365323,
       -0.18869403,  0.00252288, -0.08408438, -0.39112446, -0.08416915,
        0.32479155, -0.12083745, -0.34628263, -0.18605494,  0.06005404])

In [97]:
y=np.random.randn(30, 1)

In [98]:
data=np.concatenate([X[0:25, :], y[0:25, :]],axis=1)

In [99]:
test=np.concatenate([X[25:, :], y[25:, :]],axis=1)

In [100]:
data.shape

(25, 4)

In [101]:
test.shape

(5, 4)

In [102]:
t=RegressionTree()

In [103]:
#t.find_split(data)

In [104]:
n=t.fit(data)

[-0.19479343 -0.07195022  0.15261597  0.47762782]
[-0.14056708  0.04150299  0.19106192  0.45660026]
[-0.11388138 -0.05365323  0.04027444  0.33695184]
select:  1 -0.14056707713469624

[0.02269555 0.08337925 0.29517549 0.46595211]
[-0.25413153 -0.18831049 -0.15315325 -0.14056708]
[-0.10651719 -0.03025672  0.15080724  0.2527685 ]
select:  1 -0.25413153001825306

[-0.02173847  0.08337925  0.2550039   0.33534709]
[-0.18831049 -0.1550045  -0.151302   -0.14056708]
[-0.11388138 -0.099153    0.15340058  0.2527685 ]
select:  1 -0.15130200481251194

[-0.0618393   0.16919157  0.27508969  0.33534709]
[-0.202122   -0.17165749 -0.15407887 -0.151302  ]
[-0.13258454 -0.10651719 -0.01117263  0.2527685 ]
select:  0 -0.061839304312189966

[0.16919157 0.2550039  0.29517549 0.33534709]
[-0.21593352 -0.18831049 -0.16980625 -0.151302  ]
[-0.1512877  -0.11388138  0.06944356  0.2527685 ]
select:  0 0.25500389589937517

[-0.21213672 -0.11878053  0.00404086  0.47762782]
[0.02237309 0.05353906 0.22661276 0.4566002

In [105]:
predicts=t.predict(test)

In [106]:
predicts

array([[ 0.16109923, -0.05686799,  0.32479155,  1.19752503,  1.05689551],
       [ 0.11768796, -0.14071309, -0.12083745,  0.06787457,  0.62885777],
       [-0.2817895 ,  0.17656166, -0.34628263,  0.12009405,  0.00750235],
       [ 0.08646051,  0.06435416, -0.18605494,  1.35130389,  1.89465993],
       [ 0.11397557, -0.29141067,  0.06005404, -1.53481391,  0.4669313 ]])

'<__main__.Node object at 0x7f296c105dd0>'