*Authored 2023 by [TRL](https://github.com/tylew)*

In [None]:
import random
import cv2
import numpy as np
import torchvision
import PIL

mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True)

In [None]:
from collections import Counter

# Helper functions
def find_most_common_label(k_nearest_items):
  labels = [label for _, label in k_nearest_items]
  label_counts = Counter(labels)
  most_common_label = label_counts.most_common(1)[0][0]
  return most_common_label

def convert_to_vector(input_image):
  if isinstance(input_image, np.ndarray):
      # It's already a NumPy array, no need to cast
      return input_image
  else:
      # Cast it to a NumPy array
      return np.array(input_image)

def distance(v1,v2):
  dist = cv2.norm(v1, v2, cv2.NORM_L2)
  return dist

In [None]:
def k_nearest_neighbor(input_image, dataset: list, k=10):
  # Ensure input is vectorized
  input_vector = convert_to_vector(input_image)
  # Itterate dataset of comparable images,
  # calculate distances to the unknown-label input vector
  item_distance_list = [
    (
      distance(convert_to_vector(img), input_vector),
      label,
    )
    for img, label in dataset
  ]
  # Sort list by idx 0 -> distance
  sorted_item_distance_list = sorted(item_distance_list, key=lambda x: x[0])

  # Take the k smallest distance items
  k_nearest_items = sorted_item_distance_list[:k]

  # Find most common label in nearest item list
  most_common_label = find_most_common_label(k_nearest_items)

  # Return
  return most_common_label

In [None]:
def test_model(k:int = 10):
  runs = 0
  accurate_runs = 0

  len_data = len(mnist_train)
  random_indices = random.sample(range(len_data), min(1000,len_data))

  for idx in random_indices:
    # Retrieve current test img & label
    sample_image, label = mnist_train[idx]
    # Determine nearest neighbor
    nearest_neighbor = k_nearest_neighbor(sample_image, mnist_test, k)


    if label == nearest_neighbor:
      accurate_runs += 1
    runs += 1


  accuracy = (accurate_runs / runs) * 100

  print(str(accuracy) + "% accuracy for k value = " + str(k))



In [None]:
test_model()