In [1]:
import numpy as np
x1 = [0, 1, 1, 2, 2, 2]
x2 = [0, 0, 1, 1, 1, 0]
y = np.array([0, 0, 0, 1, 1, 0])

In [22]:
# 得到分为一组的样本的索引
def partition(a):
    return {c: (a==c).nonzero()[0] for c in np.unique(a)}

for k, v in partition(x1).items():
    print(k, v)

0 [0]
1 [1 2]
2 [3 4 5]


In [17]:
def entropy(s):
    res = 0
    val, counts = np.unique(s, return_counts=True)
    freqs = counts.astype('float')/len(s)
    for p in freqs:
        if p != 0.0:
            res -= p * np.log2(p)
    return res

In [18]:
def mutual_information(y, x):

    res = entropy(y)

    # We partition x, according to attribute values x_i
    val, counts = np.unique(x, return_counts=True)
    freqs = counts.astype('float')/len(x)

    # We calculate a weighted average of the entropy
    for p, v in zip(freqs, val):
        # print(y[x == v]) 用到了numpy的广播
        res -= p * entropy(y[x == v])

    return res

In [19]:
mutual_information(y, x1)

0.45914791702724478

In [20]:
np.array([x1, x2])

array([[0, 1, 1, 2, 2, 2],
       [0, 0, 1, 1, 1, 0]])

In [21]:
from pprint import pprint

def is_pure(s):
    return len(set(s)) == 1

def recursive_split(x, y):
    # If there could be no split, just return the original set
    if is_pure(y) or len(y) == 0:
        return y

    # We get attribute that gives the highest mutual information
    gain = np.array([mutual_information(y, x_attr) for x_attr in x.T])
    selected_attr = np.argmax(gain)

    # If there's no gain at all, nothing has to be done, just return the original set
    if np.all(gain < 1e-6):
        return y


    # We split using the selected attribute
    sets = partition(x[:, selected_attr])

    res = {}
    for k, v in sets.items():
        y_subset = y.take(v, axis=0)
        x_subset = x.take(v, axis=0)

        res["x_%d = %d" % (selected_attr, k)] = recursive_split(x_subset, y_subset)

    return res

X = np.array([x1, x2]).T
pprint(recursive_split(X, y))

{'x_0 = 0': array([0]),
 'x_0 = 1': array([0, 0]),
 'x_0 = 2': {'x_1 = 0': array([0]), 'x_1 = 1': array([1, 1])}}


## Source

[Implementing Decision Trees in Python](http://gabrielelanaro.github.io/blog/2016/03/03/decision-trees.html)