In [44]:
import numpy as np
import pandas as pd

In [45]:
data = pd.read_csv('data.csv')
data

Unnamed: 0,f1,f2,f3,y
0,13.27,4.28,2.26,2
1,11.56,2.05,3.23,1
2,11.82,1.72,1.88,1
3,13.05,3.86,2.32,1
4,12.2,3.03,2.32,2
5,12.93,2.81,2.7,2
6,13.39,1.77,2.62,0
7,11.64,2.06,2.46,1
8,12.69,1.53,2.26,1
9,13.48,1.81,2.41,0


In [46]:
def get_split_H(target):
    unique_targets = np.unique(target)
    H = 0
    for target_value in unique_targets:
        p = target[target == target_value].shape[0] / target.shape[0]
        H += p * (1 - p)
    return H

def find_split(feature_vector, target_vector):
    ids = np.argsort(feature_vector)
    feature_vector = feature_vector[ids]
    target_vector = target_vector[ids]

    thresholds = (feature_vector[:-1] + feature_vector[1:]) / 2  # считаем пороговые значения
    ginis = []
    for threshold in thresholds:
        left = feature_vector[feature_vector < threshold]
        left_target = target_vector[feature_vector < threshold]
        right = feature_vector[feature_vector >= threshold]
        right_target = target_vector[feature_vector >= threshold]
        Hl = get_split_H(left_target)
        Hr = get_split_H(right_target)
        gini = (
            - left.shape[0] / feature_vector.shape[0] * Hl 
            - right.shape[0] / feature_vector.shape[0] * Hr 
        )
        ginis.append(gini)
    ginis = np.array(ginis)
    return thresholds, ginis, thresholds[np.argmax(ginis)], np.amax(ginis)

# Дерево 1

Первый сплит

In [47]:
y = data['y'].to_numpy()
f1 = data['f1'].to_numpy()
thresholds, ginis, threshold_best, gini_best = find_split(f1, y)
threshold_best

12.73

In [48]:
(ginis >= gini_best).sum()

1

Видим, что минимальный джини только один, а значит только один наилучший порог. Отбраковываем деревья 3, 4.

In [49]:
left_data = data[data['f1'] < threshold_best]
right_data = data[data['f1'] >= threshold_best]
left_f2 = left_data['f2'].to_numpy()
left_y = left_data['y'].to_numpy()
right_f2 = right_data['f2'].to_numpy()
right_y = right_data['y'].to_numpy()

Второй левый сплит

In [50]:
thresholds, ginis, threshold_best, gini_best = find_split(left_f2, left_y)
threshold_best

2.545

In [51]:
left_left = left_data[left_data['f2'] < threshold_best]
left_right = left_data[left_data['f2'] >= threshold_best]
display(left_left['y'].value_counts())
display(left_right['y'].value_counts())

1    8
Name: y, dtype: int64

2    1
Name: y, dtype: int64

Второй правый сплит

In [52]:
thresholds, ginis, threshold_best, gini_best = find_split(right_f2, right_y)
threshold_best

2.1

In [53]:
right_left = right_data[right_data['f2'] < threshold_best]
right_right = right_data[right_data['f2'] >= threshold_best]
display(right_left['y'].value_counts())
display(right_right['y'].value_counts())

0    4
2    1
Name: y, dtype: int64

2    3
1    1
Name: y, dtype: int64

Окей, дерево 1 подходит. Поймём, что не так с деревом 2.

# Дерево 2

In [54]:
y = data['y'].to_numpy()
f1 = data['f1'].to_numpy()
thresholds, ginis, threshold_best, gini_best = find_split(f1, y)
threshold_best

12.73

In [55]:
left_data = data[data['f1'] < threshold_best]
right_data = data[data['f1'] >= threshold_best]
left_f2 = left_data['f2'].to_numpy()
left_y = left_data['y'].to_numpy()
right_f3 = right_data['f3'].to_numpy()
right_y = right_data['y'].to_numpy()

In [56]:
thresholds, ginis, threshold_best, gini_best = find_split(left_f2, left_y)
threshold_best

2.545

In [57]:
(ginis >= gini_best).sum()

1

Левый сплит может быть только один. Значит дерево построено некорректно.