# Timo's simplistic attempts at an MNIST classifier

The code here roughly reproduces what's done in 3Blue1Brown's introductory videos on Neural Networks. 
In case you have never seen those videos: you absolutely must! They are awesome! They can be found at: https://www.3blue1brown.com/topics/neural-networks 

Classifying digits is the "Hello World" of neural networks -- a moderately difficult task that can be solved reasonably well even with a simple neural network. It's also possible to train on standard computers without needing to wait forever.

In [1]:
import time
import random
import os
import struct
import numpy as np
import torch
import torch.nn as nn
import ipywidgets as widgets
import PIL.Image

In [2]:
import matplotlib.pyplot as plt
from collections import defaultdict
# we will often repeatedly use some code for plotting. We define this once to be re-used later.
def plot(plot_data):
    plt.figure("evolution over training iterations")
    for type in plot_data.keys():
       plt.plot(plot_data[type], label=type)
    plt.legend(loc='upper right')

In [4]:
# this code reads a database of digits (the database is already included in the repository)

# minimally adapted from https://gist.github.com/akesling/5358964
def read_mnist(dataset, path="data/MNIST/raw/"):
    if dataset == "training":
        fname_img = os.path.join(path, "train-images-idx3-ubyte")
        fname_lbl = os.path.join(path, "train-labels-idx1-ubyte")
    elif dataset == "testing":
        fname_img = os.path.join(path, "t10k-images-idx3-ubyte")
        fname_lbl = os.path.join(path, "t10k-labels-idx1-ubyte")
    else:
        raise ValueError("dataset must be 'testing' or 'training'")

    # Load everything in some numpy arrays
    with open(fname_lbl, "rb") as flbl:
        magic, num = struct.unpack(">II", flbl.read(8))
        labels = np.fromfile(flbl, dtype=np.int8)

    with open(fname_img, "rb") as fimg:
        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
        images = np.multiply(
            np.fromfile(fimg, dtype=np.uint8).reshape(len(labels), rows*cols),
            1.0 / 255.0)

    get_instance = lambda idx: (labels[idx], images[idx])

    # Create an iterator which returns each image in turn
    for i in range(len(labels)):
        yield get_instance(i)
        
training = [(lbl, img) for (lbl, img) in read_mnist("training")]
testing = [(lbl, img) for (lbl, img) in read_mnist("testing")]

In [5]:
import io
# print some examples (pairs of target number and image)
for i in range(10):
    buffer = io.BytesIO()
    PIL.Image.fromarray((256*training[i][1]).reshape(28, 28)).convert('L').save(buffer, format="PNG")
    display(training[i][0], widgets.Image(value=buffer.getvalue(),width=28,height=28))

5

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


0

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


4

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


1

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


9

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


2

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


1

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


3

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


1

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


4

Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


In [6]:
HIDDEN_DIM1 = 64
HIDDEN_DIM2 = 64
DROPOUT_RATE = 0.5

In [7]:
class MNISTClassify(nn.Module):
    def __init__(self):
        super(MNISTClassify, self).__init__()
        input_size = 28 * 28
        self.W1 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(input_size, HIDDEN_DIM1)))
        self.b1 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(1, HIDDEN_DIM1)))
        self.W2 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(HIDDEN_DIM1, HIDDEN_DIM2)))
        self.b2 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(1, HIDDEN_DIM2)))
        self.W3 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(HIDDEN_DIM2, 10)))
        self.b3 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(1, 10)))
        
    def forward(self, x):
        # erste innere Schicht:
        h1 = nn.functional.sigmoid(x @ self.W1 + self.b1)
        # zweite innere Schicht:
        h2 = nn.functional.sigmoid(h1 @ self.W2 + self.b2)
        # Ergebnisschicht:
        activation = nn.functional.sigmoid(h2 @ self.W3 + self.b3, dim=1)
        return activation

In [None]:
# DO NOT EXECUTE THIS CELL!

class MNISTClassify(nn.Module):
    """the same but using the built-in modules for densely connected feed-forward layers"""
    def __init__(self):
        super(MNISTClassify, self).__init__()
        input_size = 28 * 28
        self.l1 = nn.Linear(input_size, HIDDEN_DIM1)
        self.l2 = nn.Linear(HIDDEN_DIM1, HIDDEN_DIM2)
        self.l3 = nn.Linear(HIDDEN_DIM2, 10)

    def forward(self, x, isTrain=False):
        return nn.functional.log_softmax(self.l3(nn.functional.sigmoid(self.l2(nn.functional.sigmoid(self.l1(x))))))
