In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from scipy.spatial.distance import cdist
from collections import Counter
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

In [2]:
class KNN:
    def __init__(self, k = 3):
        self._k = k
        self._neighbors = None
    def fit(self, X, y):
        self._X = X #data
        self._y = y #target
        return self
    def __predict(self, column):
        neighbors = tf.gather_nd(self._y, tf.reshape(column, (-1,1)))
        return Counter(neighbors.numpy()).most_common()[0][0]
    def predict(self, X):
        distances = cdist(self._X, X, 'euclidean')
        arg_sorted_distances = tf.argsort(distances, axis = 0)
        neighbors = arg_sorted_distances[:self._k]
        predictios = tf.map_fn(self.__predict, tf.transpose(neighbors))
        return tf.cast(predictios, 'float32')

In [3]:
digits = datasets.load_digits()

In [7]:
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, shuffle=False)

k = 3
iris_knn = KNN(k=k)
iris_knn.fit(X_train, y_train)
predictions = iris_knn.predict(X_test)

print(confusion_matrix(y_test, predictions))

[[87  0  0  0  1  0  0  0  0  0]
 [ 0 88  1  0  0  0  0  0  2  0]
 [ 1  0 81  4  0  0  0  0  0  0]
 [ 0  0  0 83  0  1  0  2  3  2]
 [ 0  0  0  0 86  0  0  0  0  6]
 [ 0  0  0  0  0 88  1  0  0  2]
 [ 0  0  0  0  0  0 91  0  0  0]
 [ 0  0  0  0  0  0  0 88  1  0]
 [ 0  2  0  2  0  0  0  0 84  0]
 [ 0  0  0  3  0  2  0  0  1 86]]
