### 二分类信息熵

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def entropy(p):
    return -p * np.log(p) - (1 - p) * np.log(1 - p)

In [None]:
x = np.linspace(0.001, 0.999, 100)
plt.plot(x, entropy(x))
plt.show()

### 数据集

In [None]:
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, 1:3]
y = iris.target

In [None]:
# plt.scatter(X[y==0, 0], X[y==0, 1])
# plt.scatter(X[y==1, 0], X[y==1, 1])
# plt.scatter(X[y==2, 0], X[y==2, 1])
# or
plt.scatter(X[:,0], X[:,1], c = y)
plt.show()

In [None]:
from sklearn.tree import DecisionTreeClassifier
dt_clf = DecisionTreeClassifier(max_depth=2, criterion="entropy")
dt_clf.fit(X, y)

In [None]:
def decision_boundary_plot(X, y, clf):
    axis_x1_min, axis_x1_max = X[:, 0].min(), X[:, 0].max()
    axis_x2_min, axis_x2_max = X[:, 1].min(), X[:, 1].max()
    
    x1, x2 = np.meshgrid(np.linspace(axis_x1_min, axis_x1_max, 1000), np.linspace(axis_x2_min, axis_x2_max, 1000))
    z = clf.predict(np.c_[x1.ravel(), x2.ravel()])
    z = z.reshape(x1.shape)
    
    from matplotlib.colors import ListedColormap
    custom_cmap = ListedColormap(['#F5B9FF', '#FFFFFF', '#F9F9CB'])
    
    plt.contourf(x1, x2, z, cmap=custom_cmap)
    plt.scatter(X[:, 0], X[:, 1], c=y)
 
    plt.show()

In [None]:
decision_boundary_plot(X, y, dt_clf)

In [None]:
from sklearn.tree import plot_tree
plot_tree(dt_clf)

### 最优划分条件

In [None]:
from collections import Counter
Counter(y)

In [None]:
def calc_entropy(y):
    res = 0
    for v in Counter(y).values():
        p = v / len(y)
        res += -p * np.log2(p)
    return res

In [None]:
calc_entropy(y)

In [None]:
def split_dataset(X, y, dim, value):
    index_a = (X[:, dim] <= value)
    index_b = (X[:, dim] > value)
    return X[index_a], X[index_b], y[index_a], y[index_b]

In [None]:
def try_split(X, y):
    best_entropy = float('inf')
    best_dim, best_value = -1, -1
    best_entropy_left, best_entropy_right = None, None
    for dim in range(X.shape[1]):
        sorted_index = np.argsort(X[:, dim])
        for i in range(X.shape[0] - 1):
            value_left, value_right = X[sorted_index[i], dim], X[sorted_index[i + 1], dim]
            if value_left != value_right:
                value = (value_left + value_right) / 2
                X_l, X_r, y_l, y_r = split_dataset(X, y, dim, value)
                entropy_l = calc_entropy(y_l)
                entropy_r = calc_entropy(y_r)
                e = (len(X_l) * entropy_l + len(X_r) * entropy_r) / X.shape[0]
                if e < best_entropy:
                    best_entropy, best_dim, best_value = e, dim, value
                    best_entropy_left, best_entropy_right = entropy_l, entropy_r
    return best_dim, best_value, best_entropy, best_entropy_left, best_entropy_right


In [None]:
try_split(X, y)

In [None]:
x_left, x_right, y_left, y_right = split_dataset(X, y, 1, 2.45)
try_split(x_right, y_right)
