In [45]:
import numpy as np

In [46]:
class Node:
    """Node in decision tree
    """
    def __init__(self, feature=None, threshold=None, left=None, right=None, info_gain=None, value=None):
        # Internal node
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.info_gain = info_gain
        
        # Leaf node 
        self.value = value

In [47]:
class DecisionTreeClassifier:
    """ Decision tree classifier

    Atrributes:
    -----------
    criterion: {'gini', 'entropy'}, default to 'gini'
        The function to measure the quality of a split

    max_depth: int, default to None
        The maximum depth of the tree

    min_samples_split: int, default=2
        The minimum number of samples required to split an internal node
    """
    def __init__(self, criterion='gini', max_depth=3, min_samples_split=2):

        self.root = None
        self.criterion = criterion
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        
    def _calulate_leaf_value(self, y):
        y = list(y)
        return max(y, key=y.count)
        
    def _entropy(self, y):
        labels, counts = np.unique(y, return_counts=True)
        prob = counts / len(y)
        entropy = -np.sum(prob * np.log2(prob))
        return entropy
    
    def _gini_impurity(self, y):
        labels, counts = np.unique(y, return_counts=True)
        prob = counts / len(y)
        gini = 1 - np.sum(np.square(prob))
        return gini
    
    def _information_gain(self, parent, left, right, criterion):
        """Compute the information gain of a split
        """
        left_weight = len(left) / len(parent)
        right_weight = len(right) / len(parent)
        if criterion == 'gini':
            gain = self._gini_impurity(parent) - left_weight * self._gini_impurity(left)\
                - right_weight * self._gini_impurity(right)
        elif criterion == 'entropy':
            gain = self._entropy(parent) - left_weight * self._entropy(left)\
                - right_weight * self._entropy(right)
        else:
            raise ValueError('Criterion should be "gini" or "entropy"')
        return gain 
    
    def _split(self, X, y, feature, threshold):
        left = (X[:, feature] <= threshold)
        right = (X[:, feature] > threshold)
        X_left = X[left]
        X_right = X[right]
        y_left = y[left]
        y_right = y[right]
        return X_left, y_left, X_right, y_right
    
    def _get_best_split(self, X, y, criterion):
        """Find the best split
        """
        best_split = dict()
        n_features = X.shape[1]
        max_info_gain = -float("inf")
        
        # Loop over all features
        for feature in range(n_features):
            feature_values = X[:, feature]
            possible_thresholds = np.unique(feature_values)
            
            # Loop over all possible features
            for threshold in possible_thresholds:
                X_left, y_left, X_right, y_right = self._split(X, y, feature, threshold) # Current split
                if len(y_left) > 0 and len(y_right) > 0: 
                    current_info_gain = self._information_gain(y, y_left, y_right, criterion)
                    if current_info_gain > max_info_gain:
                        max_info_gain = current_info_gain
                        best_split['feature'] = feature 
                        best_split['threshold'] = threshold
                        best_split['X_left'] = X_left
                        best_split['y_left'] = y_left
                        best_split['X_right'] = X_right
                        best_split['y_right'] = y_right
                        best_split['info_gain'] = current_info_gain   
        return best_split
    
    def _build_tree(self, X, y, current_depth):
        """Build the tree recursively
        """
        n_samples, n_features = X.shape

        # Stoppig condition
        if n_samples >= self.min_samples_split and  current_depth < self.max_depth:
            best_split = self._get_best_split(X, y, self.criterion) # Find the best split
            if best_split['info_gain'] >= 0:
                left_subtree = self._build_tree(best_split['X_left'], best_split['y_left'], current_depth+1) # Build left sub-tree
                right_subtree = self._build_tree(best_split['X_right'], best_split['y_right'], current_depth+1) # Build rigth sub-tree
                
                return Node(best_split['feature'], best_split['threshold'], 
                           left_subtree, right_subtree, best_split['info_gain'])
            
        leaf_value = self._calulate_leaf_value(y)
        return Node(value=leaf_value)
        
    def fit(self, X, y):
        """Train the data to build tree
        """
        self.root = self._build_tree(X, y, current_depth=0)
        
    def _make_prediction(self, tree, x):
        if tree.value != None:
            return tree.value
        feature_x = x[tree.feature]
        if feature_x <= tree.threshold:
            return self._make_prediction(tree.left, x)
        else:
            return self._make_prediction(tree.right, x)
        
    def predict(self, X):
        """Predict labels for new data
        """
        predictions = [self._make_prediction(self.root, x) for x in X]
        return np.array(predictions, dtype=np.int)

In [48]:
# Test
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
# X_train = np.array([[0, 0], [1, 1]])
# y_train = np.array([0, 1])

In [49]:
cls1 = DecisionTreeClassifier(criterion='entropy', max_depth=3, min_samples_split=2)
cls1.fit(X_train, y_train)
preds = cls1.predict(X_test)
accuracy_score(preds, y_test)

0.9736842105263158

In [50]:
from sklearn import tree
cls2 = tree.DecisionTreeClassifier(max_depth=3, min_samples_split=2)
cls2.fit(X_train, y_train)
preds = cls2.predict(X_test)
accuracy_score(preds, y_test)

0.9736842105263158