In [None]:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# ------------------------
# Gini Impurity Calculation
# ------------------------
def gini_impurity(y):
    classes, counts = np.unique(y, return_counts=True)
    impurity = 1 - np.sum((counts / counts.sum()) ** 2)
    return impurity

# ------------------------
# Best Split Function
# ------------------------
def best_split(X, y, n_classes):
    m, n = X.shape
    if m <= 1:
        return None, None

    parent_impurity = gini_impurity(y)
    best_gain = 0
    best_idx, best_thr = None, None

    for idx in range(n):
        thresholds, classes = zip(*sorted(zip(X[:, idx], y)))
        num_left = [0] * n_classes
        num_right = np.bincount(classes, minlength=n_classes)

        for i in range(1, m):
            c = classes[i - 1]
            num_left[c] += 1
            num_right[c] -= 1

            gini_left = 1.0 - sum((num_left[x] / i) ** 2 for x in range(n_classes) if i > 0)
            gini_right = 1.0 - sum((num_right[x] / (m - i)) ** 2 for x in range(n_classes) if (m - i) > 0)


            weighted_gini = (i * gini_left + (m - i) * gini_right) / m
            gain = parent_impurity - weighted_gini

            if thresholds[i] == thresholds[i - 1]:
                continue

            if gain > best_gain:
                best_gain = gain
                best_idx = idx
                best_thr = (thresholds[i] + thresholds[i - 1]) / 2

    return best_idx, best_thr

# ------------------------
# Decision Tree Node
# ------------------------
class Node:
    def __init__(self, gini, num_samples, num_samples_per_class, predicted_class):
        self.gini = gini
        self.num_samples = num_samples
        self.num_samples_per_class = num_samples_per_class
        self.predicted_class = predicted_class
        self.feature_index = None
        self.threshold = None
        self.left = None
        self.right = None

# ------------------------
# Decision Tree Classifier
# ------------------------
class DecisionTreeClassifierScratch:
    def __init__(self, max_depth=None):
        self.max_depth = max_depth
        self.n_classes_ = None
        self.n_features_ = None
        self.tree_ = None

    def fit(self, X, y):
        self.n_classes_ = len(set(y))
        self.n_features_ = X.shape[1]
        self.tree_ = self._grow_tree(X, y)

    def _grow_tree(self, X, y, depth=0):
        num_samples_per_class = [np.sum(y == i) for i in range(self.n_classes_)]
        predicted_class = np.argmax(num_samples_per_class)
        node = Node(
            gini=gini_impurity(y),
            num_samples=len(y),
            num_samples_per_class=num_samples_per_class,
            predicted_class=predicted_class,
        )

        if depth < self.max_depth:
            idx, thr = best_split(X, y, self.n_classes_)
            if idx is not None:
                indices_left = X[:, idx] < thr
                X_left, y_left = X[indices_left], y[indices_left]
                X_right, y_right = X[~indices_left], y[~indices_left]
                node.feature_index = idx
                node.threshold = thr
                node.left = self._grow_tree(X_left, y_left, depth + 1)
                node.right = self._grow_tree(X_right, y_right, depth + 1)
        return node


    def predict(self, X):
        return [self._predict(inputs) for inputs in X]

    def _predict(self, inputs):
        node = self.tree_
        while node.left:
            if inputs[node.feature_index] < node.threshold:
                node = node.left
            else:
                node = node.right
        return node.predicted_class

# ------------------------
# Testing on Iris dataset
# ------------------------
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

tree = DecisionTreeClassifierScratch(max_depth=3)
tree.fit(X_train, y_train)

predictions = tree.predict(X_test)
accuracy = np.mean(predictions == y_test)
print("Decision Tree Accuracy from scratch:", accuracy)

Decision Tree Accuracy from scratch: 0.9555555555555556
