### 使用信息熵寻找最优划分

In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from sklearn import datasets

iris = datasets.load_iris()
# 后两个特诊
X = iris.data[:, 2:]
y = iris.target

In [3]:
from sklearn.tree import DecisionTreeClassifier

dt_clf = DecisionTreeClassifier(max_depth=2, criterion='entropy')

In [4]:
dt_clf.fit(X, y)

DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=2,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

### 模拟信息熵进行划分

In [5]:
# d 是维度
def split(X, y, d, value):
    index_a = (X[:, d] <= value)
    index_b = (X[:, d] > value)
    return X[index_a], X[index_b], y[index_a], y[index_b]

In [6]:
from collections import Counter
from math import log

def entropy(y):
    counter = Counter(y)
    res = 0.0
    for num in counter.values():
        p = num / len(y)
        res += -p * log(p)
    return res

def try_split(X, y):
    best_entropy = float('inf')
    best_d, best_v = -1, -1
    for d in range(X.shape[1]):
        # 对当前维度的所有X值进行排序
        sorted_index = np.argsort(X[:, d])
        for i in range(1, len(X)):
            # 相邻的两点在d维度的中间点
            if X[sorted_index[i-1], d] != X[sorted_index[i], d]:
                v = (X[sorted_index[i-1], d] + X[sorted_index[i], d]) / 2
                X_l, X_r, y_l, y_r = split(X, y, d, v)
                e = entropy(y_l) + entropy(y_r)
                if e < best_entropy:
                    best_entropy, best_d, best_v = e, d, v
    return best_entropy, best_d, best_v

In [9]:
best_entropy, best_d, best_v = try_split(X, y)
print("best_entropy = ", best_entropy)
print("best_d = ", best_d)
print("best_v = ", best_v)

best_entropy =  0.6931471805599453
best_d =  0
best_v =  2.45


In [10]:
X1_l, X1_r, y1_l, y1_r = split(X, y, best_d, best_v)

In [11]:
entropy(y1_l)

0.0

In [12]:
entropy(y1_r)

0.6931471805599453

In [13]:
best_entropy2, best_d2, best_v2 = try_split(X1_r, y1_r)
print("best_entropy2 = ", best_entropy2)
print("best_d2 = ", best_d2)
print("best_v2 = ", best_v2)

best_entropy2 =  0.4132278899361904
best_d2 =  1
best_v2 =  1.75


In [14]:
X12_l, X2_r, y2_l, y2_r = split(X1_r, y1_r, best_d2, best_v2)

In [15]:
entropy(y2_l)

0.30849545083110386

In [16]:
entropy(y2_r)

0.10473243910508653