In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
class BinaryTree:

    def __init__(self):

        self.feature = None
        self.threshold = None
        self.left = None
        self.right = None

In [13]:
class MyDecisionTree:

    n_splits = 5

    def fit(self, X, y):

        def _gini_impurity(labels):
            count_instances = {}
            for label in labels:
                if label in count_instances:
                    count_instances[label] += 1
                else:
                    count_instances[label] = 1
            gini, len_data = 1, len(labels)
            for count in count_instances.values():
                gini -= (count / len_data) ** 2

            return gini
        
        def _best_split(X, y):

            ginis, thresholds = [], []

            for column in X.T:

                min_cost, min_threshold = 1e10, 1e10
                costs = []
                column_sorted, y_sorted = zip(*sorted(zip(column, y)))

                for i in range(1, len(column_sorted) - 1):

                    cost =  (i / len(column_sorted)) * _gini_impurity(y_sorted[:i]) + \
                    ((len(column_sorted) - i) / len(column_sorted)) * _gini_impurity(y_sorted[i:])
                    costs.append((cost, column_sorted[i]))

                    if cost < min_cost:
                        min_cost = cost
                        min_threshold = column_sorted[i]
                
                ginis.append(min_cost)
                thresholds.append(min_threshold)
                print(costs)
            
            print(ginis, thresholds)

            print(f'Best split at X[{np.argmin(ginis)}]: >={thresholds[np.argmin(ginis)]}\nImpurity at split: {ginis[np.argmin(ginis)]}')

            return int(np.argmin(ginis)), thresholds[np.argmin(ginis)]

        def _build_tree(X, y, depth=0):

            if len(set(y)) == 1:
                return BinaryTree()
            
            feature, threshold = _best_split(X, y)
            feature = int(feature)

            node = BinaryTree()
            node.feature = feature
            node.threshold = threshold

            left_idx = X[:, feature] < threshold
            right_idx = ~left_idx

            node.left = _build_tree(X[left_idx], y[left_idx], depth + 1)
            node.right = _build_tree(X[right_idx], y[right_idx], depth + 1)

            return node
        
        self.root = _build_tree(X, y)
        return self
    
    def predict(self, X):

        def _traverse_tree(x, node):

            if node.left is None and node.right is None:
                return 1
            
            if x[node.feature] < node.threshold:
                return _traverse_tree(x, node.left)
            else:
                return _traverse_tree(x, node.right)
        
        predictions = [ _traverse_tree(x, self.root) for x in X ]

        return np.array(predictions)

In [23]:
X = np.array([
    [10, 3, 2],
    [20, 4, 3],
    [30, 2, 4],
    [10, 4, 2],
    [20, 3, 3],
    [30, 4, 4],
    [10, 2, 2],
    [20, 2, 3],
    [30, 3, 4],
    [10, 3, 2],
    [20, 4, 3],
    [30, 2, 4],
    [10, 4, 2],
    [20, 3, 3],
    [30, 4, 4],
    [10, 2, 2],
])
y = np.array([0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0])

tree = MyDecisionTree()
tree.fit(X, y)

[(0.44999999999999996, 10), (0.40178571428571425, 10), (0.34615384615384615, 10), (0.28125, 10), (0.20454545454545447, 10), (0.11249999999999996, 20), (0.0, 20), (0.109375, 20), (0.19444444444444442, 20), (0.2625, 20), (0.3181818181818182, 30), (0.3645833333333332, 30), (0.40384615384615385, 30), (0.4375, 30)]
[(0.44999999999999996, 2), (0.40178571428571425, 2), (0.34615384615384615, 2), (0.42708333333333337, 2), (0.46818181818181814, 3), (0.42916666666666675, 3), (0.373015873015873, 3), (0.421875, 3), (0.4563492063492064, 3), (0.47916666666666663, 4), (0.4409090909090909, 4), (0.3645833333333332, 4), (0.40384615384615385, 4), (0.4375, 4)]
[(0.44999999999999996, 2), (0.40178571428571425, 2), (0.34615384615384615, 2), (0.28125, 2), (0.20454545454545447, 2), (0.11249999999999996, 3), (0.0, 3), (0.109375, 3), (0.19444444444444442, 3), (0.2625, 3), (0.3181818181818182, 4), (0.3645833333333332, 4), (0.40384615384615385, 4), (0.4375, 4)]
[0.0, 0.34615384615384615, 0.0] [20, 2, 3]
Best split 

ValueError: not enough values to unpack (expected 2, got 0)

In [22]:
def print_tree(node, depth=0):
    if node is not None:
        print_tree(node.left, depth + 1)
        print(" " * depth * 4, f"Feature {node.feature} >= {node.threshold}")
        print_tree(node.right, depth + 1)

print_tree(tree.root)

     Feature None >= None
 Feature 0 >= 20
     Feature None >= None
