### MINST using sklearn

In [1]:
import os
import numpy as np
import sklearn
from sklearn.datasets import fetch_openml
import matplotlib as mpl
import matplotlib.pyplot as plt
import sys
assert sys.version_info >= (3, 5)
%matplotlib inline

# Scikit-Learn ≥0.20 is required
assert sklearn.__version__ >= "0.20"



In [2]:
np.random.seed(42)

# To plot pretty figures
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "classification"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

In [3]:
def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)

In [4]:
def plot_digit(data):
    image = data.reshape(28, 28)
    plt.imshow(image, cmap=mpl.cm.binary,
               interpolation="nearest")
    plt.axis("off")

In [5]:
def plot_digits(instances, images_per_row=10, **options):
    size = 28
    images_per_row = min(len(instances), images_per_row)
    images = [instance.reshape(size, size) for instance in instances]
    n_rows = (len(instances) - 1) // images_per_row + 1
    row_images = []
    n_empty = n_rows * images_per_row - len(instances)
    images.append(np.zeros((size, size * n_empty)))
    for row in range(n_rows):
        rimages = images[row * images_per_row: (row + 1) * images_per_row]
        row_images.append(np.concatenate(rimages, axis=1))
    image = np.concatenate(row_images, axis=0)
    plt.imshow(image, cmap=mpl.cm.binary, **options)
    plt.axis("off")

In [6]:
def display_random_numbers(X):
    some_digit = X[0]
    some_digit_image = some_digit.reshape(28, 28)
    plt.imshow(some_digit_image, cmap=mpl.cm.binary)
    plt.axis("off")
    save_fig("some_digit_plot")
    plt.show()

    plt.figure(figsize=(9, 9))
    example_images = X[:100]
    plot_digits(example_images, images_per_row=10)
    save_fig("more_digits_plot")
    plt.show()

In [8]:
def load_minst():
    mnist = fetch_openml('mnist_784', version=1)
    print(mnist.keys())
    X, y = mnist["data"], mnist["target"]
    y = y.astype(np.uint8)
    print(X.shape, y.shape)
    return X, y

In [None]:
X, y = load_minst()
# lets keep aside our test data before we do anything
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
# display some random digits to get an idea about the complexity of our problem
display_random_numbers(X_train)
