In [1]:
import numpy as np

In [2]:
def calculate_distance(a,b):
    return np.sqrt(sum((a-b)**2))
    
def calculate_axis_distance(a,b,axis):
    return np.abs(a[axis]-b[axis])

def sort_by_column(data,sort_coordinate):
    new_order = data[:,sort_coordinate].argsort()
    data = data[new_order]
    return data

In [3]:
class KDNode:
    def __init__(self,val,axis,label):
        self.val = val
        self.axis = axis
        self.label = label
        self.left = None
        self.right = None

In [4]:
class KDTree:
    def __init__(self):
        self.head = None
    
    def fit(self,data):
        self.head = self.__build_tree(data)
    
    def __build_tree(self,data,axis=0):
        if len(data) == 0:
            return None
        
        next_axis = (axis+1)%(data.shape[1]-1) #last column is reserved for label
        
        data = sort_by_column(data,axis)
        median = data[len(data)//2][axis]
        med_id = np.where(data[:,axis]==median)[0][0]
        left_data = data[:med_id]
        right_data = data[med_id+1:]
        
        node_val,node_label = data[med_id][:-1], int(data[med_id][-1])
        node_axis = axis
        node = KDNode(val=node_val,axis=node_axis,label=node_label)
        
        node.left = self.__build_tree(left_data,next_axis)
        node.right = self.__build_tree(right_data,next_axis)
        
        return node
        
    def find_k_closest(self,val,k):
        k_min_distances = np.array([np.inf]*k)
        k_min_nodes = np.array([None]*k)
        self.__find_k_closest_go_down(self.head,val,k,k_min_distances,k_min_nodes)
        return k_min_nodes,k_min_distances
    
    def __find_k_closest_go_down(self,node,val,k,k_min_distances,k_min_nodes):
        if node is None:
            return
        curr_dist = calculate_distance(node.val,val)
        max_id = np.argmax(k_min_distances)
        if curr_dist >= k_min_distances[max_id]:
            pass
        else:
            k_min_distances[max_id] = curr_dist
            k_min_nodes[max_id] = node
        
        axis = node.axis
        if (val[axis] < node.val[axis]): #Go left
            self.__find_k_closest_go_down(node.left,val,k,k_min_distances,k_min_nodes)
            if (calculate_axis_distance(val,node.val,axis) < np.max(k_min_distances)): #Check if need to go right too
                self.__find_k_closest_go_down(node.right,val,k,k_min_distances,k_min_nodes)
        
        elif (val[axis] >= node.val[axis]): #Go right
            self.__find_k_closest_go_down(node.right,val,k,k_min_distances,k_min_nodes)
            if (calculate_axis_distance(val,node.val,axis) < np.max(k_min_distances)):#Check if need to go left too
                self.__find_k_closest_go_down(node.left,val,k,k_min_distances,k_min_nodes)

In [5]:
class KNN:
    def __init__(self,k=5):
        self.tree = KDTree()
        self.k = k

    def fit(self,X,y):
        data = np.hstack([X,y.reshape(len(y),1)]).copy()
        self.tree.fit(data)

    def __predict_point(self,point):
        k_closest,_ = self.tree.find_k_closest(point,self.k)
        labels = np.array([node.label for node in k_closest])
        counts = np.bincount(labels)
        label = np.argmax(counts)
        return label

    def predict(self,X):
        out = []
        for point in X:
            out.append(self.__predict_point(point))
        return np.array(out)

<h3>Verifying that our KNN works the same as Sklearn's KNeighborsClassifier</h3>

In [6]:
from sklearn.neighbors import KNeighborsClassifier

k = 5
cols = 4
rows = 10_000
test_size = 100
num_classes = 10

X = np.random.random((rows,cols))
y = np.random.randint(0,num_classes,size = (rows,))

X_train,X_test = X[:-test_size],X[-test_size:]
y_train,y_test = y[:-test_size],y[-test_size:]

In [7]:
my_knn = KNN(k)
my_knn.fit(X_train,y_train)
my_pred = my_knn.predict(X_test)

In [8]:
knn = KNeighborsClassifier(k)
knn.fit(X_train,y_train)
pred = knn.predict(X_test)

<h3>Results are equal :)</h3>

In [9]:
my_pred == pred

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True])