-
Notifications
You must be signed in to change notification settings - Fork 0
/
KNN.py
74 lines (59 loc) · 2.75 KB
/
KNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import random
from data_prep import *
class KNN:
def __init__(self, trainDir, validationDir):
self.trainDir = trainDir
self.predictionsDir = validationDir
def createModel(self):
self.knn = cv2.ml.KNearest_create()
def trainModel(self):
self.trainData, self.trainLabels, self.labelsDictionary = load_data_KNN(self.trainDir)
self.predictionData, self.predictionLabels = load_test_KNN(self.predictionsDir)
self.knn.train(self.trainData, cv2.ml.ROW_SAMPLE, self.trainLabels)
def evaluateModel(self, k, accPlot, predictGraph, console):
if k == 0:
accArray = []
for k in range(1, 16):
ret, result, neighbours, dist = self.knn.findNearest(self.predictionData, k)
result = result.flatten()
result = result.astype(np.int)
acc = 0
for i in range(len(result)):
if result[i] == self.labelsDictionary[self.predictionLabels[i]]:
acc += 1
accuracy = round(float(acc/len(result)), 2) * 100
accArray.append(accuracy)
y = np.arange(1, 16)
accPlot.plot(y, accArray, label='Accuracy in %')
accPlot.set_title("Accuracy of KNN algorithm")
accPlot.set_xlabel("k Value")
accPlot.set_ylabel("Accuracy %")
accPlot.set_xlim([1, 15])
accPlot.set_xticks(y)
accPlot.legend()
else:
ret, result, neighbours, dist = self.knn.findNearest(self.predictionData, k)
result = result.flatten()
result = result.astype(np.int)
acc = 0
key_list = list(self.labelsDictionary.keys())
for i in range(len(result)):
if result[i] == self.labelsDictionary[self.predictionLabels[i]]:
acc += 1
rand = random.sample(range(0, len(self.predictionData)), k=6)
ind = 0
for x in range(2):
for y in range(3):
i = rand[ind]
ind += 1
img = np.array(self.predictionData[i])
img = img.reshape(int(len(img) / 3), -1).T
r = img[0].reshape(getTargetSize())
g = img[1].reshape(getTargetSize())
b = img[2].reshape(getTargetSize())
rgb = np.dstack((r, g, b)).astype(np.uint8)
predictGraph[x, y].set_title(key_list[result[i]])
predictGraph[x, y].imshow(Image.fromarray(rgb, 'RGB'))
predictGraph[x, y].axis('off')
accuracy = round(float(acc / len(result)), 2) * 100
console.append('Accuracy: ' + str(accuracy) + '%')