In [10]:
# You don't know JAX - Colin Raffel tutorial

import random
import itertools

import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp

from __future__ import print_function

In [11]:
#XOR function: input 1 or 2, but not both and not none
#let's make a neural network with one 3 neuron hidden layer, 
#hyperbolic tangent nonlinearity + sigmoid activation function
#cross-entropy loss + stochastic gradient descent training

# Sigmoid nonlinearity
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Computes our network's output
# tanh then sigmoid with 2 sets of weights and biases
def net(params, x):
    w1, b1, w2, b2 = params
    hidden = np.tanh(np.dot(w1, x) + b1)
    return sigmoid(np.dot(w2, hidden) + b2)

# Cross-entropy loss
def loss(params, x, y):
    out = net(params, x)
    cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
    return cross_entropy

# Utility function for testing whether the net produces the correct
# output for all possible inputs
# just prints all our inputs and corresponding outputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

#initialise parameters using original numpy
#we don't need to do any fancy transformations so we shouldnt use jax
def initial_params():
    return [
        onp.random.randn(3, 2),  # w1
        onp.random.randn(3),  # b1
        onp.random.randn(3),  # w2
        onp.random.randn(),  #b2
    ]

In [12]:
# jax.grad takes a function and returns a new function to compute the gradient of the original function
# default = gradient wrt first argument
loss_grad = jax.grad(loss)

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

for n in itertools.count():
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0])]
    # Compute the target output
    y = onp.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [13]:
#JIT compiler takes a standard function and compiles it to run efficiently on an accelerator
#also avoids overhead of Python interpreter

# Time the original gradient function
%timeit loss_grad(params, x, y)

loss_grad = jax.jit(jax.grad(loss))
# Run once to trigger JIT compilation
loss_grad(params, x, y)
%timeit loss_grad(params, x, y)

10.3 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
102 µs ± 1.57 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [14]:
#now let's run the NN again to make sure it works
#super quick!!

params = initial_params()

for n in itertools.count():
    x = inputs[onp.random.choice(inputs.shape[0])]
    y = onp.bitwise_xor(*x)
    grads = loss_grad(params, x, y)
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [15]:
#jax.vmap automatically vectorises a function
#now we can compute the output of a function in parallel over some axis of the input
#here we can use it to get a loss function gradient which can take a minibatch of examples
#more arguments: in_axes = axes for parallelisation, same length as the number of arguments in the function
                        # being vectorised
                # out_axes = axes for parallelisation of the output
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))

params = initial_params()

batch_size = 100

for n in itertools.count():
    # Generate a batch of inputs
    # chose batch_size inputs from the input list for our x matrix
    x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    # The call to loss_grad remains the same!
    grads = loss_grad(params, x, y)
    # Note that we now need to average gradients over the batch
    params = [param - learning_rate * np.mean(grad, axis=0)
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 1
Iteration 200
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 300
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 400
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [16]:
#JAX Quickstart tutorial

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [17]:
#how to generate random data
#need an explicity PRNG first
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
 -0.6713536  -0.59086424  0.73168874  0.56730247]


In [18]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
#block_until_ready - removes asynchronous execution
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

174 ms ± 5.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
#also works on regular numpy arrays
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

199 ms ± 4.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [20]:
#ensure that NDArray is backed by device memory using device_put()
#still acts like an NDArray, but only copies values back to the CPU when needed
#equivalent of jit(lambda x: x)
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

172 ms ± 941 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [21]:
#run on GPU or TPU
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)

115 ms ± 6.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [22]:
#using jit() to speed up code

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

3.85 ms ± 95.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
#use jit decorator to compile  multiple operations together
#jit-compiles the first time selu is called, then cached thereafter
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

912 µs ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [24]:
#using grad() to calculate derivatives

#calculate gradient for this sigmoid function
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [25]:
#verify correct gradient with finite diff
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))


[0.24998187 0.1964569  0.10502338]


In [26]:
#can mix grad and jit
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.035325594


In [27]:
#jax.vjp() for reverse-mode vector-Jacobian products
#jax.jvp() for forward-mode Jacobian-vector products
#can be composed arbitrarily with one another

#efficient computation of full Hessian matrices:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

In [28]:
#using vmap() to vectorize

#pushes loop into a function's primitive operations for better performance
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [29]:
#inefficient: loop over a batch dimension
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
3.65 ms ± 68.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [30]:
#efficient: use jnp.dot to handle extra batch dimensions transparently
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(mat,v_batched.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
41.8 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [31]:
#or just use vmap() to add batching support automatically
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
49.9 µs ± 2.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [32]:
#training a simple neural network, with pytorch data loading - jax tutorial

In [3]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [4]:
#hyperparameters

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
param_scale = 0.1
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))



In [5]:
#autobatching predictions

from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

#predict a single label, 2 layer process
def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

In [6]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [7]:
# Doesn't work with a batch
# batch of 10 images
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

Invalid shapes!


In [8]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


In [9]:
#sorts the classes
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

#check if target matches prediction
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

#check the loss
def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

#call jit to work on the update function, vectorises it
#gradient of the loss with respect to params, inputs, outputs
#updates weights and bias
@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [23]:
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)



[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/tfds/mnist/3.0.1...[0m


HBox(children=(HTML(value='Dl Completed...'), FloatProgress(value=0.0, max=4.0), HTML(value='')))



[1mDataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [24]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


In [25]:
import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 3.18 sec
Training set accuracy 0.9253833293914795
Test set accuracy 0.9271000027656555
Epoch 1 in 2.58 sec
Training set accuracy 0.942799985408783
Test set accuracy 0.9413999915122986
Epoch 2 in 2.57 sec
Training set accuracy 0.9533500075340271
Test set accuracy 0.9516000151634216
Epoch 3 in 2.64 sec
Training set accuracy 0.9599666595458984
Test set accuracy 0.9557999968528748
Epoch 4 in 2.76 sec
Training set accuracy 0.9651333093643188
Test set accuracy 0.9603999853134155
Epoch 5 in 2.71 sec
Training set accuracy 0.9690499901771545
Test set accuracy 0.9631999731063843
Epoch 6 in 2.61 sec
Training set accuracy 0.9726333618164062
Test set accuracy 0.965399980545044
Epoch 7 in 2.69 sec
Training set accuracy 0.9753999710083008
Test set accuracy 0.9667999744415283


In [88]:
#getting started with jax tutorial
#https://roberttlange.github.io/posts/2020/03/blog-post-10/

import numpy as onp
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random

# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

In [89]:
def ReLU(x):
    """ Rectified Linear Unit (ReLU) activation function """
    return np.maximum(0, x)

jit_ReLU = jit(ReLU)

In [90]:
%time out = ReLU(x).block_until_ready()
# Call jitted version to compile for evaluation time!
%time jit_ReLU(x).block_until_ready()
%time out = jit_ReLU(x).block_until_ready()

CPU times: user 863 µs, sys: 461 µs, total: 1.32 ms
Wall time: 838 µs
CPU times: user 17.2 ms, sys: 1.52 ms, total: 18.8 ms
Wall time: 17.7 ms
CPU times: user 345 µs, sys: 123 µs, total: 468 µs
Wall time: 363 µs


In [91]:
#check that the gradient is correct
def FiniteDiffGrad(x):
    """ Compute the finite difference derivative approx for the ReLU"""
    return np.array((ReLU(x + 1e-3) - ReLU(x - 1e-3)) / (2 * 1e-3))

# Compare the Jax gradient with a finite difference approximation
print("Jax Grad: ", jit(grad(jit(ReLU)))(2.))
print("FD Gradient:", FiniteDiffGrad(2.))

Jax Grad:  1.0
FD Gradient: 0.99998707


In [92]:
#batch of 32 vectors each with 100 features, process it w 512 hidden units and relu activation
batch_dim = 32
feature_dim = 100
hidden_dim = 512

# Generate a batch of vectors to process, dim (batch, feature)
#need key for jax random generation
X = random.normal(key, (batch_dim, feature_dim))

# Generate Gaussian weights and biases
# params = [(hidden weights, output weights), (hidden bias)]
params = [random.normal(key, (hidden_dim, feature_dim)),
          random.normal(key, (hidden_dim, ))]

def relu_layer(params, x):
    """ Simple ReLu layer for single sample """
    return ReLU(np.dot(params[0], x) + params[1])

def batch_version_relu_layer(params, X):
    """ Error prone batch version """
    return ReLU(np.dot(X, params[0].T) + params[1])

#vmap wraps relu_layer and takes the axis over which to batch
#batch dimension (0)
#output axes stacks individual sample outputs also as 0
def vmap_relu_layer(params, x):
    """ vmap version of the ReLU layer """
    return jit(vmap(relu_layer, in_axes=(None, 0), out_axes=0))

out = np.stack([relu_layer(params, X[i, :]) for i in range(X.shape[0])])
out = batch_version_relu_layer(params, X)
out = vmap_relu_layer(params, X)

In [93]:
from jax.scipy.special import logsumexp
from jax.experimental import optimizers
import numpy as onp

import torch
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader

import time

import tensorflow_datasets as tfds

batch_size = 100

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = jnp.reshape(train_labels, (len(train_labels), ))
#reshape each section, then merge the tensors into a dataset
train_data = TensorDataset(torch.Tensor(train_images), torch.Tensor(train_labels))
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = jnp.reshape(test_labels, (len(test_labels), ))
test_data = TensorDataset(torch.Tensor(test_images), torch.Tensor(test_labels))
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [94]:
def initialize_mlp(sizes, key):
    """ Initialize the weights of all layers of a linear layer network """
    keys = random.split(key, len(sizes))
    # Initialize a single layer with Gaussian weights -  helper function
    def initialize_layer(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
    return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
# Return a list of tuples of layer weights (weight,bias)
params = initialize_mlp(layer_sizes, key)

In [95]:
def forward_pass(params, in_array):
    """ Compute the forward pass for each example individually """
    activations = in_array

    # Loop over the ReLU hidden layers
    for w, b in params[:-1]:
        activations = relu_layer([w, b], activations)

    # Perform final trafo to logits
    final_w, final_b = params[-1]
    logits = np.dot(final_w, activations) + final_b
    #return log of softmax output
    return logits - logsumexp(logits)

# Make a batched version of the `predict` function
batch_forward = vmap(forward_pass, in_axes=(None, 0), out_axes=0)

In [96]:
#multi-class cross-entropy loss between one-hot encoded class labels and softmax output
def one_hot(x, k, dtype=np.float32):
    """Create a one-hot encoding of x of size k """
    return np.array(x[:, None] == np.arange(k), dtype)

def loss(params, in_arrays, targets):
    """ Compute the multi-class cross-entropy loss """
    preds = batch_forward(params, in_arrays)
    return -np.sum(preds * targets)

def accuracy(params, data_loader):
    """ Compute the accuracy for a provided dataloader """
    acc_total = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        images = np.array(data).reshape(data.size(0), 28*28)
        targets = one_hot(np.array(target), num_classes)
        target_class = np.argmax(targets, axis=1)
        predicted_class = np.argmax(batch_forward(params, images), axis=1)
        acc_total += np.sum(predicted_class == target_class)
    return acc_total/len(data_loader.dataset)

In [97]:
@jit
def update(params, x, y, opt_state):
    """ Compute the gradient for a batch and update the parameters """
    value, grads = value_and_grad(loss)(params, x, y)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, value

# Defining an optimizer in Jax
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)

num_epochs = 10
num_classes = 10

In [98]:
def run_mnist_training_loop(num_epochs, opt_state, net_type="MLP"):
    """ Implements a learning loop over epochs. """
    # Initialize placeholder for loggin
    log_acc_train, log_acc_test, train_loss = [], [], []

    # Get the initial set of parameters
    params = get_params(opt_state)

    # Get initial accuracy after random init
    train_acc = accuracy(params, train_loader)
    test_acc = accuracy(params, test_loader)
    log_acc_train.append(train_acc)
    log_acc_test.append(test_acc)

    # Loop over the training epochs
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch_idx, (data, target) in enumerate(train_loader):
            if net_type == "MLP":
                # Flatten the image into 784 vectors for the MLP
                x = np.array(data).reshape(data.size(0), 28*28)
            elif net_type == "CNN":
                # No flattening of the input required for the CNN
                x = np.array(data)
            y = one_hot(np.array(target), num_classes)
            params, opt_state, loss = update(params, x, y, opt_state)
            train_loss.append(loss)

        epoch_time = time.time() - start_time
        train_acc = accuracy(params, train_loader)
        test_acc = accuracy(params, test_loader)
        log_acc_train.append(train_acc)
        log_acc_test.append(test_acc)
        print("Epoch {} | T: {:0.2f} | Train A: {:0.3f} | Test A: {:0.3f}".format(epoch+1, epoch_time,
                                                                    train_acc, test_acc))

    return train_loss, log_acc_train, log_acc_test


train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,
                                                          opt_state,
                                                          net_type="MLP")

Epoch 1 | T: 3.46 | Train A: 0.975 | Test A: 0.970
Epoch 2 | T: 3.18 | Train A: 0.986 | Test A: 0.975
Epoch 3 | T: 2.98 | Train A: 0.988 | Test A: 0.974
Epoch 4 | T: 3.52 | Train A: 0.991 | Test A: 0.974
Epoch 5 | T: 3.02 | Train A: 0.993 | Test A: 0.975
Epoch 6 | T: 3.01 | Train A: 0.992 | Test A: 0.976
Epoch 7 | T: 3.09 | Train A: 0.996 | Test A: 0.977
Epoch 8 | T: 2.79 | Train A: 0.995 | Test A: 0.978
Epoch 9 | T: 2.92 | Train A: 0.995 | Test A: 0.975
Epoch 10 | T: 3.22 | Train A: 0.996 | Test A: 0.977


In [82]:
#CONVOLUTIONS

from jax.experimental import stax
from jax.experimental.stax import (BatchNorm, Conv, Dense, Flatten,
                                   Relu, LogSoftmax)

#stack of convolutional layers with batchnorm and ReLU activation after each
init_fun, conv_net = stax.serial(Conv(32, (5, 5), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(32, (5, 5), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(10, (3, 3), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(10, (3, 3), (2, 2), padding="SAME"), Relu,
                                 Flatten,
                                 Dense(num_classes),
                                 LogSoftmax)

_, params = init_fun(key, (batch_size, 1, 28, 28))

In [85]:
def accuracy(params, data_loader):
    """ Compute the accuracy for the CNN case (no flattening of input)"""
    acc_total = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        images = np.array(data)
        targets = one_hot(np.array(target), num_classes)
        target_class = np.argmax(targets, axis=1)
        print(conv_net(params,images))
        predicted_class = np.argmax(conv_net(params, images), axis=1)
        acc_total += np.sum(predicted_class == target_class)
    return acc_total/len(data_loader.dataset)

def loss(params, images, targets):
    preds = conv_net(params, images)
    return -np.sum(preds * targets)

In [86]:
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)
num_epochs = 10

train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,
                                                          opt_state,
                                                          net_type="CNN")

TypeError: convolution requires lhs and rhs ndim to be equal, got 2 and 4.

In [None]:
##stax attempt

from jax.experimental import stax
from jax.experimental.stax import (BatchNorm, Conv, Dense, Flatten,
                                   Relu, LogSoftmax)

#stack of convolutional layers with batchnorm and ReLU activation after each
init_fun, conv_net = stax.serial(Conv(32, (5, 5), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(32, (5, 5), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(10, (3, 3), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(10, (3, 3), (2, 2), padding="SAME"), Relu,
                                 Flatten,
                                 Dense(num_classes),
                                 LogSoftmax)

_, params = init_fun(key, (batch_size, 1, 28, 28))