In [1]:
import numpy as np

def gini_index(groups, classes):
    total_sample = float(sum([len(group) for group in groups]))
    gini = 0.0
    for group in groups:
        size = float(len(group))
        if size == 0:
            continue
        score = sum([
            (row[-1]==c) for row in group for c in classes
        ])/size
        gini += (1.0 - sum([score**2 for c in classes])) * (size/total_sample)
        return gini

In [2]:
dataset = [
    [2.8, 'Yes'],
    [1.2, 'No'],
    [3.6, 'Yes'],
    [4.5, 'No'],
    [5.1, 'Yes']
]

In [3]:
def split_data(dataset, feature_index, threshold):
    left = [row for row in dataset if row[feature_index] < threshold]
    right = [row for row in dataset if row[feature_index] >= threshold]
    return left,right


In [4]:
groups = split_data(dataset,0,3.0)
classes = ['Yes','No']
gini = gini_index(groups,classes)
print(f'Gini Index: {gini:.4f}')

Gini Index: -0.4000


In [5]:
class TreeNode:
    def __init__(self, feature_index=None, threshold=None, left=None, right=None, label=None):
        self.feature_index = feature_index
        self.threshold = threshold
        self.left = left
        self.right = right
        self.label = label
class DecisionTree:
    def __init__(self, max_depth = 3):
        self.max_depth = max_depth
        self.root = None
    def gini_index(self, groups, classes):
        total_samples = sum([len(group) for group in groups])
        gini = 0.0
        for group in groups:
            size = len(group)
            if size == 0:
                continue
            score = 0.0
            for class_val in classes:
                proportion = [row[-1] for row in group].count(class_val)/size
                score += proportion**2
            gini += (1.0 -score) *(size/total_samples)
        return gini

    def split_data(self,dataset, feature_index,threshold):
        left = [row for row in dataset if row[feature_index] < threshold]
        right = [row for row in dataset if row[feature_index] >= threshold]
        return left,right
    def best_split(self, dataset):
        class_values = list(set(row[-1] for row in dataset))
        best_index, best_threshold, best_score, best_groups = None,None,float('inf'), None
        for index in range(len(dataset[0]) -1):
            for row in dataset:
                groups = self.split_data(dataset, index, row[index])
                gini = self.gini_index(groups, class_values)
                if gini < best_score:
                    best_index, best_threshold, best_score, best_groups = index,row[index],gini, groups
        return best_index, best_threshold, best_groups
    def build_tree(self, dataset, depth = 0):
        class_values = [row[-1] for row in dataset]
        if len(set(class_values)) == 1 or depth >= self.max_depth:
            return TreeNode(label = max(set(class_values), key  = class_values.count))
        
        feature_index , threshold, (left,right) = self.best_split(dataset)
        if not left or not right:
            return TreeNode(label = max(set(class_values), key = class_values.count))
        
        left_node = self.build_tree(left, depth+1)
        right_node = self.build_tree(right,depth+1)
        
        return TreeNode(feature_index, threshold, left_node, right_node)
    def fit(self, dataset):
        self.root = self.build_tree(dataset)
    def print_tree(self, node = None, depth = 0):
        if node is None:
            node = self.root
        if node.label is not None:
            print(f"{' '*depth} [Leaf] Label: {node.label}")
        else:
            print(f"{' '*depth} [Node] Feature: {node.feature_index} <= {node.threshold}")
            self.print_tree(node.left, depth+1)
            self.print_tree(node.right, depth+1)

In [6]:
dataset = [
    [2.8, 'Yes'],
    [1.2, 'No'],
    [3.6, 'Yes'],
    [4.5, 'No'],
    [5.1, 'Yes']
]

In [7]:
tree = DecisionTree ( max_depth =3)
tree . fit ( dataset )

print('Decision tree:')
tree.print_tree()

Decision tree:
 [Node] Feature: 0 <= 2.8
  [Leaf] Label: No
  [Node] Feature: 0 <= 4.5
   [Leaf] Label: Yes
   [Node] Feature: 0 <= 5.1
    [Leaf] Label: No
    [Leaf] Label: Yes
