In [1]:
import numpy as np
import copy
import statistics
from collections import deque
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
array = np.array

In [2]:
class DecisionTreeClassifier:
    def __init__(self, max_depth = 5):
        self.is_trained = False
        self.tree = []
        self.root = None
        self.max_depth = max_depth
    def fit(self, X, y):
        self.root = Node(X,y)
        to_be_split = deque()
        to_be_split.append(self.root)
        while len(to_be_split) > 0:
            node = to_be_split.popleft()
            self.tree.append(node)
            if node.tree_level == self.max_depth:
                node.is_leaf = True
                continue
            else:
                node.best_split()
                if node.has_split == True:
                    to_be_split.append(node.left_branch)
                    to_be_split.append(node.right_branch)
        self.is_trained = True
    def predict(self, X):
        if len(X) == 1:
            predictions = self.predict_one(X)
        else:
            predictions = [self.predict_one(x) for x in X]
        return predictions
    def predict_one(self, x):
        if self.is_trained == False:
            print('You need to train the model first!')
        else:
            currentNode = self.root
            while currentNode.is_leaf == False:
                split_idx, split_thresh = currentNode.split_on
                if x[split_idx] < split_thresh:
                    currentNode = currentNode.left_branch
                else:
                    currentNode = currentNode.right_branch
            return statistics.mode(currentNode.y)
        
class Node:
    def __init__(self, X, y, tree_level=0):
        self.X = X
        self.y = y
        self.num_features = self.X.shape[1]
        self.num_samples_per_class = array([sum(self.y == C) for C in classes])
        self.G = gini_impurity(self.num_samples_per_class)
        self.has_split = False
        self.split_on = None
        self.left_branch = None
        self.right_branch = None
        self.child_G = 1
        self.is_leaf = False
        self.tree_level = tree_level

    def best_split(self):
        best_G = 1
        split_idx = None
        split_thresh = None
        
        for feat_idx in range(self.num_features):
            thresh, G = best_split_for_feature(self, feat_idx)
 
            if G < best_G:
                best_G = G
                split_idx = feat_idx
                split_thresh = thresh

        if best_G < self.G:
            split(self, split_idx, split_thresh)
            self.split_on = (split_idx, split_thresh)
        else:
            self.is_leaf = True

In [3]:
#Some helper functions

#Split parent_Node on split_idx against split_thresh
#Create left and right branch and compute the resulting gini impurity 
def split(parent_Node, split_idx, split_thresh):
    left_idx = parent_Node.X[:,split_idx] < split_thresh
    right_idx = np.logical_not(left_idx)
    left_X, left_y = parent_Node.X[left_idx], parent_Node.y[left_idx]
    right_X, right_y = parent_Node.X[right_idx], parent_Node.y[right_idx]
    parent_Node.left_branch = Node(left_X, left_y, parent_Node.tree_level+1)
    parent_Node.right_branch = Node(right_X, right_y, parent_Node.tree_level+1)
    parent_Node.has_split = True
    gini_impurity_of_split(parent_Node)

#Find the threshold value that leads to smallest gini impurity for feat_idx
def best_split_for_feature(node, feat_idx):
    node_copy = copy.deepcopy(node)
    feature_array = node_copy.X[:,feat_idx]
    thresh_vals = np.unique(feature_array)
    best_thresh = None
    best_G = 1
    for thresh_val in thresh_vals:
        split(node_copy, feat_idx, thresh_val)
        if node_copy.child_G < best_G:
            best_G = node_copy.child_G
            best_thresh = thresh_val
    return best_thresh, best_G

#Gini impurity of a sample set is the probability of misclassiying a random sample
#that is labeled according to the probability distribution of labels in the set.
#Sum of (probability of class C)*(probability not of class C) across all Cs.
#(1-pk)pk reduces to 1-pk^2 as the sum of pk across all C is 1.

def gini_impurity(num_samples_per_class):
    num_samples = sum(num_samples_per_class)
    if num_samples == 0:
        G = 1
    else:
        G = 1 - sum([(n_k/num_samples)**2 for n_k in num_samples_per_class])
    return G

#The gini impurity of a split is the weighted average of the gini impurities

def gini_impurity_of_split(node):
    if node.has_split == False:
        pass
    else:
        left_y = node.left_branch.y
        right_y = node.right_branch.y
        num_samples = len(left_y) + len(right_y)
        L_num_samples_per_class = array([sum(left_y == C) for C in classes])
        R_num_samples_per_class = array([sum(right_y == C) for C in classes])
        num_L_samples = sum(L_num_samples_per_class)
        num_R_samples = sum(R_num_samples_per_class)
        gini_L = gini_impurity(L_num_samples_per_class)
        gini_R = gini_impurity(R_num_samples_per_class)
        G = (gini_L*num_L_samples+gini_R*num_R_samples)/num_samples
        node.child_G = G

In [4]:
data = datasets.load_iris()
X,y = data.data, data.target
classes = set(y)

trainX, valX, trainy, valy = train_test_split(X,y)
myTreeClassifier = DecisionTreeClassifier()
myTreeClassifier.fit(trainX,trainy)

predictions = myTreeClassifier.predict(valX)

confusion_matrix(valy, predictions)

array([[14,  0,  0],
       [ 0, 10,  1],
       [ 0,  0, 13]])