Modify the Decision Tree scratch code in our lecture such that:
- Modify the scratch code so it can accept an hyperparameter <code>max_depth</code>, in which it will continue create the tree until max_depth is reached.</li>
- Put everything into a class <code>DecisionTree</code>.  It should have at least two methods, <code>fit()</code>, and <code>predict()</code>
- Load the iris data and try with your class</li>

In [43]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import numpy as np

In [2]:
iris = load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [40]:
#To help with our implementation, we create a class Node
class Node:
    def __init__(self, gini, num_samples, num_samples_per_class, predicted_class):
        self.gini = gini
        self.num_samples = num_samples
        self.num_samples_per_class = num_samples_per_class
        self.predicted_class = predicted_class
        self.feature_index = 0
        self.threshold = 0
        self.left = None
        self.right = None

class DescissionTree:
    def __init__(self, max_depth=None):
        self.max_depth = max_depth

    def find_split(self, X, y, n_classes):
        """ Find split where children has lowest impurity possible
        in condition where the purity should also be less than the parent,
        if not, stop.
        """
        n_samples, n_features = X.shape
        if n_samples <= 1:
            return None, None
        
        #so it will not have any warning about "referenced before assignments"
        feature_ix, threshold = None, None
        
        # Count of each class in the current node.
        sample_per_class_parent = [np.sum(y == c) for c in range(n_classes)] #[2, 2]
        
        # Gini of parent node.
        best_gini = 1.0 - sum((n / n_samples) ** 2 for n in sample_per_class_parent)

        # Loop through all features.
        for feature in range(n_features):
            sample_sorted = sorted(X[:, feature]) #[2, 3, 10, 19]
            sort_idx = np.argsort(X[:, feature])
            y_sorted = y[sort_idx] #[0, 0, 1, 1]
                    
            sample_per_class_left = [0] * n_classes   #[0, 0]
            
            sample_per_class_right = sample_per_class_parent.copy() #[2, 2]
            
            for i in range(1, n_samples): #1 to 3 (excluding 4)
                #the class of that sample
                c = y_sorted[i - 1]  #[0]
                
                #put the sample to the left
                sample_per_class_left[c] += 1  #[1, 0]
                            
                #take the sample out from the right  [1, 2]
                sample_per_class_right[c] -= 1
                
                gini_left = 1.0 - sum(
                    (sample_per_class_left[x] / i) ** 2 for x in range(n_classes)
                )
                            
                #we divided by n_samples - i since we know that the left amount of samples
                #since left side has already i samples
                gini_right = 1.0 - sum(
                    (sample_per_class_right[x] / (n_samples - i)) ** 2 for x in range(n_classes)
                )

                #weighted gini
                weighted_gini = ((i / n_samples) * gini_left) + ( (n_samples - i) /n_samples) * gini_right

                # in case the value are the same, we do not split
                # (both have to end up on the same side of a split).
                if sample_sorted[i] == sample_sorted[i - 1]:
                    continue

                if weighted_gini < best_gini:
                    best_gini = weighted_gini
                    feature_ix = feature
                    threshold = (sample_sorted[i] + sample_sorted[i - 1]) / 2  # midpoint

        #return the feature number and threshold 
        #used to find best split
        return feature_ix, threshold
    
    def fit(self, Xtrain, ytrain, n_classes, depth=0, ):  
        n_samples, n_features = Xtrain.shape
        num_samples_per_class = [np.sum(ytrain == i) for i in range(n_classes)]
        #predicted class using the majority of sample class
        predicted_class = np.argmax(num_samples_per_class)
        
        #define the parent node
        node = Node(
            gini = 1 - sum((np.sum(ytrain == c) / n_samples) ** 2 for c in range(n_classes)),
            predicted_class=predicted_class,
            num_samples = ytrain.size,
            num_samples_per_class = num_samples_per_class,
            )
            
        if self.max_depth is not None and depth == self.max_depth:
            return node
        else:
            feature, threshold = self.find_split(Xtrain, ytrain, n_classes)
            if feature is not None:
                #take all the indices that is less than threshold
                indices_left = Xtrain[:, feature] < threshold
                X_left, y_left = Xtrain[indices_left], ytrain[indices_left]

                #tilde for negation
                X_right, y_right = Xtrain[~indices_left], ytrain[~indices_left]

                #take note for later decision
                node.feature_index = feature
                node.threshold = threshold
                node.left = self.fit(X_left, y_left, n_classes, depth + 1)
                node.right = self.fit(X_right, y_right, n_classes, depth + 1)
            return node

    def predict(self, X, tree):
        first_node = tree
        predicted = []
        for i in range(X.shape[0]):
            tree = first_node
            while tree.left:
                if X[i, tree.feature_index] < tree.threshold:
                    tree = tree.left
                else:
                    tree = tree.right
            predicted.append(tree.predicted_class)
        return predicted

In [50]:
model = DescissionTree(max_depth=10)
tree = model.fit(X_train, y_train, 3)
predicted = model.predict(X_test, tree)

print(classification_report(y_test, predicted))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        19
           1       1.00      0.85      0.92        13
           2       0.87      1.00      0.93        13

    accuracy                           0.96        45
   macro avg       0.96      0.95      0.95        45
weighted avg       0.96      0.96      0.96        45

