## Pseudo code: Decision Tree

```
function build_tree(data, labels, depth=0):
    if stop_condition(data, labels, depth):
        return LeafNode(class = majority_class(labels))
    
    best_gain = 0
    best_feat, best_thresh = None, None
    parent_impurity = impurity(labels)

    for each feature j in 1…D:
        for each threshold t in unique_values(data[:,j]):
            left_labels  = labels[data[:,j] ≤ t]
            right_labels = labels[data[:,j] >  t]

            if len(left_labels)==0 or len(right_labels)==0: continue

            gain = parent_impurity \
                   - (|left|/|total|)*impurity(left_labels) \
                   - (|right|/|total|)*impurity(right_labels)

            if gain > best_gain:
                best_gain, best_feat, best_thresh = gain, j, t

    if best_gain < min_impurity_decrease:
        return LeafNode(class = majority_class(labels))

    left_data, left_labels  = split(data, labels, best_feat, best_thresh, side="left")
    right_data, right_labels = split(data, labels, best_feat, best_thresh, side="right")

    left_subtree  = build_tree(left_data,  left_labels,  depth+1)
    right_subtree = build_tree(right_data, right_labels, depth+1)

    return DecisionNode(
        feature_index = best_feat,
        threshold     = best_thresh,
        left          = left_subtree,
        right         = right_subtree
    )

```

In [1]:
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 100)
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

from mlcore.decision_tree import CustomDecisionTreeClassifier

In [2]:
def accuracy(y_true, y_pred):
    accuracy = np.sum(y_true == y_pred) / len(y_true)
    return accuracy
X, y = datasets.make_classification(n_samples=10000, n_features=10, random_state=4)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)

In [7]:
tree = CustomDecisionTreeClassifier(
                                max_depth=5, 
                                max_features="sqrt",
                                min_samples_split=2, 
                                min_samples_leaf=1,
                                criterion="gini",
                                min_impurity_decrease=1e-7,
                                random_state=42
                                )
tree.fit(X_train, y_train)
y_pred = tree.predict(X_test)

print("Accuracy:", accuracy(y_test, y_pred))

Accuracy: 0.8905


In [8]:
sk_tree = DecisionTreeClassifier(
                            max_depth=5, 
                            max_features="sqrt",
                            min_samples_split=2, 
                            min_samples_leaf=1,
                            criterion="gini",
                            min_impurity_decrease=1e-7,
                            random_state=42
                            )
sk_tree.fit(X_train, y_train)
y_sk_pred = sk_tree.predict(X_test)

print("Sklearn Accuracy:", accuracy(y_test, y_sk_pred))

Sklearn Accuracy: 0.8915


In [5]:
CustomDecisionTreeClassifier.print_tree(tree)

Feature[5] ≤ 0.0018  |  Gain=0.2738
→ True branch:
    Feature[0] ≤ 1.0656  |  Gain=0.0104
    → True branch:
        Feature[0] ≤ 0.7066  |  Gain=0.0724
        → True branch:
            Feature[6] ≤ 0.8228  |  Gain=0.0150
            → True branch:
                Feature[5] ≤ -0.1705  |  Gain=0.0129
                → True branch:
                    Predict: 1 (samples=2035)
                → False branch:
                    Predict: 1 (samples=609)
            → False branch:
                Feature[8] ≤ -1.3127  |  Gain=0.0235
                → True branch:
                    Predict: 0 (samples=12)
                → False branch:
                    Predict: 1 (samples=102)
        → False branch:
            Feature[2] ≤ 0.6896  |  Gain=0.0884
            → True branch:
                Feature[8] ≤ 2.2682  |  Gain=0.0052
                → True branch:
                    Predict: 0 (samples=303)
                → False branch:
                    Predict: 1 (samples=5)
      