In [701]:
import numpy as np

In [752]:
class Node:
    def __init__(self,x,y,depth,depth_of_node,min_points,n_class=2):
        self.best_feature=None                                      # best information gain feature index 
        self.threshold=None                                         # threshold
        self.left_link=None                                         # left node
        self.right_link=None                                        # right node
        self.n_class={i:0 for i in n_class}                         # class : points
        self.x=x                                                    # X
        self.y=y                                                    # target
        self.depth=depth                                            # min_depth 
        self.depth_of_node=depth_of_node                            # depth of this node
        self.min_points=min_points                                  # min points required to start spliting
        if len(x)>self.min_points and self.depth > self.depth_of_node:
            self.devision()


    def weighted_gini_impurity(self,left_prob,right_prob,left_points,right_points,total_points):
        """
        gini_impurity = 1 - (p(y=1|x)**2 + p(y=0|x)**2)
        
        """
        gini_impurity_left = 1 - sum(left_prob**2)               #left node gini impurity
        gini_impurity_right = 1 - sum(right_prob**2)             #right node gini impurity

        gini_impurity = (left_points/total_points)*gini_impurity_left + (right_points/total_points)*gini_impurity_right #weighted gini impurity
        
        return  gini_impurity

    def devision(self):                                            #spliting node
        
        gini_impurity = np.inf
        for i in range(len(self.x[0])):                             # for each features of
            temp1=self.y[np.argsort(self.x[:,i])]                   # sort y according to ascending order of x[i] feature 
            temp2=np.sort(self.x[:,i])                              # sort x[i]
            for j in range(1,len(temp2)):
                temp3=temp1[:j]                                     # if we divide using a threshold , corresponding y labels left side                                 
                temp4=temp1[j:]                                     # right side
                left_index,left_counts=np.unique(temp3,return_counts=True) # count of each class when divied using corresponding threshold
                
                
                
                left_probability=left_counts/sum(left_counts)       # probability of each class in left tree
                
                right_index,right_counts=np.unique(temp4,return_counts=True) #count of each class when divied using corresponding threshold

                right_probability=right_counts/(sum(right_counts))   # probability of each class in right tree
                
                temp5=self.weighted_gini_impurity(left_probability,right_probability,len(temp3),len(temp4),len(temp1))  # calculate weighted gini_impurity
        
                if gini_impurity > temp5:                           # if gini impurity is less than other than change
                    
                    gini_impurity =temp5
                    self.threshold=temp2[j]                         
                    self.best_feature=i
                         
        temp1=self.x[:,self.best_feature]                       #best feature
        temp2=temp1 < self.threshold                            #best threshold left node points

        self.left_link=Node(self.x[temp2],self.y[temp2],depth=self.depth,depth_of_node=self.depth_of_node+1,min_points=self.min_points,n_class=self.n_class) #creating left node

        temp2=temp1 >= self.threshold                           #best threshold right node points

        self.right_link=Node(self.x[temp2],self.y[temp2],depth=self.depth,depth_of_node=self.depth_of_node+1,min_points=self.min_points,n_class=self.n_class) #creating right node
        
        del self.x,self.y

    def predict(self,xq):                                        # at a time one query only 
        """
        xq : one query point at a time
        """
        if self.left_link==None  and  self.right_link==None:     # checking if there no right or left link 
            
            index,count=np.unique(self.y,return_counts=True)     #counting and giving class label
            #print(index,count)
            return index[np.argmax(count)]
        else:                                                    #if there is link then gving the query to next node
            if xq[self.best_feature] < self.threshold : 
                return self.left_link.predict(xq)
            else:
                return self.right_link.predict(xq)
            
            

In [775]:
class Decision_Tree:
    def __init__(self,depth=5,min_points=10):
        """
        depth : max depth of the tree
        
        min_points : min points to stop tree spliting
        """
        self.depth=depth                                    # depth of the tree
        self.min_points=min_points                          # min points
        self.x=None                                         # x_data
        self.y=None                                         # target labels
        self.node=None
        self.no_of_classes=None                             #number of classes
    def fit(self,x,y):
        self.x=x                        
        self.y=y
        self.no_of_classes=np.unique(y)
        self.node=Node(x,y,depth=self.depth,depth_of_node=0,min_points=self.min_points,n_class=self.no_of_classes) # creating node
    def predict(self,xq):
        return self.node.predict(xq)                        # sending query to node 
    

In [776]:
from sklearn.datasets import make_classification
x,y=make_classification()

In [777]:
m=Decision_Tree(depth=1,min_points=20)

In [778]:
m.fit(x,y)

In [781]:
m.predict(x[3])

1

In [782]:
y[3]

1

In [783]:
accuracy=0
for i in range(len(x)):
    if y[i]==m.predict(x[i]):
        accuracy+=1
print('accuracy :',accuracy/len(y))

accuracy : 0.91
