Implementation of the k-NN classifier algorithm where we find the distance between vectors and use those to influence whether or not a value is
the nearest neighbor.

Added on to this is the use of a heap to maintain information about the "k" elements

In [71]:
import numpy as np
import pandas as pd
import heapq

In [43]:
iris_df = pd.read_csv('datasets/iris.data', header=0, names=['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'label'])

examples = iris_df.iloc[:,:-1].to_numpy()
example_labels = iris_df.iloc[:, -1:].to_numpy()

example_labels

array([['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-setosa'],
       ['Iris-se

In [120]:
def dist(x, y):
    # print(x, y, np.dstack((x, y)))
    d = 0
    zipped = np.dstack((x, y))[0]
    for pair in zipped:
        d += (pair[0] - pair[1])**2
    return d**0.5

def knn(data, k, x):
    x = np.array(x)
    h = []
    for idx, train_row in data.iterrows():
        heapq.heappush(h, (-dist(x, train_row), idx))
        if len(h) > k:
            heapq.heappop(h)

    neighbors = []
    while h:
        neighbors.append(heapq.heappop(h)[1])
    return neighbors

# knn(examples, 3, [4.2, 1.4, 3.2, 0.6])

In [164]:
def knn_classifier(data, labels, k, x):
    neighbors = knn(data, k, x)
    filtered_labels = labels.loc[labels.index.isin(neighbors)].to_numpy()
    # from https://www.geeksforgeeks.org/python-find-most-frequent-element-in-a-list/
    unique, counts = np.unique(filtered_labels, return_counts=True)
    index = np.argmax(counts)
    return unique[index]

# knn_classifier(examples, example_labels, 3, [2.5, 1.4, 1.2, 0.6])

In [166]:
def train_test_split(df, label_column, train_ratio, test_ratio):
    if train_ratio + test_ratio != 1.0:
        raise 'Invalid train-test ratio'

    np.random.seed(42)
    df = df.reindex(np.random.permutation(df.index))
    data = df.loc[:, df.columns != label_column]
    labels = df.loc[:, df.columns == label_column]
    ratio = int(df.shape[0] * train_ratio)
    train_data = data.iloc[ratio:,:]
    test_data = data.iloc[:ratio,:]
    train_labels = labels.iloc[ratio:,:]
    test_labels = labels.iloc[:ratio,:]

    return train_data, train_labels, test_data, test_labels

# Handles the model initialization and loading using a k-NN classifier
def run_model(df, label_column, train_ratio, test_ratio, k):
    train_data, train_labels, test_data, test_labels = train_test_split(df, label_column, train_ratio, test_ratio)

    misses = 0

    for idx, row in test_data.iterrows():
        expected_label = test_labels.loc[[idx]]['label'].values[0]
        predicted_label = knn_classifier(train_data, train_labels, k, row)
        if expected_label != predicted_label:
            misses += 1

    return misses / test_data.shape[0]

In [170]:
for k in range(1, 20):
    print(f'Using {k}-NN has a miss rate of {run_model(iris_df, "label", 0.6, 0.4, k)}')

Using 1-NN has a miss rate of 0.056179775280898875
Using 2-NN has a miss rate of 0.06741573033707865
Using 3-NN has a miss rate of 0.0449438202247191
Using 4-NN has a miss rate of 0.0449438202247191
Using 5-NN has a miss rate of 0.0449438202247191
Using 6-NN has a miss rate of 0.0449438202247191
Using 7-NN has a miss rate of 0.0449438202247191
Using 8-NN has a miss rate of 0.0449438202247191
Using 9-NN has a miss rate of 0.0449438202247191
Using 10-NN has a miss rate of 0.056179775280898875
Using 11-NN has a miss rate of 0.06741573033707865
Using 12-NN has a miss rate of 0.056179775280898875
Using 13-NN has a miss rate of 0.06741573033707865
Using 14-NN has a miss rate of 0.06741573033707865
Using 15-NN has a miss rate of 0.0898876404494382
Using 16-NN has a miss rate of 0.07865168539325842
Using 17-NN has a miss rate of 0.07865168539325842
Using 18-NN has a miss rate of 0.0898876404494382
Using 19-NN has a miss rate of 0.0898876404494382
