<a href="https://colab.research.google.com/github/szymonszwedzinskiii/DataTemplates/blob/main/DecisionTree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from collections import Counter


In [25]:
X,y = load_iris(return_X_y=True)
def count_gini(y):
    if len(y) == 0:
        return 0
    counts = np.bincount(y)
    probabilities = counts / len(y)
    gini = 1 - np.sum(probabilities ** 2)
    return gini


In [26]:
def most_common_class(y):
    return Counter(y).most_common(1)[0][0]

In [29]:
def find_best_split(X, y):
    n_samples, n_features = X.shape
    best_gini = 1.0
    best_feature, best_threshold = None, None
    for feature_index in range(n_features):
        thresholds = np.unique(X[:, feature_index])
        for threshold in thresholds:
            left_mask = X[:, feature_index] < threshold
            right_mask = ~left_mask

            if sum(left_mask) == 0 or sum(right_mask) == 0:
                continue

            y_left, y_right = y[left_mask], y[right_mask]
            gini_left = count_gini(y_left)
            gini_right = count_gini(y_right)

            weighted_gini = (len(y_left) * gini_left + len(y_right) * gini_right) / n_samples

            if weighted_gini < best_gini:
                best_gini = weighted_gini
                best_feature = feature_index
                best_threshold = threshold

    return best_feature, best_threshold

In [37]:
def build_tree(X,y,depth,max_depth=5,min_samples_split=2):
  if np.unique(y).size == 1:
      return {'leaf': y[0]}

  if depth >= max_depth or len(y) < min_samples_split:
      return {'leaf': most_common_class(y)}

  best_feature, best_threshold = find_best_split(X, y)

  if best_feature is None:
      return {'leaf': most_common_class(y)}

  left_idx = X[:, best_feature] < best_threshold
  right_idx = ~left_idx

  if np.sum(left_idx) == 0 or np.sum(right_idx) == 0:
      return {'leaf': most_common_class(y)}

  left_branch = build_tree(X[left_idx], y[left_idx], depth + 1, max_depth, min_samples_split)
  right_branch = build_tree(X[right_idx], y[right_idx], depth + 1, max_depth, min_samples_split)

  return {
      'feature': best_feature,
      'threshold': best_threshold,
      'left': left_branch,
      'right': right_branch
  }
tree = build_tree(X,y,0,5)
def predict_tree(x, tree):
    if 'leaf' in tree:
        return tree['leaf']
    if x[tree['feature']] < tree['threshold']:
        return predict_tree(x, tree['left'])
    else:
        return predict_tree(x, tree['right'])

predict = [predict_tree(x,tree) for x in X]
print(predict)

[np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1)

In [34]:
def print_tree(tree, depth=0):
    indent = "  " * depth
    if 'leaf' in tree:
        print(f"{indent}Leaf: {tree['leaf']}")
    else:
        print(f"{indent}X[{tree['feature']}] < {tree['threshold']}")
        print(f"{indent}Left:")
        print_tree(tree['left'], depth + 1)
        print(f"{indent}Right:")
        print_tree(tree['right'], depth + 1)

print_tree(tree)

X[2] < 3.0
Left:
  Leaf: 0
Right:
  X[3] < 1.8
  Left:
    X[2] < 5.0
    Left:
      X[3] < 1.7
      Left:
        Leaf: 1
      Right:
        Leaf: 2
    Right:
      X[3] < 1.6
      Left:
        Leaf: 2
      Right:
        X[0] < 7.2
        Left:
          Leaf: 1
        Right:
          Leaf: 2
  Right:
    X[2] < 4.9
    Left:
      X[0] < 6.0
      Left:
        Leaf: 1
      Right:
        Leaf: 2
    Right:
      Leaf: 2
