*Authored 2023 by [TRL](https://github.com/tylew)*

In [3]:
import random
import numpy as np
import cv2
from collections import Counter
import torchvision
from tqdm import tqdm  # Import tqdm for progress bar

mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Helper functions
def find_most_common_label(k_nearest_items):
  labels = [label for _, label in k_nearest_items]
  label_counts = Counter(labels)
  most_common_label = label_counts.most_common(1)[0][0]
  return most_common_label

def convert_to_vector(input_image):
  if isinstance(input_image, np.ndarray):
      # It's already a NumPy array, no need to cast
      return input_image
  else:
      # Cast it to a NumPy array
      return np.array(input_image)

def distance(v1,v2):
  dist = cv2.norm(v1, v2, cv2.NORM_L2)
  return dist

In [5]:
def k_nearest_neighbor(input_image, dataset: list, k=10):
  # Ensure input is vectorized
  input_vector = convert_to_vector(input_image)
  # Itterate dataset of comparable images,
  # calculate distances to the unknown-label input vector
  item_distance_list = [
    (
      distance(convert_to_vector(img), input_vector),
      label,
    )
    for img, label in dataset
  ]
  # Sort list by idx 0 -> distance
  sorted_item_distance_list = sorted(item_distance_list, key=lambda x: x[0])

  # Take the k smallest distance items
  k_nearest_items = sorted_item_distance_list[:k]

  # Find most common label in nearest item list
  most_common_label = find_most_common_label(k_nearest_items)

  # Return
  return most_common_label

In [6]:
def test_model(k:int = 10, max_samples:int = 1000):
  runs = 0
  accurate_runs = 0

  # Limit the number of samples to test
  len_data = len(mnist_train)
  random_indices = random.sample(range(len_data), min(max_samples, len_data))

  # Create progress bar
  progress_bar = tqdm(total=len(random_indices), desc="Testing")

  for idx in random_indices:
      # Retrieve current test image and label
      sample_image, label = mnist_train[idx]
      
      # Determine nearest neighbor
      nearest_neighbor = k_nearest_neighbor(sample_image, mnist_test, k)

      runs += 1
      if label == nearest_neighbor:
          accurate_runs += 1

      # Update progress bar
      progress_bar.update(1)

  accuracy = (accurate_runs / runs) * 100

  # Close progress bar
  progress_bar.close()

  # Print final accuracy
  print(f"\n{accuracy:.2f}% accuracy for k value = {k}")



In [9]:
test_model(k=15,max_samples=10)

Testing:   0%|          | 0/10 [00:00<?, ?it/s]

Testing: 100%|██████████| 10/10 [00:14<00:00,  1.49s/it]


100.00% accuracy for k value = 15



