# Using Cosine Similarity Model to Classify MNIST Hand-written Digits

Below I write out the algorithm and show accuracy of it on the entire MNIST dataset.  I later go on to show how this approach needs far less data than other classification methods to work well (I think).

In [2]:
import numpy as np
from tqdm import tqdm
import cupy as cp
import heapq
from collections import Counter
from sklearn.metrics.pairwise import cosine_similarity
from sklearn import datasets, model_selection
mnist = datasets.fetch_mldata('MNIST original')

data, target = mnist.data, mnist.target

In [2]:
data.shape, target.shape

((70000, 784), (70000,))

In [3]:
indices = [0, 5923, 12665, 18623, 24754, 30596, 36017, 41935, 48200, 54051, 10000, 18000, 22000, 25000, 35000, 38000, 43000, 49000, 55000]
for i in indices:
    print(target[i])

0.0
1.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
1.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0


In [9]:
# first make a function that takes in array and classifies it
def classification(indx):
    """indx: index of datapoint in dataset to be classified
    n: number of top values to compare
    returns: top n most similar indexes"""
    
    dist = []
    
    for i in range(0, len(target)):
        cpmr = [data[indx]]
        cpmr.append(data[i])
        cpmr = np.array(cpmr)
        
        cosim = cosine_similarity(cpmr)[0][1]
        dist.append(cosim)
    
    dist = np.array(dist)
    top = heapq.nlargest(5, range(len(dist)), dist.take)
    
    # top most similar index (excluding itself)
    decision = target[top[1]]
    
    if(target[indx] == decision):
        return decision, True
    else:
        return decision, False

In [16]:
%%time
p = classification(18623)

CPU times: user 6.6 s, sys: 3.91 ms, total: 6.61 s
Wall time: 6.63 s


Great! So the basics are working, now we just need to figure out how to find the accuracy of this throughout the whole dataset...

In [20]:
%%time
# now make a loop that uses above function to test and predict accuracy
# just pick a random index and work with that.  if do enough over time will fall on all numbers
correct = 0
count = 0

for i in range(0, 100):
    indx = np.random.choice(target.shape[0])
    pred = classification(indx)
    
    if(pred[1] == True):
        correct += 1
        count += 1
    else:
        count += 1
        
    if(i%10 == 0):
        print(i)

0
10
20
30
40
50
60
70
80
90
CPU times: user 10min 52s, sys: 176 ms, total: 10min 53s
Wall time: 10min 53s


In [21]:
print("correct")
print("\n" + "count" + "\n")
print(correct/count)

correct

count

0.97


## optimize
Now, I will start messing around with everything to see how I can optimize it all.  The above version, **version 1, takes 6 seconds per classification and takes about ten minutes to classify 100 digits.  V1 is very slow**.

In [23]:
%%time
cosine_similarity(np.array([data[0]]), data)

CPU times: user 486 ms, sys: 274 ms, total: 760 ms
Wall time: 634 ms


array([[1.        , 0.87006533, 0.63943673, ..., 0.46985674, 0.49680307,
        0.53896219]])

note how this took an order of magnitude less time with this technique

In [98]:
# try to optimize function
def classif(indx, n):
    """indx: index of datapoint in dataset to be classified
    n: number of top values to compare
    returns: top n most similar indexes"""

    dist = cosine_similarity(np.array([data[indx]]), data)
    
    top = heapq.nlargest((n+1), range(len(dist[0])), dist.take)
    #take off the first value as that is its own
    top = [target[i] for i in top[1:(n+1)]]
    
    mc = Counter(top)
    # voting on value to return
    return mc.most_common(1)[0][0]

In [107]:
%%time
classif(30596, 7) == target[30596]

CPU times: user 665 ms, sys: 248 ms, total: 913 ms
Wall time: 699 ms


True

Find accuracy

In [100]:
def get_acc(n, steps):
    correct = 0
    count = 0
    #that way we can look and make sure it touched on all values
    indxls = []
    
    for i in range(0, steps):
        indx = np.random.choice(target.shape[0])
        indxls.append(indx)
        
        pred = classif(indx, n)
        
        if(pred == target[indx]):
            correct += 1
            count += 1
        else:
            count += 1
        
        if(i%100 == 0):
            print(str(i) + " correct: " + str(correct/count))
    
    return correct, count, (correct/count)

In [108]:
%%time
get_acc(5, 100)

0 correct: 0.0
CPU times: user 1min 42s, sys: 30.3 s, total: 2min 12s
Wall time: 1min 8s


(97, 100, 0.97)

**Version 2 (above) took about 1 minute, an order of magnitude faster than Version 1**

In [107]:
%%time
get_acc(6, 1000)

0 correct: 1.0
100 correct: 0.9702970297029703
200 correct: 0.9701492537313433
300 correct: 0.9767441860465116
400 correct: 0.9750623441396509
500 correct: 0.9780439121756487
600 correct: 0.9800332778702163
700 correct: 0.9800285306704708
800 correct: 0.9812734082397003
900 correct: 0.9800221975582686
CPU times: user 16min 34s, sys: 4min 54s, total: 21min 28s
Wall time: 11min 18s


(979, 1000, 0.979)

In [118]:
%%time
get_acc(7, 10000)

0 correct: 1.0
100 correct: 0.9900990099009901
200 correct: 0.9900497512437811
300 correct: 0.9833887043189369
400 correct: 0.9850374064837906
500 correct: 0.9860279441117764
600 correct: 0.9883527454242929
700 correct: 0.985734664764622
800 correct: 0.9862671660424469
900 correct: 0.9866814650388457
1000 correct: 0.988011988011988
1100 correct: 0.9872842870118075
1200 correct: 0.9866777685262281
1300 correct: 0.9846272098385856
1400 correct: 0.9857244825124911
1500 correct: 0.9846768820786143
1600 correct: 0.985633978763273
1700 correct: 0.9841269841269841
1800 correct: 0.9816768461965575
1900 correct: 0.9810625986322988
2000 correct: 0.9810094952523738
2100 correct: 0.9804854831032842
2200 correct: 0.9809177646524307
2300 correct: 0.9817470664928292
2400 correct: 0.9816743023740109
2500 correct: 0.9804078368652539
2600 correct: 0.98000768935025
2700 correct: 0.9796371714179933
2800 correct: 0.9796501249553731
2900 correct: 0.9800068941744227
3000 correct: 0.9803398867044318
3100 corr

(9799, 10000, 0.9799)

So if this is anywhere faster than 10 minutes we know we have a good speedup.  And **Version 2 is much faster than Version 1**.


**Next step is to see how little data this can work on**  Also speed up even more via vectorization and doing things like pre-picking the random index values

Also do weighted voting and distance of matrices instead of vectors.  When it gets one wrong, return index of incorrect one.

First below will start by trying to furthur optimize the cosine similarity by messing with entereing matrices and pre-picking random numbers.

## Work towards vectorizing and optimizing even more

In [3]:
data.shape, target.shape

((70000, 784), (70000,))

In [7]:
#choose n random indices
random_indx = [np.random.choice(len(target)) for i in range(0, 1000)]
random_indx[:10]

[66201, 6093, 39272, 22214, 4028, 55381, 66627, 4710, 5307, 8157]

Now need to make some matrix P with n random images chosen from data.  Can do the same thing for either matrix passed into cosine_similarity so as to choose the amount of data testing on.

In [11]:
P = [data[i] for i in random_indx]
P = np.array(P)
P.shape

(1000, 784)

In [14]:
# test how the output datastructure looks for cosine similarity between two matrices
X = np.array([[3, 5, 2],
             [4, 4, 1]])
Y = np.array([[3, 5, 2],
             [200, 10, 10]])
cosine_similarity(X, Y)

array([[1.        , 0.54208823],
       [0.96013024, 0.73798737]])

In [30]:
Z = [[3, 5, 3], [200, 10, 10]]
mc = Counter(Z)
classification = mc.most_common(1)
classification

TypeError: unhashable type: 'list'

In [25]:
top = [(heapq.nlargest(2, range(len(i)), i.take)) for i in Y]

top

[[1, 0], [0, 1]]

In [34]:
def classif(indx, n):
    """indx: index of datapoint in dataset to be classified
    n: number of top values to compare
    returns: top n most similar indexes"""

    dist = cosine_similarity(np.array([data[indx]]), data)
    
    top = heapq.nlargest((n+1), range(len(dist[0])), dist.take)
    #take off the first value as that is its own
    top = [target[i] for i in top[1:(n+1)]]
    
    mc = Counter(top)
    # voting on value to return
    return mc.most_common(1)[0][0]

In [45]:
def most_common(lst):
    return max(set(lst), key=lst.count)

In [47]:
x = np.random.randint(1, 10, (5000))
x = list(x)
most_common(x)

8

In [77]:
def superclassif(comparisons, n):
    correct = 0
    count = 0
    # classify all comparisons
    random_indx = [np.random.choice(len(target)) for i in range(0, comparisons)]
    X = [data[i] for i in random_indx]
    X = np.array(X)
    
    # comparisons x size data structure
    cosim = cosine_similarity(X, data)
    
    # get top n indices for each
    top = [(heapq.nlargest((n), range(len(i)), i.take)) for i in cosim]
    top = [[target[j] for j in i] for i in top[:(n)]]
    
    pred = [most_common(i) for i in top]
    
    correct_classification = [target[i] for i in random_indx]
    
    for i, j in pred, correct_classification:
        if(i==j):
            correct += 1
            count += 1
        else:
            count += 1

In [55]:
random_indx = [np.random.choice(len(target)) for i in range(0, 5)]
random_indx

[20097, 23328, 27054, 16216, 21490]

In [75]:
target[21490]

3.0

In [56]:
X = [data[i] for i in random_indx]
X = np.array(X)
X

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)

In [58]:
cosim = cosine_similarity(X, data)
cosim.shape

(5, 70000)

In [92]:
top = [(heapq.nlargest((6), range(len(i)), i.take)) for i in cosim]
top

[[20097, 63829, 21066, 21379, 63263, 21168],
 [23328, 21463, 23675, 19946, 23251, 22894],
 [27054, 64563, 24990, 25201, 64527, 26471],
 [16216, 62191, 19868, 12828, 16222, 18017],
 [21490, 22622, 23171, 19168, 22810, 20315]]

In [93]:
top = [[target[j] for j in i[1:6]] for i in top]
top

[[3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0],
 [4.0, 4.0, 4.0, 4.0, 4.0],
 [2.0, 3.0, 2.0, 2.0, 2.0],
 [3.0, 3.0, 3.0, 3.0, 3.0]]

In [94]:
pred = [most_common(i) for i in top]
pred

[3.0, 3.0, 4.0, 2.0, 3.0]

In [95]:
correct_classification = [target[i] for i in random_indx]
correct_classification

[3.0, 3.0, 4.0, 2.0, 3.0]

In [91]:
count = 0
correct = 0
for i, j in zip(pred, correct_classification):
    if(i==j):
        correct += 1
        count += 1
    else:
        count += 1
        
print(correct/count)

1.0


## Final
This should be the final (optimized) version (version 3) coming from the work above, of course are always more way to optimize.

I will now run again on 1,000 comparisons and see how it performs.  On last optimized version (version 2), it took about 10 minutes to run through 1,000.

In [102]:
def most_common(lst):
    return max(set(lst), key=lst.count)

def superclassif(comparisons, n):
    """comparisons: the number of numbers to test
    n: number of top highest indices to vote on
    returns: amount of correct predictions, total predictions, and percentage accuracy"""
    correct = 0
    count = 0
    # classify all comparisons
    random_indx = [np.random.choice(len(target)) for i in range(0, comparisons)]
    X = [data[i] for i in random_indx]
    X = np.array(X)
    
    # comparisons x size data structure
    cosim = cosine_similarity(X, data)
    
    # get top n indices for each
    top = [(heapq.nlargest((n+1), range(len(i)), i.take)) for i in cosim]
    top = [[target[j] for j in i[1:(n+1)]] for i in top]
    
    pred = [most_common(i) for i in top]
    
    correct_classification = [target[i] for i in random_indx]
    
    for i, j in zip(pred, correct_classification):
        if(i==j):
            correct += 1
            count += 1
        else:
            count += 1
            
    acc = (correct / count) * 100
    
    return correct, count, acc

In [105]:
%%time
superclassif(1, 6)

CPU times: user 636 ms, sys: 271 ms, total: 906 ms
Wall time: 673 ms


(1, 1, 100.0)

In [110]:
%%time
superclassif(100, 5)

CPU times: user 4.8 s, sys: 297 ms, total: 5.1 s
Wall time: 4.19 s


(96, 100, 96.0)

**Version 3 (above) is even faster than version 1, classifying 100 samples in four seconds, far faster than Version 2 and Verison 1.  Note however, that Version 3 classifies one image in about the same time as Version 2.  This is because Version 3 is built to have a lot of images pushed through it at once, thus it is the same speed on the lower threshold of comparisons, but much faster on the upper threshhold.  If only singular comparisons were being made however, much could be done to Version 3 to make it faster on the lower threshold of comparisons.**

In [111]:
%%time
superclassif(1000, 6)

CPU times: user 40.4 s, sys: 458 ms, total: 40.9 s
Wall time: 37.7 s


(978, 1000, 97.8)

What took Version 2 11 minutes to complete, and what would have taken Version 1 over 100 minutes to complete, was done in 37 seconds on Version 3!

Below I will mess with the n parameter a bit and see how I can nudge more accuracy with adjusting that.

In [112]:
%%time
superclassif(1000, 3)

CPU times: user 39.2 s, sys: 422 ms, total: 39.6 s
Wall time: 36.1 s


(987, 1000, 98.7)

In [113]:
%%time
superclassif(1000, 4)

CPU times: user 39.5 s, sys: 475 ms, total: 40 s
Wall time: 36.7 s


(974, 1000, 97.39999999999999)

In [114]:
%%time
superclassif(1000, 5)

CPU times: user 39.9 s, sys: 393 ms, total: 40.3 s
Wall time: 36.8 s


(977, 1000, 97.7)

In [115]:
%%time
superclassif(1000, 13)

CPU times: user 40.6 s, sys: 398 ms, total: 41 s
Wall time: 37.5 s


(978, 1000, 97.8)

In [116]:
%%time
superclassif(1000, 21)

CPU times: user 40.1 s, sys: 475 ms, total: 40.6 s
Wall time: 37.3 s


(964, 1000, 96.39999999999999)

Looking at the results above, it would seem that smaller values of n do better than larget numbers, however the drop for larget numbers is not very steep.

Now I will test on a large base of 10,000 with an n=3

In [117]:
%%time
superclassif(10000, 3)

CPU times: user 6min 30s, sys: 1.61 s, total: 6min 32s
Wall time: 6min 4s


(9806, 10000, 98.06)

Once again, the optimized version is better having taken 6 minutes to classify 10,000 images instead of 1 hour 44 :)  Also note the higher 98% accuracy with n=3!

## Next Steps
Next I will a) see how this performs on less data, b) optimize even more