# K-Nearest Neighbour

For a datapoint you want to make a prediction for, you can use the neighbouring points to decide what class it should fall into![image.png](attachment:image.png)

K is the number of neighbours you will be observing in order to make your prediction on what class your unknown value falls into. Within the diameter of the K-circle, you find the class with the highest number of occurences within that area, thus making it most likely to be the target class. However, to keep track of an arbitrary number of closest distances is not trivial. 

### Example 


Given the distance of K = 3, the algorithm will then have to go through the distance of K at 1, 2 and 3 i.e. K = [1,2,3]. You can use a sorting algorithm to values with the closest distances. Thus this can be done in O(logK) time. With the distances, the votes of the classes and corresponding classes can then be stored such that {dist1: class1, class2, ...} or [(dist1, class1),(dist2,class2),...] and {class1: num_class1, class2: num_class2, ...} once the K-nearest Neighbours has been collected. You pick the class with the highest votes. 

### Breaking Ties
- Take argmax of votes 
- Pick one at random
- Weight by distance to neight (more difficult)

KNN is a lazy classifier as training doesn't do anything but store the values of X and use them to infer Y. predict(X) does all the work by looking through the stored X and Y


## Implementation


In [None]:
import numpy as np
from sortedcontainers import SortedList

# importing from the util.py file to get the preprocessed data from the mnist folder
from util import get_mnist_data
from datetime import datetime

class KNN(object):
    def __init__(self,k):
        self.k = k
        
    # rmb the KNN is lazy so only stores X and y in training
    def fit(self, X,Y):
        self.X = X
        self.y = Y
        
    # the predict function only takes in the X value and uses it to infer the Y value    
    def predict(self, X):
        y = np.zeros(len(X))
        for i, x in enumerate(X):
            sl = SortedList(load=self.k) 
            for j, xt in enumerate(self.X): # where j is the
                diff = x - xt
                d = diff.dot(diff)
                

In the SortedList function, the load parameter defines how large the sorted list should be. The following for loop goes through all the training points for each input test point to find the nearest neighbours. Where j is the index and xt is the training point.  
d is the square distance of the difference 