## 木の生成過程

In [1]:
import numpy as np
import pandas as pd

最初のノードを分割する

In [2]:
# ノードの分割
def build(x, y):
    '''
    x: 説明変数
    y: 目的変数
    '''

    num_data = x.shape[0] # データ数
    num_features = x.shape[1] # 特徴量の数

    best_gini_index = 0.0
    best_feature = None
    best_threshold = None
    gini = 1.0

    c_gini = gini_func(y)

    for f in range(num_features):
        data_f = np.unique(x[:, f])
        points = (data_f[:-1,] + data_f[1:]) / 2.0

        for threshold in points:

            # しきい値でデータを2分割する
            y_r = y[x[:, f] < threshold]
            y_l = y[x[:, f] >= threshold]

            # 分割したあとのノードそれぞれからGini係数を計算
            gini_r = gini_func(y_r)
            gini_l = gini_func(y_l)
            pr = float(y_r.shape[0]) / num_data
            pl = float(y_l.shape[0]) / num_data
            gini_index = gini - (pl * gini_l + pr * gini_r)

            # 算出したジニ係数がそれまでの最高よりも大きければ更新
            if gini_index > best_gini_index:
                best_gini_index = gini_index
                best_feature = f
                best_threshold = threshold
    return best_feature, best_threshold

In [3]:
# 今回は不純度としてジニ係数を使用
def gini_func(target):
    val = 1.0
    classes = np.unique(target)
    num_data = len(target)
    for c in classes:
        p = float(len(target[target == c])) / num_data
        val -= p ** 2.0

    return val

In [4]:
# 試しに少量のデータのみ用いて実験
df = pd.read_csv('../data/basedata/train.csv', nrows=10)
x_columns = df.columns.tolist()[1:]
x = np.array(df.iloc[:,1:])
y = np.array(df.iloc[:,0])

best_feature, best_threshold = build(x,y)
print('Best Feature: {}, Threshold: {}'.format(x_columns[best_feature], best_threshold))

Best Feature: Sex_female, Threshold: 0.5


In [5]:
# 結果の確認
y_r = y[x[:, best_feature] < best_threshold]
y_l = y[x[:, best_feature] >= best_threshold]
print('Right: {} \nLeft: {}'.format(y_r, y_l))

Right: [0 0 0 0 0] 
Left: [1 1 1 1 1]


In [6]:
# 別の特徴量での結果の確認
y_r = y[x[:, best_feature - 2] < best_threshold]
y_l = y[x[:, best_feature - 2] >= best_threshold]
print('Right: {} \nLeft: {}'.format(y_r, y_l))

Right: [1 1 0 1] 
Left: [0 1 0 0 0 1]
