<a href="https://colab.research.google.com/github/isa-ulisboa/greends-pml/blob/main/notebooks/prune_decision_tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import _tree

# define a function prune_index that recursively prunes a decision tree by setting the left and right child indices of a node to TREE_LEAF if their sum of values is below a threshold
def prune_index(inner_tree, index, threshold):
    if inner_tree.children_left[index] != _tree.TREE_LEAF:
        prune_index(inner_tree, inner_tree.children_left[index], threshold)
        prune_index(inner_tree, inner_tree.children_right[index], threshold)
        # The tree_ object has several properties, including children_left and children_right, which are arrays that represent the indices of the left and right children of each node in the tree
        left_child = inner_tree.children_left[index]
        right_child = inner_tree.children_right[index]
        # the value array contains the class distribution at each leaf node of the decision tree. The value array has shape (n_nodes, n_classes, n_outputs), where
        # n_nodes is the number of nodes in the tree,
        # n_classes is the number of classes in the classification problem, and
        # n_outputs is the number of outputs in the model.
        if (inner_tree.value[left_child][0][1] + inner_tree.value[right_child][0][1]) <= threshold:
            inner_tree.children_left[index] = _tree.TREE_LEAF
            inner_tree.children_right[index] = _tree.TREE_LEAF

# function that repeatedly fits a decision tree on the training data, evaluates its accuracy on the validation set, and prunes it if the accuracy falls below a threshold
def prune_tree(tree, X_train, y_train, X_val, y_val, threshold):
    while True:
        # fit the decision tree on the training data
        tree.fit(X_train, y_train)

        # calculate the accuracy on the validation set
        y_pred = tree.predict(X_val)
        acc = accuracy_score(y_val, y_pred)

        # check if pruning is possible
        if acc >= threshold:
            break  # In Python, the break statement is used to terminate the execution of a loop prematurely.

        # prune the decision tree
        # tree.tree_ represents the underlying binary tree structure of the decision tree, and contains information about the nodes and branches of the tree.
        prune_index(tree.tree_, 0, threshold)

    return tree

# load the iris dataset
iris = load_iris()

# split the dataset into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)

# create a decision tree classifier
tree = DecisionTreeClassifier()

# prune the decision tree
pruned_tree = prune_tree(tree, X_train, y_train, X_val, y_val, threshold=0.9)

# evaluate the pruned decision tree on the test set
y_pred = pruned_tree.predict(X_val)
acc = accuracy_score(y_val, y_pred)
print(f"Accuracy: {acc}")
