In [1]:
import torch
import numpy as np
import torchvision.transforms as tt

from utils import load_data, supervised_samples, set_random_seed, CustomDataSet
import config

In [2]:
set_random_seed(config.RANDOM_SEED)
np.random.seed(config.RANDOM_SEED)

Setting seeds ...... 



In [4]:
full_ds, test_ds, classes = load_data()
test_ds.to(config.DEVICE)

In [5]:
classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

In [6]:
class Model:
	def __init__(self, train_ds: CustomDataSet, k=5):
		self.train_ds = train_ds
		self.k = k

	
	def to(self, device):
		self.train_ds.to(device)

	@staticmethod
	def euclidean(p1, p2):
		return torch.sqrt(torch.sum((p1-p2)**2, dim=[1,2,3]))
		
	
	def evaluate(self, test_point: torch.Tensor):
		distances = []
		

		X_train, y_train = self.train_ds[:]


		distances = self.euclidean(X_train, test_point).unsqueeze(1)

		distances = torch.cat((distances, y_train.unsqueeze(1)), dim=1)


		distances = distances[distances[:, 0].argsort()][: self.k]


		labels, counts = torch.unique(distances[:, 1], return_counts=True)

		majority_vote = labels[counts.argmax()]

		return majority_vote, (counts.max() / self.k).item()


	def calculate_accuracy(self, test_ds):
		corrected = 0
	
		for test_point, label in test_ds:
			pred_label, _ = self.evaluate(test_point)
			corrected += (pred_label == label).item()
		
		return corrected / len(test_ds)
	

In [7]:
import pickle

In [8]:
for k in [1, 3, 5]:

	accuracy_values = []
	x_values = [50, 100, 500, 'full']

	if config.USED_DATA == "EMNIST":
		x_values = [100, 200, 1000, 'full']

	for x in x_values:
		if x == 'full':
			train_ds = full_ds
		else:
			train_ds = supervised_samples(full_ds, x, 10)

		train_ds.to(config.DEVICE)
		

		model = Model(train_ds, k)

		
		accuracy = model.calculate_accuracy(test_ds)
		accuracy_values.append(accuracy)


		with open(f'{config.USED_DATA}/KNN/_{x}.pkl', 'wb') as f:
			model.to('cpu')
			pickle.dump(model, f)

	print(f'{k=}, acc: {accuracy_values}')


KeyboardInterrupt: 

In [None]:
accuracy_values

[]