In [1]:
import pandas as pd
import numpy as np


In [2]:
class Node:
    def __init__(self, label=None, feature=None, threshold=None, left=None, right=None):
        self.label = label
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right


def cart(data, target_attr, attrs):

    if len(np.unique(data[target_attr])) == 1:
        return Node(label=data[target_attr].iloc[0])

    if len(attrs) == 0:
        return Node(label=data[target_attr].value_counts().idxmax())

    best_feature, best_threshold, best_gini_index = None, None, None

    for attr in attrs:
        for threshold in np.unique(data[attr])[:-1]:
            left_data = data[data[attr] <= threshold]
            right_data = data[data[attr] > threshold]
            gini_index = calc_gini_index(left_data[target_attr], right_data[target_attr])
            if best_gini_index is None or gini_index < best_gini_index:
                best_feature, best_threshold, best_gini_index = attr, threshold, gini_index

    left_data = data[data[best_feature] <= best_threshold]
    right_data = data[data[best_feature] > best_threshold]

    left_node = cart(left_data, target_attr, [attr for attr in attrs if attr != best_feature])
    right_node = cart(right_data, target_attr, [attr for attr in attrs if attr != best_feature])

    return Node(feature=best_feature, threshold=best_threshold, left=left_node, right=right_node)


def calc_gini_index(left_data, right_data):
    n = len(left_data) + len(right_data)
    left_weight = len(left_data) / n
    right_weight = len(right_data) / n
    left_gini = calc_gini(left_data)
    right_gini = calc_gini(right_data)
    return left_weight * left_gini + right_weight * right_gini


def calc_gini(data):
    gini = 1
    for value in data.unique():
        prob = len(data[data == value]) / len(data)
        gini -= prob ** 2
    return gini



def print_tree(node, depth=0):
    prefix = "  " * depth
    if node.label is not None:
        print(f"{prefix}--> {node.label}")
    elif node.feature is not None:
        print(f"{prefix}{node.feature} <= {node.threshold}")
        print_tree(node.left, depth + 1)
        print_tree(node.right, depth + 1)

In [4]:

data = pd.read_excel(r"../dataset/buy_car.xlsx")
target_attr = "Buys_Car"
attrs = [attr for attr in data.columns if attr != target_attr]
tree = cart(data, target_attr, attrs)
print_tree(tree)


Age <= MiddleAge
  --> Yes
  Marital_Status <= No
    Income <= High
      --> No
      Cred_rating <= Excellent
        --> No
        --> Yes
    Cred_rating <= Excellent
      Income <= Low
        --> No
        --> Yes
      --> Yes
