Let's first get the imports out of the way.

In [0]:
import array
import gzip
import itertools
import numpy
import numpy.random as npr
import os
import struct
import time
from os import path
import urllib.request

import jax.numpy as np
from jax.api import jit, grad
from jax.config import config
from jax.scipy.special import logsumexp
from jax import random

The following cell contains boilerplate code to download and load MNIST data.

In [0]:
_DATA = "/tmp/"

def _download(url, filename):
  """Download a url to a file in the JAX data temp directory."""
  if not path.exists(_DATA):
    os.makedirs(_DATA)
  out_file = path.join(_DATA, filename)
  if not path.isfile(out_file):
    urllib.request.urlretrieve(url, out_file)
    print("downloaded {} to {}".format(url, _DATA))


def _partial_flatten(x):
  """Flatten all but the first dimension of an ndarray."""
  return numpy.reshape(x, (x.shape[0], -1))


def _one_hot(x, k, dtype=numpy.float32):
  """Create a one-hot encoding of x of size k."""
  return numpy.array(x[:, None] == numpy.arange(k), dtype)


def mnist_raw():
  """Download and parse the raw MNIST dataset."""
  # CVDF mirror of http://yann.lecun.com/exdb/mnist/
  base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"

  def parse_labels(filename):
    with gzip.open(filename, "rb") as fh:
      _ = struct.unpack(">II", fh.read(8))
      return numpy.array(array.array("B", fh.read()), dtype=numpy.uint8)

  def parse_images(filename):
    with gzip.open(filename, "rb") as fh:
      _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
      return numpy.array(array.array("B", fh.read()),
                      dtype=numpy.uint8).reshape(num_data, rows, cols)

  for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
                   "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
    _download(base_url + filename, filename)

  train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
  train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
  test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
  test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))

  return train_images, train_labels, test_images, test_labels


def mnist(create_outliers=False):
  """Download, parse and process MNIST data to unit scale and one-hot labels."""
  train_images, train_labels, test_images, test_labels = mnist_raw()

  train_images = _partial_flatten(train_images) / numpy.float32(255.)
  test_images = _partial_flatten(test_images) / numpy.float32(255.)
  train_labels = _one_hot(train_labels, 10)
  test_labels = _one_hot(test_labels, 10)

  if create_outliers:
    mum_outliers = 30000
    perm = numpy.random.RandomState(0).permutation(mum_outliers)
    train_images[:mum_outliers] = train_images[:mum_outliers][perm]

  return train_images, train_labels, test_images, test_labels

def shape_as_image(images, labels, dummy_dim=False):
  target_shape = (-1, 1, 28, 28, 1) if dummy_dim else (-1, 28, 28, 1)
  return np.reshape(images, target_shape), labels

train_images, train_labels, test_images, test_labels = mnist(create_outliers=False)
num_train = train_images.shape[0]

# **Problem 1**

This function computes the output of a fully-connected neural network (i.e., multilayer perceptron) by iterating over all of its layers and:

1. taking the `activations` of the previous layer (or the input itself for the first hidden layer) to compute the `outputs` of a linear classifier. Recall the lectures: `outputs` is what we wrote $z=w\cdot x + b$ where $x$ is the input to the linear classifier. 
2. applying a non-linear activation. Here we will use $tanh$.

Complete the following cell to compute `outputs` and `activations`. 

In [0]:
def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]: #iterate through each layer
    outputs = np.dot(activations, w) + b
    activations = np.tanh(outputs)

  final_w, final_b = params[-1]
  logits = np.dot(activations, final_w) + final_b #Return a matrix, each vector is one sample
  return logits - logsumexp(logits, axis=1, keepdims=True) #Return a matrix, contains different vectors, each vector has element with log-logsum

The following cell computes the loss of our model. Here we are using cross-entropy combined with a softmax but the implementation uses the `LogSumExp` trick for numerical stability. This is why our previous function `predict` returns the logits to which we substract the `logsumexp` of logits. We discussed this in class but you can read more about it [here](https://blog.feedly.com/tricks-of-the-trade-logsumexp/).

Complete the return line. Recall that the loss is defined as :
$$ l(X, Y) = -\frac{1}{n} \sum_{i\in 1..n}  \sum_{j\in 1.. K}y_j^{(i)} \log(f_j(x^{(i)})) = -\frac{1}{n} \sum_{i\in 1..n}  \sum_{j\in 1.. K}y_j^{(i)} \log\left(\frac{z_j^{(i)}}{\sum_{k\in 1..K}z_k^{(i)}}\right) $$
where $X$ is a matrix containing a batch of $n$ training inputs, and $Y$ a matrix containing a batch of one-hot encoded labels defined over $K$ labels. Here $z_j^{(i)}$ is the logits (i.e., input to the softmax) of the model on the example $i$ of our batch of training examples $X$.

In [0]:
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(preds * targets, axis=1)) #-np.mean(np.sum(preds * targets, axis=1))

The following cell defines the accuracy of our model and how to initialize its parameters. 

In [0]:
def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class) #Return the % of same classes, array_1([0, 1, 2, 3]) , array_2([0, 1, 2, 0]), np.mean(array_1 == array_2) = 0.75

def init_random_params(layer_sizes, rng=npr.RandomState(0)): #Layer sizes contain number of neurons in each layer
  scale = 0.1
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

The following line defines our architecture with the number of neurons contained in each fully-connected layer (the first layer has 784 neurons because MNIST images are 28*28=784 pixels and the last layer has 10 neurons because MNIST has 10 classes)

In [0]:
layer_sizes = [784, 1024, 128, 10]

The following cell creates a Python generator for our dataset. It outputs one batch of $n$ training examples at a time. 

In [0]:
batch_size = 128
num_complete_batches, leftover = divmod(num_train, batch_size) #num_train = 60000, batch_size = 128, returns quotient and remainder. 468, 96
num_batches = num_complete_batches + bool(leftover) #num_batches = 469

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train) #1-60000 randomly permutate
    for i in range(num_batches): #in range(469)
      batch_idx = perm[i * batch_size:(i + 1) * batch_size] #take 128 samples
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

We are now ready to define our optimizer. Here we use mini-batch stochastic gradient descent. Complete `<w UPDATE RULE>` and `<b UPDATE RULE>` using the update rule we saw in class. Recall that `dw` is the partial derivative of the `loss` with respect to `w` and `learning_rate` is the learning rate of gradient descent. 

In [0]:
learning_rate = 0.1

@jit
def update(params, batch):
  grads = grad(loss)(params, batch)
  return [(w - learning_rate * dw, b - learning_rate * db)
          for (w, b), (dw, db) in zip(params, grads)]

This is now the proper training loop for our fully-connected neural network. 

In [9]:
num_epochs = 10
params = init_random_params(layer_sizes)
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    params = update(params, next(batches))  #next() returns the next iterm from iteration
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 2.73 sec
Training set accuracy 0.9401500225067139
Test set accuracy 0.9377000331878662
Epoch 1 in 0.49 sec
Training set accuracy 0.9592833518981934
Test set accuracy 0.95250004529953
Epoch 2 in 0.50 sec
Training set accuracy 0.9681666493415833
Test set accuracy 0.9607000350952148
Epoch 3 in 0.48 sec
Training set accuracy 0.9759166836738586
Test set accuracy 0.9663000702857971
Epoch 4 in 0.49 sec
Training set accuracy 0.9795500040054321
Test set accuracy 0.9676000475883484
Epoch 5 in 0.49 sec
Training set accuracy 0.982616662979126
Test set accuracy 0.970300018787384
Epoch 6 in 0.50 sec
Training set accuracy 0.9865833520889282
Test set accuracy 0.9716000556945801
Epoch 7 in 0.50 sec
Training set accuracy 0.9892333149909973
Test set accuracy 0.9736000299453735
Epoch 8 in 0.50 sec
Training set accuracy 0.9911666512489319
Test set accuracy 0.9741000533103943
Epoch 9 in 0.51 sec
Training set accuracy 0.992983341217041
Test set accuracy 0.9746000170707703


### **Slow Convergence**

In [10]:
def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    outputs = np.dot(activations, w) + b
    activations = np.tanh(outputs)

  final_w, final_b = params[-1]
  logits = np.dot(activations, final_w) + final_b
  return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(preds * targets, axis=1))
def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)
def init_random_params(layer_sizes, rng=npr.RandomState(0)):
  scale = 0.1
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
layer_sizes = [784, 1024, 128, 10]
batch_size = 128
num_complete_batches, leftover = divmod(num_train, batch_size) #num_train = 60000, batch_size = 128, returns quotient and remainder. 468, 96
num_batches = num_complete_batches + bool(leftover) #num_batches = 469

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train) #1-60000 randomly permutate
    for i in range(num_batches): #in range(469)
      batch_idx = perm[i * batch_size:(i + 1) * batch_size] #take 128 samples
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
learning_rate = 0.001

@jit
def update(params, batch):
  grads = grad(loss)(params, batch)
  return [(w - learning_rate * dw, b - learning_rate * db)
          for (w, b), (dw, db) in zip(params, grads)]
num_epochs = 10
params = init_random_params(layer_sizes)
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    params = update(params, next(batches))  #next() returns the next iterm from iteration
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 1.26 sec
Training set accuracy 0.5794000029563904
Test set accuracy 0.5819000005722046
Epoch 1 in 0.47 sec
Training set accuracy 0.7056000232696533
Test set accuracy 0.7110000252723694
Epoch 2 in 0.50 sec
Training set accuracy 0.760366678237915
Test set accuracy 0.7656000256538391
Epoch 3 in 0.49 sec
Training set accuracy 0.7911666631698608
Test set accuracy 0.7988000512123108
Epoch 4 in 0.52 sec
Training set accuracy 0.8119333386421204
Test set accuracy 0.8196000456809998
Epoch 5 in 0.49 sec
Training set accuracy 0.8274999856948853
Test set accuracy 0.8343000411987305
Epoch 6 in 0.50 sec
Training set accuracy 0.8385833501815796
Test set accuracy 0.8464000225067139
Epoch 7 in 0.50 sec
Training set accuracy 0.8474000096321106
Test set accuracy 0.8584000468254089
Epoch 8 in 0.50 sec
Training set accuracy 0.8540499806404114
Test set accuracy 0.8649000525474548
Epoch 9 in 0.52 sec
Training set accuracy 0.8603500127792358
Test set accuracy 0.8695000410079956


### **Oscillations but still converges**

In [3]:
def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    outputs = np.dot(activations, w) + b
    activations = np.tanh(outputs)

  final_w, final_b = params[-1]
  logits = np.dot(activations, final_w) + final_b
  return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(preds * targets, axis=1))
def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)
def init_random_params(layer_sizes, rng=npr.RandomState(0)):
  scale = 0.1
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
layer_sizes = [784, 1024, 128, 10]
batch_size = 128
num_complete_batches, leftover = divmod(num_train, batch_size) #num_train = 60000, batch_size = 128, returns quotient and remainder. 468, 96
num_batches = num_complete_batches + bool(leftover) #num_batches = 469

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train) #1-60000 randomly permutate
    for i in range(num_batches): #in range(469)
      batch_idx = perm[i * batch_size:(i + 1) * batch_size] #take 128 samples
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
learning_rate = 1.6

@jit
def update(params, batch):
  grads = grad(loss)(params, batch)
  return [(w - learning_rate * dw, b - learning_rate * db)
          for (w, b), (dw, db) in zip(params, grads)]
num_epochs = 10
params = init_random_params(layer_sizes)
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    params = update(params, next(batches))  #next() returns the next iterm from iteration
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 2.73 sec
Training set accuracy 0.8037166595458984
Test set accuracy 0.8108000159263611
Epoch 1 in 0.51 sec
Training set accuracy 0.8496833443641663
Test set accuracy 0.8445000648498535
Epoch 2 in 0.51 sec
Training set accuracy 0.8338333368301392
Test set accuracy 0.8315000534057617
Epoch 3 in 0.52 sec
Training set accuracy 0.9210333228111267
Test set accuracy 0.9182000160217285
Epoch 4 in 0.50 sec
Training set accuracy 0.9023500084877014
Test set accuracy 0.8945000171661377
Epoch 5 in 0.51 sec
Training set accuracy 0.8478666543960571
Test set accuracy 0.844700038433075
Epoch 6 in 0.47 sec
Training set accuracy 0.9323333501815796
Test set accuracy 0.9280000329017639
Epoch 7 in 0.52 sec
Training set accuracy 0.9317333698272705
Test set accuracy 0.9284000396728516
Epoch 8 in 0.52 sec
Training set accuracy 0.935533344745636
Test set accuracy 0.9326000213623047
Epoch 9 in 0.48 sec
Training set accuracy 0.956083357334137
Test set accuracy 0.9481000304222107


### **Instability and diverges**

In [3]:
def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    outputs = np.dot(activations, w) + b
    activations = np.tanh(outputs)

  final_w, final_b = params[-1]
  logits = np.dot(activations, final_w) + final_b
  return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(preds * targets, axis=1))
def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)
def init_random_params(layer_sizes, rng=npr.RandomState(0)):
  scale = 0.1
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
layer_sizes = [784, 1024, 128, 10]
batch_size = 128
num_complete_batches, leftover = divmod(num_train, batch_size) #num_train = 60000, batch_size = 128, returns quotient and remainder. 468, 96
num_batches = num_complete_batches + bool(leftover) #num_batches = 469

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train) #1-60000 randomly permutate
    for i in range(num_batches): #in range(469)
      batch_idx = perm[i * batch_size:(i + 1) * batch_size] #take 128 samples
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
learning_rate = 2

@jit
def update(params, batch):
  grads = grad(loss)(params, batch)
  return [(w - learning_rate * dw, b - learning_rate * db)
          for (w, b), (dw, db) in zip(params, grads)]
num_epochs = 10
params = init_random_params(layer_sizes)
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    params = update(params, next(batches))  #next() returns the next iterm from iteration
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 2.68 sec
Training set accuracy 0.09035000205039978
Test set accuracy 0.08920000493526459
Epoch 1 in 0.54 sec
Training set accuracy 0.10441666841506958
Test set accuracy 0.10280000418424606
Epoch 2 in 0.52 sec
Training set accuracy 0.09863333404064178
Test set accuracy 0.0958000048995018
Epoch 3 in 0.54 sec
Training set accuracy 0.09930000454187393
Test set accuracy 0.10320000350475311
Epoch 4 in 0.56 sec
Training set accuracy 0.10441666841506958
Test set accuracy 0.10280000418424606
Epoch 5 in 0.57 sec
Training set accuracy 0.09930000454187393
Test set accuracy 0.10320000350475311
Epoch 6 in 0.55 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09800000488758087
Epoch 7 in 0.56 sec
Training set accuracy 0.09863333404064178
Test set accuracy 0.0958000048995018
Epoch 8 in 0.58 sec
Training set accuracy 0.09736666828393936
Test set accuracy 0.0982000082731247
Epoch 9 in 0.55 sec
Training set accuracy 0.10218333452939987
Test set accuracy 0.10100000351667404


### **Underfit**

In [3]:
def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    outputs = np.dot(activations, w) + b
    activations = np.tanh(outputs)

  final_w, final_b = params[-1]
  logits = np.dot(activations, final_w) + final_b
  return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(preds * targets, axis=1))
def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)
def init_random_params(layer_sizes, rng=npr.RandomState(0)):
  scale = 0.1
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
layer_sizes = [784, 4, 10]
batch_size = 128
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
learning_rate = 0.1

@jit
def update(params, batch):
  grads = grad(loss)(params, batch)
  return [(w - learning_rate * dw, b - learning_rate * db)
          for (w, b), (dw, db) in zip(params, grads)]

num_epochs = 10
params = init_random_params(layer_sizes)
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    params = update(params, next(batches))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 2.47 sec
Training set accuracy 0.6809833645820618
Test set accuracy 0.6829000115394592
Epoch 1 in 0.48 sec
Training set accuracy 0.7415333390235901
Test set accuracy 0.7386000156402588
Epoch 2 in 0.48 sec
Training set accuracy 0.7940000295639038
Test set accuracy 0.7928000092506409
Epoch 3 in 0.45 sec
Training set accuracy 0.8138499855995178
Test set accuracy 0.8131000399589539
Epoch 4 in 0.47 sec
Training set accuracy 0.8187167048454285
Test set accuracy 0.8192000389099121
Epoch 5 in 0.47 sec
Training set accuracy 0.8266833424568176
Test set accuracy 0.8264000415802002
Epoch 6 in 0.48 sec
Training set accuracy 0.8317500352859497
Test set accuracy 0.8302000164985657
Epoch 7 in 0.46 sec
Training set accuracy 0.8362333178520203
Test set accuracy 0.8311000466346741
Epoch 8 in 0.49 sec
Training set accuracy 0.840499997138977
Test set accuracy 0.8371000289916992
Epoch 9 in 0.50 sec
Training set accuracy 0.8401833176612854
Test set accuracy 0.834100067615509


### **Overfit**

In [4]:
train_images, train_labels, test_images, test_labels = mnist(create_outliers=True)
num_train = train_images.shape[0]

def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    outputs = np.dot(activations, w) + b
    activations = np.tanh(outputs)

  final_w, final_b = params[-1]
  logits = np.dot(activations, final_w) + final_b
  return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -np.mean(np.sum(preds * targets, axis=1))
def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)
def init_random_params(layer_sizes, rng=npr.RandomState(0)):
  scale = 0.1
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
layer_sizes = [784, 2048, 1024, 512, 1024, 512, 128, 10]
batch_size = 128
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
learning_rate = 0.01

@jit
def update(params, batch):
  grads = grad(loss)(params, batch)
  return [(w - learning_rate * dw, b - learning_rate * db)
          for (w, b), (dw, db) in zip(params, grads)]

num_epochs = 30
params = init_random_params(layer_sizes)
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    params = update(params, next(batches))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 3.86 sec
Training set accuracy 0.45848333835601807
Test set accuracy 0.7910000085830688
Epoch 1 in 0.81 sec
Training set accuracy 0.49258333444595337
Test set accuracy 0.8338000178337097
Epoch 2 in 0.82 sec
Training set accuracy 0.5123833417892456
Test set accuracy 0.8500000238418579
Epoch 3 in 0.82 sec
Training set accuracy 0.5232999920845032
Test set accuracy 0.8551000356674194
Epoch 4 in 0.82 sec
Training set accuracy 0.5355499982833862
Test set accuracy 0.8591000437736511
Epoch 5 in 0.83 sec
Training set accuracy 0.54708331823349
Test set accuracy 0.8568000197410583
Epoch 6 in 0.82 sec
Training set accuracy 0.5604333281517029
Test set accuracy 0.851900041103363
Epoch 7 in 0.83 sec
Training set accuracy 0.5738000273704529
Test set accuracy 0.8452000617980957
Epoch 8 in 0.83 sec
Training set accuracy 0.5889166593551636
Test set accuracy 0.8370000123977661
Epoch 9 in 0.83 sec
Training set accuracy 0.6087499856948853
Test set accuracy 0.8294000625610352
Epoch 10 in 0.83 sec


# **Problem 2**

Before we get started, we need to import two small libraries that contain boilerplate code for common neural network layer types and for optimizers like mini-batch SGD.

In [0]:
from jax.experimental import optimizers
from jax.experimental import stax

Here is a fully-connected neural network architecture, like the one of Problem 1, but this time defined with `stax`

In [0]:
init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(128),
    stax.Relu,
    stax.Dense(10),
)

We redefine the cross-entropy loss for this model. As done in Problem 1, complete the return line below (it's identical). 

In [0]:
def loss(params, batch):
  inputs, targets = batch
  logits = predict(params, inputs)
  preds  = stax.logsoftmax(logits)
  return -np.mean(np.sum(preds * targets, axis=1))

Next, we define the mini-batch SGD optimizer, this time with the optimizers library in JAX. 

In [0]:
learning_rate = 0.15
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)

@jit
def update(_, i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

The next cell contains our training loop, very similar to Problem 1. 

In [13]:
num_epochs = 12

key = random.PRNGKey(123)
_, init_params = init_random_params(key, (-1, 28, 28, 1))
opt_state = opt_init(init_params)
itercount = itertools.count()

for epoch in range(1, num_epochs + 1):
  for _ in range(num_batches):
    opt_state = update(key, next(itercount), opt_state, shape_as_image(*next(batches)))

  params = get_params(opt_state)
  test_acc = accuracy(params, shape_as_image(test_images, test_labels))
  test_loss = loss(params, shape_as_image(test_images, test_labels))
  print('Epoch {} Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format(epoch, test_loss, 100 * test_acc))

Epoch 1 Test set loss, accuracy (%): (0.18, 94.58)
Epoch 2 Test set loss, accuracy (%): (0.12, 96.46)
Epoch 3 Test set loss, accuracy (%): (0.10, 96.95)
Epoch 4 Test set loss, accuracy (%): (0.08, 97.57)
Epoch 5 Test set loss, accuracy (%): (0.08, 97.64)
Epoch 6 Test set loss, accuracy (%): (0.06, 97.84)
Epoch 7 Test set loss, accuracy (%): (0.07, 97.88)
Epoch 8 Test set loss, accuracy (%): (0.06, 98.02)
Epoch 9 Test set loss, accuracy (%): (0.08, 97.68)
Epoch 10 Test set loss, accuracy (%): (0.06, 98.09)
Epoch 11 Test set loss, accuracy (%): (0.06, 98.07)
Epoch 12 Test set loss, accuracy (%): (0.07, 97.93)


### **Convnet**

In [0]:
init_random_params, predict = stax.serial(
    stax.Conv(out_chan=12, filter_shape=(5, 5), strides=(1, 1)),
    stax.Relu,
    stax.MaxPool(window_shape=(2, 2)),
    stax.Conv(out_chan=16, filter_shape=(5, 5), strides=(1, 1)),
    stax.Relu,
    stax.MaxPool(window_shape=(2, 2)),
    stax.Flatten,
    stax.Dense(120),
    stax.Relu,
    stax.Dense(84),
    stax.Relu,
    stax.Dense(10)
)

In [0]:
def loss(params, batch):
  inputs, targets = batch
  logits = predict(params, inputs)
  preds  = stax.logsoftmax(logits)
  return -np.mean(np.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)

In [0]:
learning_rate = 0.06
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)

@jit
def update(_, i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

In [8]:
num_epochs = 20
batch_size = 128
test_acc_list = []
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]

for runs in range(1,6):
  print('Run {}'.format(runs))
  batches = data_stream()

  key = random.PRNGKey(123)
  _, init_params = init_random_params(key, (-1, 28, 28, 1))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  for epoch in range(1, num_epochs + 1):
    for _ in range(num_batches):
      opt_state = update(key, next(itercount), opt_state, shape_as_image(*next(batches)))

    params = get_params(opt_state)
    test_acc = accuracy(params, shape_as_image(test_images, test_labels))
    test_loss = loss(params, shape_as_image(test_images, test_labels))
    print('Epoch {} Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format(epoch, test_loss, 100 * test_acc))
  test_acc_list.append(test_acc)
mean_accuracy, std_accuracy = np.mean(np.array(test_acc_list)), np.std(np.array(test_acc_list))
print('Test Set Mean Accuracy over 5 runs: {}, Standard Deviation: {}'.format(mean_accuracy, std_accuracy))

Run 1
Epoch 1 Test set loss, accuracy (%): (0.07, 97.79)
Epoch 2 Test set loss, accuracy (%): (0.04, 98.48)
Epoch 3 Test set loss, accuracy (%): (0.04, 98.77)
Epoch 4 Test set loss, accuracy (%): (0.03, 98.95)
Epoch 5 Test set loss, accuracy (%): (0.03, 98.90)
Epoch 6 Test set loss, accuracy (%): (0.03, 99.07)
Epoch 7 Test set loss, accuracy (%): (0.03, 99.05)
Epoch 8 Test set loss, accuracy (%): (0.04, 98.80)
Epoch 9 Test set loss, accuracy (%): (0.05, 98.39)
Epoch 10 Test set loss, accuracy (%): (0.03, 99.12)
Epoch 11 Test set loss, accuracy (%): (0.04, 98.75)
Epoch 12 Test set loss, accuracy (%): (0.04, 98.85)
Epoch 13 Test set loss, accuracy (%): (0.03, 99.09)
Epoch 14 Test set loss, accuracy (%): (0.03, 99.14)
Epoch 15 Test set loss, accuracy (%): (0.03, 99.24)
Epoch 16 Test set loss, accuracy (%): (0.03, 99.08)
Epoch 17 Test set loss, accuracy (%): (0.03, 99.19)
Epoch 18 Test set loss, accuracy (%): (0.03, 99.19)
Epoch 19 Test set loss, accuracy (%): (0.03, 99.20)
Epoch 20 Test s