# Prototype Selection for Nearest Neighbor

One way to speed up nearest neighbor classification is to replace the training set by a carefully chosen
subset of *“prototypes”*. We then simply run nearest neighbor on the smaller prototype dataset.

In this notebook you will <font color="blue">*create your own strategy*</font> for selecting a prototype dataset on **MNIST**. We will then see how your prototype compares to a randomly selected dataset of the same size.

# Setup Notebook

As usual, we start by importing the required packages and data in order to do the notebook. For this notebook we will be using the **entire** `MNIST` dataset. The code below defines some helper functions that will load `MNIST` onto your computer.

In [1]:
import gzip
import sys
import os
import copy
import numpy as np
import pickle

if sys.version_info[0] == 2:
    from urllib import urlretrieve
else:
    from urllib.request import urlretrieve


In [2]:
def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
    print("Downloading %s" % filename)
    urlretrieve(source + filename, filename)

def load_mnist_images(filename):
    if not os.path.exists(filename):
        download(filename)
    # Read the inputs in Yann LeCun's binary format.
    with gzip.open(filename, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1,784)
    return data / np.float32(256)

def load_mnist_labels(filename):
    if not os.path.exists(filename):
        download(filename)
    with gzip.open(filename, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=8)
        #data2 = np.zeros( (len(data),10), dtype=np.float32 )
        #for i in range(len(data)):
        #    data2[i][ data[i] ] = 1.0
    return data

We now import the required packages and load in `MNIST`. If necessary, `MNIST` is downloaded onto your computer.

In [3]:
%matplotlib inline
import matplotlib.pyplot as plt 
import time
from sklearn.neighbors import BallTree

## Load the training set
train_data = load_mnist_images('train-images-idx3-ubyte.gz')
train_labels = load_mnist_labels('train-labels-idx1-ubyte.gz')

## Load the testing set
test_data = load_mnist_images('t10k-images-idx3-ubyte.gz')
test_labels = load_mnist_labels('t10k-labels-idx1-ubyte.gz')

Downloading train-labels-idx1-ubyte.gz
Downloading t10k-labels-idx1-ubyte.gz


# Creating a Random Prototype

To give you a better idea of how this process works, let's first consider the case where the prototype is randomly selected. Let us select $M<6000$ datapoints at random and observe the quality of the fit.

The following function, <font color="blue">**rand_prototype**</font>, returns the prototype features and labels for $M$ datapoints.

In [15]:
def rand_prototype(M):
    indices = np.random.choice( len(train_labels) , M, replace=False)
    return train_data[indices,:], train_labels[indices] 

Here is an example of **rand_prototype** in action. The function returns a subset of `train_data` of size $M=1000$ .

In [16]:
example_data, example_labels = rand_prototype(1000)
print "Shape of train_data:",train_data.shape
print "Shape of prototype: ",example_data.shape

Shape of train_data: (60000, 784)
Shape of prototype:  (1000, 784)


Now lets check the accuracy of the **rand_prototype** function for different values of $M$. Intuitively, a prototype should become more accurate when it uses more datapoints to perform nearest neighbors. Thus we would expect the accuracy of **rand_prototype** to increase as $M$ increases.  

The function, <font color="blue">**NN_accuracy**</font>,  computes the test accuracy of using Nearest Neighbors with a specific prototype.

In [25]:
def NN_accuracy(proto_data, proto_labels):
    ball_tree = BallTree(proto_data, metric='euclidean')
    test_neighbors = np.squeeze(ball_tree.query(test_data, k=1, return_distance=False))
    test_fit = proto_labels[test_neighbors]
    return sum(test_fit == test_labels)/float(len(test_fit))


Here is an example of using **NN_accuracy** to calculate the accuracy of prototypes with different sizes.

In [26]:
proto_data, proto_labels = rand_prototype(500)
print "Prototype Accuracy  for 500 datapoints:\t\t", NN_accuracy(proto_data, proto_labels) 

proto_data, proto_labels = rand_prototype(5000)
print "Prototype Accuracy  for 5000 datapoints:\t", NN_accuracy(proto_data, proto_labels) 

Prototype Accuracy  for 500 datapoints:		0.8416
Prototype Accuracy  for 5000 datapoints:	0.9407


The final function, <font color="blue">**check_strategy**</font>, uses `rand_prototype` and `NN_accuracy` to to run multiple trials of the prototype strategy. It then calculates the mean accuracy of the trials.

In [27]:
def check_strategy(fn_strategy, M, rounds=1 ):
    acc_list = []
    for i in range(rounds):
        proto_data, proto_labels = fn_strategy(M)
        accuracy = NN_accuracy(proto_data, proto_labels)
        acc_list.append(accuracy)
    return np.mean(acc_list)
    

Here is an example showing how the check_strategy funciton works.

In [28]:
acc = check_strategy(rand_prototype, M=1000, rounds=5)
print("Mean Accuracy of 5 trial prototypes each of size 1000: ", acc)

('Mean Accuracy of 5 trial prototypes each of size 1000: ', 0.88414000000000004)


# Build your Own Prototype

It is now time for you to create your own strategy for picking the prototype! 

Write a function, <font color="blue">**my_prototype**</font>, that creates a prototype of size $M$ using your own strategy. Like the the **rand_prototype** function, your function should take $M$ as its input and should return your prototype's data and labels.

In [29]:
# Modify this Cell

def my_prototype(M):
    
    # 
    # Write your own function here
    #
    
    return prototype_data, prototype_labels

# Compare Strategies

It's time to put you code to the test! Lets see if it can do better than just randomely selecting a prototype.

In [30]:
import ipywidgets as widgets
from IPython.display import display
from ipywidgets import interact, interactive, fixed, interact_manual

You can use the following widget to see how your code fairs agaisnt the randome strategy by moving the sliders around.

In [31]:
@interact_manual( M=(100,7500,100), rounds=(1,15))
def comparison(M,rounds):
    print("Comparing your prototype to the random prototype...")
    rand_acc = check_strategy(rand_prototype, M, rounds) 
    my_acc   = check_strategy(  my_prototype, M, rounds) 
    
    print;print("Prototype Size:\t\t %d" % (M))
    print("Number of Trials:\t %d" % (rounds))
    print("Random Prototype Accuracy:\t %f" % (rand_acc) )
    print("Your Prototype Accuracy:\t %f" % (my_acc) );print
    if rand_acc>my_acc:
        print("The RANDOM Prototype Wins!")
    else:
        print("YOUR Prototype Wins!")
    