the ED3 algorithm and print the tree using the iris dataset.

In [1]:
from sklearn.datasets import load_iris

iris = load_iris()

In [2]:
import numpy as np
from collections import Counter

def calculate_entropy(labels):
    """Calculates the entropy of a dataset."""
    n_labels = len(labels)
    if n_labels == 0:
        return 0
    label_counts = Counter(labels)
    entropy = 0
    for label in label_counts:
        probability = label_counts[label] / n_labels
        entropy -= probability * np.log2(probability)
    return entropy

def split_dataset(data, feature_index, threshold):
    """Splits the dataset based on a given feature and threshold."""
    left_subset = data[data[:, feature_index] <= threshold]
    right_subset = data[data[:, feature_index] > threshold]
    return left_subset, right_subset

def build_decision_tree(data, feature_indices, max_depth=None, current_depth=0):
    """Builds the decision tree recursively."""
    labels = data[:, -1]
    unique_labels = np.unique(labels)

    if len(unique_labels) == 1:
        return unique_labels[0]
    if len(feature_indices) == 0 or (max_depth is not None and current_depth >= max_depth):
        return Counter(labels).most_common(1)[0][0]

    best_info_gain = -1
    best_feature_index = None
    best_threshold = None

    current_entropy = calculate_entropy(labels)

    for feature_index in feature_indices:
        feature_values = data[:, feature_index]
        possible_thresholds = np.unique(feature_values)

        for threshold in possible_thresholds:
            left_subset, right_subset = split_dataset(data, feature_index, threshold)

            if len(left_subset) == 0 or len(right_subset) == 0:
                continue

            left_labels = left_subset[:, -1]
            right_labels = right_subset[:, -1]

            info_gain = current_entropy - (len(left_subset) / len(data) * calculate_entropy(left_labels) +
                                            len(right_subset) / len(data) * calculate_entropy(right_labels))

            if info_gain > best_info_gain:
                best_info_gain = info_gain
                best_feature_index = feature_index
                best_threshold = threshold

    if best_info_gain <= 0:
        return Counter(labels).most_common(1)[0][0]

    left_subset, right_subset = split_dataset(data, best_feature_index, best_threshold)
    remaining_feature_indices = [i for i in feature_indices if i != best_feature_index]

    left_subtree = build_decision_tree(left_subset, remaining_feature_indices, max_depth, current_depth + 1)
    right_subtree = build_decision_tree(right_subset, remaining_feature_indices, max_depth, current_depth + 1)

    return {
        'feature_index': best_feature_index,
        'threshold': best_threshold,
        'left': left_subtree,
        'right': right_subtree
    }

In [5]:
data = np.hstack((iris.data, iris.target.reshape(-1, 1)))
feature_indices = list(range(iris.data.shape[1]))
decision_tree = build_decision_tree(data, feature_indices)
print(decision_tree)

{'feature_index': 2, 'threshold': np.float64(1.9), 'left': np.float64(0.0), 'right': {'feature_index': 3, 'threshold': np.float64(1.7), 'left': {'feature_index': 0, 'threshold': np.float64(7.0), 'left': {'feature_index': 1, 'threshold': np.float64(2.8), 'left': np.float64(1.0), 'right': np.float64(1.0)}, 'right': np.float64(2.0)}, 'right': {'feature_index': 0, 'threshold': np.float64(5.9), 'left': {'feature_index': 1, 'threshold': np.float64(3.0), 'left': np.float64(2.0), 'right': np.float64(1.0)}, 'right': np.float64(2.0)}}}


## Visualize the decision tree


In [4]:
def print_tree(node, depth=0):
    indent = "  " * depth
    if isinstance(node, dict):
        print(f"{indent}Feature {node['feature_index']} <= {node['threshold']:.2f}")
        print(f"{indent}  Left:")
        print_tree(node['left'], depth + 1)
        print(f"{indent}  Right:")
        print_tree(node['right'], depth + 1)
    else:
        print(f"{indent}  Predict: {node}")

print_tree(decision_tree)

Feature 2 <= 1.90
  Left:
    Predict: 0.0
  Right:
  Feature 3 <= 1.70
    Left:
    Feature 0 <= 7.00
      Left:
      Feature 1 <= 2.80
        Left:
          Predict: 1.0
        Right:
          Predict: 1.0
      Right:
        Predict: 2.0
    Right:
    Feature 0 <= 5.90
      Left:
      Feature 1 <= 3.00
        Left:
          Predict: 2.0
        Right:
          Predict: 1.0
      Right:
        Predict: 2.0
