Example. As an example we apply least squares classification to the MNIST data set described in §4.3. The (training) data set contains 60,000 images of size 28 by 28. The number of examples per digit varies between 5421 (for digit five) and 6742 (for digit one). The pixel intensities are scaled to lie between 0 and 1. There is also a separate test set containing 10000 images.

In [1]:
import struct
import gzip
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.animation as animation
from IPython.display import HTML
%matplotlib widget

In [2]:
# create dictionaries to store the data
train = dict()
test = dict()

In [3]:
def get_images(filename):
    with gzip.GzipFile(Path('mnist', filename), 'rb') as f:
        magic, size, rows, cols = struct.unpack(">IIII", f.read(16))
        images = np.frombuffer(f.read(), dtype=np.dtype('B'))
    return images.reshape(size, rows,cols)

train['image'] = get_images('train-images-idx3-ubyte.gz')
test['image'] = get_images('t10k-images-idx3-ubyte.gz')

In [4]:
def get_labels(filename):
    with gzip.GzipFile(Path('mnist', filename), 'rb') as f:
        magic, num = struct.unpack(">II", f.read(8))
        labels = np.frombuffer(f.read(), dtype=np.dtype('B'))
    return labels

In [5]:
train['label'] = get_labels('train-labels-idx1-ubyte.gz')
test['label'] = get_labels('t10k-labels-idx1-ubyte.gz')

For each digit, we can define a Boolean classifier that distinguishes the digit from the other nine digits. Here we will consider classifiers to distinguish the digit zero. In a first experiment, we use the n = 28 × 28 = 784 pixel intensities as features in the least squares classifier (12.1). 

In [6]:
#scales images to 0-1
x = (train['image'].reshape(60000, -1)/255)
#train only on 0, code 0 as +1, >0 as 1
y = (train['label']>0).astype(int)*-2 + 1

In [7]:
from sklearn import linear_model as slm
lm = slm.LinearRegression()
lm.fit(x, y)
yhat = lm.predict(x)

In [8]:
from functools import lru_cache
def make_roc_factory(y, yhat):
    @lru_cache(maxsize=128, typed=False)
    def compute_roc(alpha):  
        yhat_roc = np.sign(yhat + alpha)
        tp = ((y==1) & (yhat_roc==1)).sum()
        tn = ((y==-1) & (yhat_roc==-1)).sum()
        fp = ((y==-1) & (yhat_roc==1)).sum()
        fn = ((y==1) & (yhat_roc==-1)).sum()
        tpr = tp/(tp+fn)
        fpr = fp/(fp+tn)
        return fpr, tpr
    return compute_roc

compute_roc = make_roc_factory(y, yhat)

In [9]:
def label_offsets(alpha, offset=.015):
    xscale, yscale = 1, 1
    if alpha>.9:
        xscale = -10
    if alpha >=.6:
        xscale = -4
        yscale = -5
    if alpha>=.3:
        xscale = -1
        yscale = -4
    return xscale*offset, yscale*offset

In [10]:
class OLS_ALPHA:
    alphas = np.linspace(-1, 1, 1000)
    fpx, fpy = zip(*[compute_roc(alpha) for alpha in alphas])
    
    def __init__(self, alpha=0):
        self.alpha = 0
        self.fig, (self.ax1, self.ax2) = plt.subplots(ncols=2, figsize=(12, 4), 
                                            gridspec_kw ={'width_ratios':[1.5,1]})
        self.patches1, self.patches2 = self.draw_histogram()
        self.cached_state = self.patches1+self.patches2

        _ = self.ax1.axvspan(0, 3, facecolor= 'lavender', edgecolor='lightgray', label=r"$\hat{y}$ = +1", zorder=-6)
        _ = self.ax1.axvspan(-3, 0, facecolor='cornsilk', edgecolor='lightgray', label=r"$\hat{y}$ = -1", zorder=-6)
        _ = self.ax1.axvline(x=0, color='k')
        _ = self.ax1.set_xlim(-3,3)
        _ = self.ax1.set_ylim(0,2.5)
        _ = self.ax1.legend(ncol=2, loc=1, facecolor='white', framealpha=.95)

        self.line, = self.ax2.plot(self.fpx, self.fpy, color='darkseagreen', picker=True)
        self.line.set_pickradius(10)


        self.sx, self.sy = compute_roc(self.alpha)
        
        self.s = 30
        self.pcoll = self.ax2.scatter(self.sx, self.sy, s=self.s, c='seagreen', zorder=5, picker=True)
        _ = self.ax2.set_xlabel("False Positive")
        _ = self.ax2.set_ylabel("True Positive")

        xof, yof = label_offsets(self.alpha)
        self.label = self.ax2.text(self.sx+xof, self.sy+yof, f'alpha: {alpha:.2f}', color='mediumseagreen')

        # to create dragging update effect
        self.rect = mpatches.Rectangle((-3,0), height=3, width=6, ec=None, facecolor='white',
                              alpha = 0,picker=True, zorder=100)
        self.ax1.add_patch(self.rect)

        self.fig.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidrelease = self.rect.figure.canvas.mpl_connect('button_release_event', self.on_release)
        #to do - add motion for smoothness
        #self.cidmotion = self.rect.figure.canvas.mpl_connect('motion_notify_event', self.on_motion)
        plt.show()
    
    def draw_histogram(self):
        _, _, patches1 = self.ax1.hist(yhat[y==1] + self.alpha, label="y = +1", color="tab:blue", bins=50, density=True, alpha=.75)
        _, _, patches2 = self.ax1.hist(yhat[y==-1] + self.alpha, label="y = -1", color="tab:orange", bins=50, density=True, alpha=.75)
        return patches1, patches2

    def update_plot(self):
        for p in self.cached_state:
            p.remove()
        del self.cached_state[:]
        self.patches1, self.patches2 = self.draw_histogram()
        self.cached_state.extend(self.patches1+self.patches2)

        xp, yp = compute_roc(self.alpha)

        self.pcoll.set_offsets([[xp, yp]])

        self.label.set_text(f'alpha: {self.alpha:.2f}')
        xo, yo = label_offsets(self.alpha)
        self.label.set_position((xp+xo, yp+yo))
        return [self.patches1, self.patches2, self.pcoll, self.label]
    
    def on_press(self, event):
        'on button press we will see if the mouse is over us and store some data'
        self.x0, self.y0 = event.xdata, event.ydata
        if self.rect.contains(event)[0]:
            self.trigger = 'histogram' 
        elif self.pcoll.contains(event)[0]:
            self.trigger = 'scatter'
        else:
            self.trigger = None
        print(self.trigger, self.x0, self.y0)
        return
    
    def on_motion(self, event):
        pass
    
    def on_release(self, event):
        #selected dot, move somewhere on line
        contains, attrs = self.line.contains(event)
        if self.trigger == 'histogram' and self.rect.contains(event)[0]:
            self.alpha = np.clip(event.xdata, -1, 1)
            print("rectangle:", self.alpha)
        elif self.trigger == 'scatter' and contains:
            #since line map directly to list of alphas, grab any random alpha
            ind = np.random.choice(attrs['ind'])
            self.alpha = self.alphas[ind]
            print("scatter:", self.alpha)   
        else:
            return
        self.trigger = None
        return self.update_plot()



In [11]:
OLS_ALPHA()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<__main__.OLS_ALPHA at 0x7f437d1dc340>

In [30]:
np.clip(1.2,1,2.2)

1.2