In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [10]:
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale*random.normal(key, (n, m)), scale*random.normal(key, (n, ))

def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))

In [17]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):

    activations = image
    for w, b in params[:-1]:
        out = jnp.dot(w, activations) + b
        activations = relu(out)

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b

    return logits - logsumexp(logits)

batched_predict = vmap(predict, in_axes=(None, 0))

In [21]:
def one_hot(x, k, dtype=jnp.float32):
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(target_class == predicted_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds*targets)

@jit
def update(params, images, targets):
    grads = grad(loss)(params, images, targets)
    return [(w - step_size*dw, b - step_size*db) for (w, b), (dw, db) in zip(params, grads)]

In [25]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  return tree_map(jnp.asarray, data.default_collate(batch))

# No need to use NumpyLoader, we just need to convert torch tensors to numpy arras.
# Only utility is that default of few args changed.
class NumpyLoader(data.DataLoader): 
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [26]:
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = data.DataLoader(mnist_dataset, batch_size=batch_size, collate_fn=numpy_collate, num_workers=0)

In [29]:
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)



In [31]:
import time

for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in training_generator:
        targets = one_hot(y, n_targets)
        params = update(params, x, targets) 
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)

    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}\n".format(test_acc))

Epoch 0 in 2.81 sec
Training set accuracy 0.9757000207901001
Test set accuracy 0.9686999917030334

Epoch 1 in 2.85 sec
Training set accuracy 0.9779500365257263
Test set accuracy 0.9702000021934509

Epoch 2 in 2.85 sec
Training set accuracy 0.9797999858856201
Test set accuracy 0.97079998254776

Epoch 3 in 2.61 sec
Training set accuracy 0.9817667007446289
Test set accuracy 0.9723999500274658

Epoch 4 in 2.61 sec
Training set accuracy 0.9832833409309387
Test set accuracy 0.973099946975708

Epoch 5 in 2.60 sec
Training set accuracy 0.98458331823349
Test set accuracy 0.9740999937057495

Epoch 6 in 2.61 sec
Training set accuracy 0.9856833219528198
Test set accuracy 0.9741999506950378

Epoch 7 in 2.61 sec
Training set accuracy 0.9867166876792908
Test set accuracy 0.9745999574661255

