# Decision Tree

## Import libraries

In [73]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from matplotlib.text import Annotation

print("Libraries imported!")

Libraries imported!


## Model Architecture

In [74]:
class DecisionTree():

    def __init__(self, max_depth=None):
        self.max_depth = max_depth

    def fit(self, X, y):
        self.n_classes_ = len(np.unique(y))
        self.n_features_ = X.shape[1]
        self.tree_ = self._grow_tree(X, y)

    def predict(self, X):
        y_pred = []
        for x in X:
            node = self.tree_
            while node.left:
                if x[node.feature_idx] < node.threshold:
                    node = node.left
                else:
                    node = node.right
            y_pred.append(node.majority_class)

        return y_pred


    def _grow_tree(self, X, y, depth=0):
        samples_per_class = [np.sum(y == i) for i in range(self.n_classes_)]
        majority_class = np.argmax(samples_per_class)

        node = Node(majority_class=majority_class)
        if depth < self.max_depth:
            feature_idx, threshold = self._best_split(X, y)
            if feature_idx is not None:
                node.num_samples = len(y)
                left_indices = X[:, feature_idx] < threshold
                X_left, y_left = X[left_indices], y[left_indices]
                X_right, y_right = X[~left_indices], y[~left_indices]
                node.feature_idx = feature_idx
                node.threshold = threshold
                node.left = self._grow_tree(X_left, y_left, depth + 1)
                node.right = self._grow_tree(X_right, y_right, depth + 1)
                
        
        return node
    
    def _best_split(self, X, y):
        m = y.size

        # if there is only one class in y, return None, None
        if m <= 1:
            print("hello")
            return None, None
        
        # count of each class in y
        parent_class_count = np.unique(y, return_counts=True)[1]

        # initialize the best gini index with gini index of parent node
        best_gini = self._gini(y)

        # initialize the best feature and threshold with None, None
        best_feature_idx, best_threshold = None, None

        for feature_idx in range(self.n_features_):
            # Sort feature values and corresponding class labels
            thresholds, classes = zip(*sorted(zip(X[:, feature_idx], y)))

            for i in range(1, m):
                left_indices = X[:, feature_idx] < thresholds[i]
                y_left = y[left_indices]
                y_right = y[~left_indices]
                gini_left  = self._gini(y_left)
                gini_right = self._gini(y_right)
                weighted_gini = (len(y_left) / len(y)) * gini_left + (len(y_right) / len(y)) * gini_right
                if thresholds[i] == thresholds[i - 1]:
                    continue

            # Update best index and threshold if current Gini is lower
                if weighted_gini < best_gini:
                    best_gini = weighted_gini
                    best_feature_idx = feature_idx
                    best_threshold = (thresholds[i] + thresholds[i - 1]) / 2

        return best_feature_idx, best_threshold


    def _gini(self, y):
         # returns the count of each value in the array
        _, counts = np.unique(y, return_counts=True)
        # calculate the gini impurity
        impurity = 1 - np.sum(np.square(counts / len(y)))
        return impurity
    
    def traverse_tree(self, node=None, depth=0):
        if node is None:
            node = self.tree_
        if node.is_leaf_node():
            print("  " * depth + f"Leaf Node: Majority class = {node.majority_class}")
        else:
            print("  " * depth + f"Node: Majority class = {node.majority_class}")
            if node.feature_idx is not None:
                print("  " * depth + f"  Split on feature {node.feature_idx} with threshold {node.threshold}")
                print("  " * depth + "  Left:")
                self.traverse_tree(node.left, depth + 1)
                print("  " * depth + "  Right:")
                self.traverse_tree(node.right, depth + 1)
        print("Number of samples in node: {}".format(node.num_samples))

class Node():
    
    def __init__(self, majority_class):
        self.majority_class = majority_class  
        self.feature_idx = None
        self.threshold = None
        self.left = None
        self.right = None
        self.num_samples = None

    def is_leaf_node(self):
        return self.left is None and self.right is None


In [75]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load the iris dataset
iris = load_iris()
X = iris.data
y = iris.target

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train the decision tree
tree = DecisionTree(max_depth=3)
tree.fit(X_train, y_train)

# Make predictions on the test set
y_pred = tree.predict(X_test)

# Compute the accuracy of the predictions
accuracy = accuracy_score(y_test, y_pred)

print(f"Accuracy: {accuracy}")

Accuracy: 1.0


In [77]:
tree.traverse_tree()

Node: Majority class = 1
  Split on feature 2 with threshold 2.45
  Left:
  Leaf Node: Majority class = 0
Number of samples in node: None
  Right:
  Node: Majority class = 1
    Split on feature 2 with threshold 4.75
    Left:
    Node: Majority class = 1
      Split on feature 3 with threshold 1.65
      Left:
      Leaf Node: Majority class = 1
Number of samples in node: None
      Right:
      Leaf Node: Majority class = 2
Number of samples in node: None
Number of samples in node: 37
    Right:
    Node: Majority class = 2
      Split on feature 3 with threshold 1.75
      Left:
      Leaf Node: Majority class = 1
Number of samples in node: None
      Right:
      Leaf Node: Majority class = 2
Number of samples in node: None
Number of samples in node: 43
Number of samples in node: 80
Number of samples in node: 120
