In [1]:
import random
from collections import Counter

# Gini impurity
def gini(y):
    counts = Counter(y)
    impurity = 1
    for c in counts.values():
        p = c / len(y)
        impurity -= p ** 2
    return impurity

In [2]:
def split_data(X, y, feature, value):
    left_X, left_y, right_X, right_y = [], [], [], []
    for i in range(len(X)):
        if X[i][feature] <= value:
            left_X.append(X[i])
            left_y.append(y[i])
        else:
            right_X.append(X[i])
            right_y.append(y[i])
    return left_X, left_y, right_X, right_y

In [None]:
def best_split(X, y, features):
    best_feature, best_value, best_gini = None, None, float("inf")
    for f in features:
        values = set(x[f] for x in X)
        for v in values:
            lX, ly, rX, ry = split_data(X, y, f, v)
            if len(ly) == 0 or len(ry) == 0:
                continue
            g = (len(ly)/len(y))*gini(ly) + (len(ry)/len(y))*gini(ry)
            if g < best_gini:
                best_gini = g
                best_feature, best_value = f, v
    return best_feature, best_value

In [None]:
def build_tree(X, y, depth=0, max_depth=3):
    if len(set(y)) == 1 or depth == max_depth:
        return Counter(y).most_common(1)[0][0]

    features = random.sample(range(len(X[0])), int(len(X[0]) ** 0.5))
    f, v = best_split(X, y, features)

    if f is None:
        return Counter(y).most_common(1)[0][0]

    lX, ly, rX, ry = split_data(X, y, f, v)

    return {
        "feature": f,
        "value": v,
        "left": build_tree(lX, ly, depth+1),
        "right": build_tree(rX, ry, depth+1)
    }