modified replica of https://github.com/b-akshay/aknn-classifier/blob/master/examples/aknn_usage.ipynb

In [1]:
import scipy as sp
import numpy as np
import aknn

In [2]:
INT2LABEL = np.array(list('ABCDEFGHIJ'))
data = sp.io.loadmat("notMNIST_small.mat")
images = data['images'].transpose(2, 0, 1)
samples = (2 * images.reshape(-1, 784) - 255.0) / 255.0
labels = data['labels'].astype(int)
labels = INT2LABEL[labels]

In [3]:
nbrs_list = aknn.calc_nbrs_exact(samples, k=1000)

In [6]:
# get aknn predictions
aknn_pred, aknn_ks = aknn.predict_nn_rule(nbrs_list, labels)
aknn_correct = (aknn_pred == labels)

for k in [2, 3, 5, 7, 8, 10, 30, 99]:
    # get knn predictions
    knn_pred = aknn.knn_rule(nbrs_list, labels, k)
    knn_correct = (knn_pred == labels)
    # sample aknn predictions where less than k neighbors used
    aknn_cov_ndces = (aknn_ks <= k)
    aknn_cover = aknn_cov_ndces.mean()
    aknn_cond_acc = aknn_correct[aknn_cov_ndces].mean()
    # print
    print('KNN accuracy (k = %d): \t\t%.4f' % (k, knn_correct.mean()))
    print('AKNN accuracy (k <= %d): \t%.4f \t\t Coverage: %.3f\n' % (k, aknn_cond_acc, aknn_cover))
print('Full AKNN accuracy: %s' % aknn_correct.mean())

KNN accuracy (k = 2): 		0.8599
AKNN accuracy (k <= 2): 	0.9702 		 Coverage: 0.838

KNN accuracy (k = 3): 		0.8749
AKNN accuracy (k <= 3): 	0.9702 		 Coverage: 0.838

KNN accuracy (k = 5): 		0.8833
AKNN accuracy (k <= 5): 	0.9450 		 Coverage: 0.918

KNN accuracy (k = 7): 		0.8836
AKNN accuracy (k <= 7): 	0.9408 		 Coverage: 0.926

KNN accuracy (k = 8): 		0.8835
AKNN accuracy (k <= 8): 	0.9362 		 Coverage: 0.936

KNN accuracy (k = 10): 		0.8823
AKNN accuracy (k <= 10): 	0.9322 		 Coverage: 0.943

KNN accuracy (k = 30): 		0.8768
AKNN accuracy (k <= 30): 	0.9159 		 Coverage: 0.969

KNN accuracy (k = 99): 		0.8592
AKNN accuracy (k <= 99): 	0.9072 		 Coverage: 0.982

Full AKNN accuracy: 0.8925977355265969


In [9]:
# get the absolute best k
hi, hi_k = 0, 1
for k in np.arange(2, 30):
    # get knn predictions
    knn_pred = aknn.knn_rule(nbrs_list, labels, k)
    knn_acc = (knn_pred == labels).mean()
    if knn_acc > hi:
        hi, hi_k = knn_acc, k
print('%.4f (k=%d)' % (hi, hi_k))

0.8836 (k=7)
