Example. As an example we apply least squares classification to the MNIST data set described in §4.3. The (training) data set contains 60,000 images of size 28 by 28. The number of examples per digit varies between 5421 (for digit five) and 6742 (for digit one). The pixel intensities are scaled to lie between 0 and 1. There is also a separate test set containing 10000 images.

In [1]:
import struct
import gzip
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.animation as animation
from IPython.display import HTML
%matplotlib widget

In [2]:
# create dictionaries to store the data
train = dict()
test = dict()

In [3]:
def get_images(filename):
    with gzip.GzipFile(Path('mnist', filename), 'rb') as f:
        magic, size, rows, cols = struct.unpack(">IIII", f.read(16))
        images = np.frombuffer(f.read(), dtype=np.dtype('B'))
    return images.reshape(size, rows,cols)

train['image'] = get_images('train-images-idx3-ubyte.gz')
test['image'] = get_images('t10k-images-idx3-ubyte.gz')

In [4]:
def get_labels(filename):
    with gzip.GzipFile(Path('mnist', filename), 'rb') as f:
        magic, num = struct.unpack(">II", f.read(8))
        labels = np.frombuffer(f.read(), dtype=np.dtype('B'))
    return labels

In [5]:
train['label'] = get_labels('train-labels-idx1-ubyte.gz')
test['label'] = get_labels('t10k-labels-idx1-ubyte.gz')

For each digit, we can define a Boolean classifier that distinguishes the digit from the other nine digits. Here we will consider classifiers to distinguish the digit zero. In a first experiment, we use the n = 28 × 28 = 784 pixel intensities as features in the least squares classifier (12.1). 

In [6]:
#scales images to 0-1
x = (train['image'].reshape(60000, -1)/255)
#train only on 0, code 0 as +1, >0 as 1
y = (train['label']>0).astype(int)*-2 + 1

In [7]:
from sklearn import linear_model as slm
lm = slm.LinearRegression()
lm.fit(x, y)
yhat = lm.predict(x)

In [30]:
fig, (ax, ax2) = plt.subplots(ncols=2, figsize=(10,5))
fig.suptitle("Linear Regression: y = mx+b\n predicted class = coefficient*observation + residuals")
ax. set_title("classifier/coeff(m)")
im = ax.imshow(lm.coef_.reshape(28,28), cmap="RdBu", vmin=-.3, vmax=.3)
cb = fig.colorbar(im, ax=ax, fraction=.045)
cb.set_ticks([-.3, -.15, 0, .15, .3])
cb.set_ticklabels([r"$\leq.3$", "-.15", "0", ".15", "$\geq.3$"])

ax2.set_title("observation (x) ")
im2 = ax2.imshow(x[1000].reshape(28,28), cmap='gray')
cb2 = fig.colorbar(im2, ax=ax2, fraction=.045)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

The output of the linear regression, y, is not yet the binary class assigment yhat.  As shown below, there is some misclassification- 0s classified as not, not 0s classified as 0s.

In [37]:
def draw_histogram(alpha, ax):
    _, _, patches1 = ax.hist(yhat[y==1] + alpha, label="y = +1", color="tab:blue", bins=50, density=True, alpha=.75)
    _, _, patches2 = ax.hist(yhat[y==-1] + alpha, label="y = -1", color="tab:orange", bins=50, density=True, alpha=.75)
    return patches1, patches2

fig, ax1 = plt.subplots(figsize=(12, 4))

alpha = 0
patches1, patches2 = draw_histogram(alpha, ax=ax1)
_ = ax1.set_title("Confusion Histogram")
_ = ax1.axvspan(0, 3, facecolor= 'lavender', edgecolor='lightgray', label=r"$\hat{y}$ = +1", zorder=-6)
_ = ax1.axvspan(-3, 0, facecolor='cornsilk', edgecolor='lightgray', label=r"$\hat{y}$ = -1", zorder=-6)
_ = ax1.axvline(x=0, color='k')
_ = ax1.set_xlim(-3,3)
_ = ax1.legend(ncol=2, loc=1, facecolor='white', framealpha=.95)
_ = ax1.set_xlabel("coeff * obs")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

The 0 in the middle represents $\alpha=0$ where $\alpha$ is a coeffiecient controlling the trade off between the true and false positive rates. We can visualize that trade off in what's called an ROC curve.

In [38]:
from functools import lru_cache
def make_roc_factory(y, yhat):
    @lru_cache
    def compute_roc(alpha):  
        yhat_roc = np.sign(yhat + alpha)
        tp = ((y==1) & (yhat_roc==1)).sum()
        tn = ((y==-1) & (yhat_roc==-1)).sum()
        fp = ((y==-1) & (yhat_roc==1)).sum()
        fn = ((y==1) & (yhat_roc==-1)).sum()
        tpr = tp/(tp+fn)
        fpr = fp/(fp+tn)
        return fpr, tpr
    return compute_roc

compute_roc = make_roc_factory(y, yhat)

In [39]:
fpx, fpy = zip(*[compute_roc(alpha) for alpha in np.linspace(-1,1,1000)])

In [40]:
fig, ax2 = plt.subplots(figsize=(12, 4))

alpha = 0
patches1, patches2 = draw_histogram(alpha, ax=ax1)

_ = ax2.plot(fpx, fpy, color='darkseagreen')

xo, yo = compute_roc(alpha)

pcoll = ax2.scatter(xo, yo, s=30, c='seagreen', zorder=5)
_ = ax2.set_xlabel("False Positive")
_ = ax2.set_ylabel("True Positive")
xof, yof = label_offsets(alpha)
label = ax2.text(xo+xof, yo+yof, f'alpha: {alpha:.2f}', color='mediumseagreen')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

More helpful still would be to visualize the interaction between changing alpha and the ROC curve and histograms

In [41]:
def label_offsets(alpha, offset=.015):
    xscale, yscale = 1, 1
    if alpha>.9:
        xscale = -10
    if alpha >=.6:
        xscale = -4
        yscale = -5
    if alpha>=.3:
        xscale = -1
        yscale = -4
    return xscale*offset, yscale*offset

In [42]:
%%capture

def draw_histogram(alpha, ax):
    _, _, patches1 = ax.hist(yhat[y==1] + alpha, label="y = +1", color="tab:blue", bins=50, density=True, alpha=.75)
    _, _, patches2 = ax.hist(yhat[y==-1] + alpha, label="y = -1", color="tab:orange", bins=50, density=True, alpha=.75)
    return patches1, patches2

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(12, 4), 
                               gridspec_kw ={'width_ratios':[1.5,1]})

alpha = 0
patches1, patches2 = draw_histogram(alpha, ax=ax1)

_ = ax1.axvspan(0, 3, facecolor= 'lavender', edgecolor='lightgray', label=r"$\hat{y}$ = +1", zorder=-6)
_ = ax1.axvspan(-3, 0, facecolor='cornsilk', edgecolor='lightgray', label=r"$\hat{y}$ = -1", zorder=-6)
_ = ax1.axvline(x=0, color='k')
_ = ax1.set_xlim(-3,3)
_ = ax1.legend(ncol=2, loc=1, facecolor='white', framealpha=.95)

_ = ax2.plot(fpx, fpy, color='darkseagreen')


xo, yo = compute_roc(alpha)

pcoll = ax2.scatter(xo, yo, s=30, c='seagreen', zorder=5)
_ = ax2.set_xlabel("False Positive")
_ = ax2.set_ylabel("True Positive")
xof, yof = label_offsets(alpha)
label = ax2.text(xo+xof, yo+yof, f'alpha: {alpha:.2f}', color='mediumseagreen')

cached_state = patches1+patches2

def animate(alpha):
    for p in cached_state:
        p.remove()
    del cached_state[:]
    patches1, patches2 = draw_histogram(alpha, ax=ax1)
    cached_state.extend(patches1+patches2)
 
    xp, yp = compute_roc(alpha)
    
    pcoll.set_offsets([[xp, yp]])
    
    label.set_text(f'alpha: {alpha:.2f}')
    xo, yo = label_offsets(alpha)
    label.set_position((xp+xo, yp+yo))
    
    return [patches1, patches2, pcoll, label]

In [43]:
ani = animation.FuncAnimation(fig, animate, frames=np.arange(-1,1,.1),
                             interval=500)
HTML(ani.to_jshtml())

In [None]:
ani.save('alpha.gif', dpi=80, writer='imagemagick')

In [None]:
from ipywidgets import interact

In [None]:
def draw_histogram(alpha, ax):
    _, _, patches1 = ax.hist(yhat[y==1] + alpha, label="y = +1", color="tab:blue", bins=50, density=True, alpha=.75)
    _, _, patches2 = ax.hist(yhat[y==-1] + alpha, label="y = -1", color="tab:orange", bins=50, density=True, alpha=.75)
    return patches1, patches2


fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(12, 4), 
                               gridspec_kw ={'width_ratios':[1.5,1]})

patches1, patches2 = draw_histogram(0, ax=ax1)

_ = ax1.axvspan(0, 3, facecolor= 'lavender', edgecolor='lightgray', label=r"$\hat{y}$ = +1", zorder=-6)
_ = ax1.axvspan(-3, 0, facecolor='cornsilk', edgecolor='lightgray', label=r"$\hat{y}$ = -1", zorder=-6)
_ = ax1.axvline(x=0, color='k')
_ = ax1.set_xlim(-3,3)
_ = ax1.legend(ncol=2, loc=1, facecolor='white', framealpha=.95)


_ = ax2.plot(fpx, fpy, color='darkseagreen')

alpha = 0
xo, yo = compute_roc(alpha)

pcoll = ax2.scatter(xo, yo, s=30, c='seagreen', zorder=5)
_ = ax2.set_xlabel("False Positive")
_ = ax2.set_ylabel("True Positive")
xof, yof = label_offsets(alpha)
label = ax2.text(xo+xof, yo+yof, f'alpha: {alpha:.2f}', color='mediumseagreen')

cached_state = patches1+patches2

@interact(alpha=(-1, 1, .05))
def animate(alpha):
    for p in cached_state:
        p.remove()
    del cached_state[:]
    patches1, patches2 = draw_histogram(alpha, ax=ax1)
    cached_state.extend(patches1+patches2)
 
    xp, yp = compute_roc(alpha)
    
    pcoll.set_offsets([[xp, yp]])
    
    label.set_text(f'alpha: {alpha:.2f}')
    xo, yo = label_offsets(alpha)
    label.set_position((xp+xo, yp+yo))
    #return [patches1, patches2, pcoll, label]
    return 