In [1]:
%pylab inline
import pickle as pk

Populating the interactive namespace from numpy and matplotlib


In [2]:
#labels is the labels or the neighbors of a test point
#It consists of a list of 1000 arrays each corresponding to one test examples
# each array contains 60000 elements which are the labels of the 60000 trainining examples sorted by increasing 
# distance to the test example.

# with open('labels.pkl','br') as pklfile:
#     labels=pk.load(pklfile)
# len(labels), labels[0].shape

In [3]:
!ls -lrth *.npz

-rw-r--r--  1 yoavfreund  staff   1.1G Apr 26 22:01 IndexByDistance.npz


In [4]:
X=np.load('IndexByDistance.npz')
indicesByDistance=X['a']

In [5]:
import mnist

def recast(x):
    return np.array(x,dtype=np.float32)
train_images = recast(mnist.train_images())
train_labels = mnist.train_labels()
test_images = recast(mnist.test_images())
test_labels = mnist.test_labels()

In [6]:
labels=train_labels[indicesByDistance]

In [86]:
p=0.8
random=np.random.choice(10,indicesByDistance.shape)

In [87]:
combined=np.dstack((labels,random))

In [88]:
combined.shape

(10000, 60000, 2)

In [89]:
selector=np.random.binomial(1,p,random.shape)
sel=np.dstack([1-selector,selector])
sel.shape

(10000, 60000, 2)

In [90]:
noisy_labels=sum(combined*sel,axis=2)
noisy_labels.shape

(10000, 60000)

In [91]:
noisy_labels[:10,:7]

array([[5, 2, 1, 7, 7, 9, 0],
       [5, 1, 5, 2, 8, 2, 2],
       [4, 0, 7, 2, 9, 7, 3],
       [6, 9, 7, 0, 3, 9, 2],
       [7, 9, 4, 5, 0, 7, 4],
       [1, 7, 4, 2, 0, 1, 1],
       [5, 8, 7, 0, 8, 8, 6],
       [6, 7, 3, 0, 8, 2, 5],
       [8, 5, 5, 3, 3, 5, 8],
       [1, 4, 8, 3, 8, 7, 8]])

In [92]:
labels=noisy_labels

In [93]:
X=np.array([[1,3,2],[1,7,1]])
argmax(X,axis=1),np.max(X,axis=1)

(array([1, 1]), array([3, 7]))

In [102]:
_range=1000

def find_sig(S,thr=5):
    """ S is a matrix of 10 X _range.
    entry S[i,j] is the significance associated with label i when considering j nearset neighbors.
    Find the first entry of x that is larger than thr.
    returns predicted_label,stopping_time
    """
    f=[]
    stopping_time=_range
    for i in range(10):
        f=np.nonzero(S[i,:]>thr)[0]
        if f.shape[0]>0:
            f=f[0]
            if stopping_time > f:
                stopping_time=f
                predicted_label=i
    if stopping_time<_range:
        return predicted_label,stopping_time
    # did not find anything with significance > thr, then take the one with highest significance
    _argmax=np.argmax(S,axis=1)
    _max = np.max(S,axis=1)
    predicted_label = np.argmax(_max)
    stopping_time = _argmax[predicted_label]
    #print(stopping_time)
    return predicted_label,stopping_time
            
    

In [103]:
from collections import Counter
def find_label_knn(test_index,k=5):
    L=labels[test_index]
    C=list(Counter(L[:k]).items())
    S=sorted(C,key=lambda x:x[1],reverse=True)
    return np.array(S).transpose()

In [104]:
def find_label_adaptive(test_index,thr=4):
    """Find label of test example test_index using adaptive k NN"""
    C=[]
    sig=[]
    P=np.zeros([10,_range])
    L=labels[test_index]
    scale=arange(1,_range+1,1)
    for i in range(10):
        C.append(np.cumsum(L==i))
        P[i,:]=(C[-1][:_range]-(scale/10))/sqrt(scale)
    predicted_label,stopping_time = find_sig(P,thr=thr)
    sigs_at_stop = P[:,stopping_time]
    I=argsort(sigs_at_stop)
    Sigs=sigs_at_stop[I]
    prediction=np.array([I[Sigs>0],Sigs[Sigs>0]],dtype=np.int16)
    return prediction[:,-1::-1]

In [111]:
def compute_errs(method="adaptive",**kwargs):
    errs=np.zeros([4,11])
    n=test_labels.shape[0]
    for i in range(n):
        if method=='adaptive':
            pred_label=find_label_adaptive(i,**kwargs)
        else:
            pred_label=find_label_knn(i,**kwargs)

        prediction=int(pred_label[0,0])
        err = 0 #prediction error
        labelSet=set(pred_label[0,:])
        if prediction != test_labels[i]:
            err=1 # incorrect prediction
        if test_labels[i] in labelSet:
            err+=2 # prediction in prediction set
        multi=pred_label.shape[1]
        if prediction==-1:
            multi=0
            err=1
        errs[err,multi]+=1
        if i%100==0:
            print('\r',i,end='')
    errs=np.array(errs,dtype=np.int16)
    print()       
    
    cr_sums=np.sum(errs,axis=1)
    print('size of set total',''.join(["%5d"%x for x in range(11)]))
    print('total      %6d'%np.sum(cr_sums),''.join(["%5d"%x for x in np.sum(errs,axis=0)]))
    print('incorrect  %6d'%np.sum(cr_sums[[0,1,3]]),''.join(["%5d"%x for x in np.sum(errs,axis=0)-errs[2,:]]))
    #print('in set     ',''.join(["%6d"%x for x in errs[2,:]+errs[3,:]]))
    print('not in set %6d'%np.sum(cr_sums[[0,1]]),''.join(["%5d"%x for x in errs[0,:]+errs[1,:]]))
 

In [112]:
for k in range(1,11):
    print('='*40,'knn with k=',k)
    errs=compute_errs('knn',k=k)

 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     010000    0    0    0    0    0    0    0    0    0
incorrect    7834     0 7834    0    0    0    0    0    0    0    0    0
not in set   7834     0 7834    0    0    0    0    0    0    0    0    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0 1187 8813    0    0    0    0    0    0    0    0
incorrect    7834     0  677 7157    0    0    0    0    0    0    0    0
not in set   6123     0  677 5446    0    0    0    0    0    0    0    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0  162 2951 6887    0    0    0    0    0    0    0
incorrect    7601     0   47 1830 5724    0    0    0    0    0    0    0
not in set   4710     0   47 1374 3289    0    0    0    0    0    0    0
 99003400
size of set total     0    1    2    3    4    5    6    7    8    9   10
total     

In [113]:
for k in range(10,101,10):
    print('='*40,'knn with k=',k)
    errs=compute_errs('knn',k=k)

 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0    0    1   23  369 1766 3488 3164 1083  103    3
incorrect    5925     0    0    0    5  128  836 1966 2105  801   81    3
not in set    863     0    0    0    2   29  172  323  254   81    2    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0    0    0    0    0   11  179 1157 3203 3954 1496
incorrect    4504     0    0    0    0    0    3   45  421 1328 1901  806
not in set     88     0    0    0    0    0    0    3   15   39   31    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0    0    0    0    0    0    3   94  895 3783 5225
incorrect    3512     0    0    0    0    0    0    0   22  247 1279 1964
not in set     18     0    0    0    0    0    0    0    1    3   14    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10

In [114]:
for k in range(100,801,100):
    print('='*40,'knn with k=',k)
    errs=compute_errs('knn',k=k)

 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0    0    0    0    0    0    0    0    0   15 9985
incorrect    1029     0    0    0    0    0    0    0    0    0    2 1027
not in set      0     0    0    0    0    0    0    0    0    0    0    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0    0    0    0    0    0    0    0    0    010000
incorrect     818     0    0    0    0    0    0    0    0    0    0  818
not in set      0     0    0    0    0    0    0    0    0    0    0    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0    0    0    0    0    0    0    0    0    010000
incorrect     842     0    0    0    0    0    0    0    0    0    0  842
not in set      0     0    0    0    0    0    0    0    0    0    0    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10

In [115]:
for theta in range(1,4):
    print('='*40,'adaptive NN with theta=',theta)
    errs=compute_errs('adaptive',thr=theta)

 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0 1249  943 2829 2997 1675  293   14    0    0    0
incorrect    2119     0  691  206  494  423  248   57    0    0    0    0
not in set   1225     0  691  138  212  108   65   11    0    0    0    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0  445 2794 4303 2093  352   13    0    0    0    0
incorrect    1066     0   36  213  453  295   65    4    0    0    0    0
not in set    247     0   36   63   87   55    5    1    0    0    0    0
 9900
size of set total     0    1    2    3    4    5    6    7    8    9   10
total       10000     0 1440 4080 3226 1070  174   10    0    0    0    0
incorrect    1078     0   51  285  428  251   60    3    0    0    0    0
not in set    255     0   51   85   76   36    6    1    0    0    0    0


In [62]:
5254  +2252  +1540   +626   +278    +50

10000

In [None]:
i=8520 # a 4 that looks like a 9
i=20
#i=23
#i=320

threshold=5
L=labels[i]
true_label=test_labels[i]
imshow(test_images[i,:,:],cmap='gray')
title(str(true_label))

figure(figsize=[10,8])
_range=1000
C=[]
sig=[]
scale=arange(1,_range+1,1)
for i in range(10):
    C.append(np.cumsum(L==i))
    P=(C[-1][:_range]-(scale/10))/sqrt(scale)
    sig.append(find_sig(P))
    _name=str(i)
    if i==true_label:
        _name+= ' true label '
    plot(P,label=_name)
legend()
grid()
print(sig)

In [None]:
P[:10]

In [None]:
np.nonzero(P>1)[0].shape

In [None]:
# instead of giving up when reaching _range, use the label that showed the largest "p".