In [None]:
import datasets


In [None]:
import time

import numpy.random as npr

from jax import jit, grad
from jax.scipy.special import logsumexp
import jax.numpy as jnp

In [None]:
def init_random_params(scale,layer_sizes,rng=npr.RandomState(0)):
  return [(scale*rng.randn(m,n),scale*rng.randn(n)) for m,n in zip(layer_sizes[:-1],layer_sizes[1:])]

In [None]:
def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    outputs = jnp.dot(activations, w) + b
    activations = jnp.tanh(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(activations, final_w) + final_b
  return logits - logsumexp(logits, axis=1, keepdims=True)

In [None]:
def loss(params,batch):
  inputs,targets=batch
  preds=predict(params,inputs)
  return -jnp.mean(jnp.sum(preds*targets,axis=1))

In [None]:
def accuracy(params,batch):
  inputs,targets=batch
  target_class=jnp.argmax(targets,axis=1)
  predicted_class=jnp.argmax(predict(params,inputs),axis=1)
  return jnp.mean(predicted_class==target_class)


In [None]:
if __name__ == "__main__":
  layer_sizes = [784, 1024, 1024, 10]
  param_scale = 0.1
  step_size = 0.001
  num_epochs = 10
  batch_size = 128

  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, batch_size)
  num_batches = num_complete_batches + bool(leftover)

downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/


In [None]:
list1=[]
def data_stream():
  rng=npr.RandomState(0)
  while True:
    perm=rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx=perm[i*batch_size:(i+1)*batch_size]
      yield train_images[batch_idx],train_labels[batch_idx]
batches=data_stream()

In [None]:
batches

<generator object data_stream at 0x7f52d78d46d0>

In [None]:
num_batches

469

In [None]:
list1

[]

In [None]:
@jit
def update(params,batch):
  grads = grad(loss)(params, batch)
  return [(w-step_size*dw,b-step_size*db) for (w,b),(dw,db) in zip(params,grads)]

In [None]:
  params = init_random_params(param_scale, layer_sizes)
  for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
      params = update(params, next(batches))
    epoch_time = time.time() - start_time
    train_acc = accuracy(params, (train_images, train_labels))
    test_acc = accuracy(params, (test_images, test_labels))
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 9.01 sec
Training set accuracy 0.7352833151817322
Test set accuracy 0.739300012588501
Epoch 1 in 8.93 sec
Training set accuracy 0.814716637134552
Test set accuracy 0.8192999958992004
Epoch 2 in 8.95 sec
Training set accuracy 0.8458666801452637
Test set accuracy 0.848800003528595
Epoch 3 in 8.94 sec
Training set accuracy 0.8645166754722595
Test set accuracy 0.8659999966621399
Epoch 4 in 8.91 sec
Training set accuracy 0.8768666386604309
Test set accuracy 0.8772000074386597
Epoch 5 in 8.90 sec
Training set accuracy 0.8859000205993652
Test set accuracy 0.8862000107765198
Epoch 6 in 8.89 sec
Training set accuracy 0.8929833173751831
Test set accuracy 0.8919000029563904
Epoch 7 in 8.92 sec
Training set accuracy 0.8984333276748657
Test set accuracy 0.8959000110626221
Epoch 8 in 8.88 sec
Training set accuracy 0.9036499857902527
Test set accuracy 0.9010999798774719
Epoch 9 in 8.96 sec
Training set accuracy 0.9073333144187927
Test set accuracy 0.9043999910354614
