<a href="https://colab.research.google.com/github/present42/PyTorchPractice/blob/main/Following_Jax_Tutorial_Training_a_Simple_Neural_Network%2C_with_PyTorch_Data_Loading.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

# Hyperparameters

In [2]:
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

In [3]:
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, key)
    for m, n, key in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))

## Auto-batching predictions

Let us define prediction function for a *single* image example.

In [4]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(x, 0)

def predict(params, image):
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

In [5]:
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [6]:
random_flattened_images = random.normal(random.key(42), (10, 28 * 28))

# Make a batched version of the predict function
batched_predict = vmap(predict, in_axes=(None, 0))
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


# Utility and loss functions

In [7]:
def one_hot(x, k, dtype=jnp.float32):
  """ assume x is an array of integers of size k"""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

In [8]:
def accuracy(params, images, targets):
  """ assumes targets are one-hot vectors (size: N x K)
      where N is batch size, K is # of classes
  """
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=-1)
  return jnp.mean(predicted_class == target_class)

In [9]:
def loss(params, images, targets):
  """ cross entropy loss """
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

In [10]:
@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
    for (w, b), (dw, db) in zip(params, grads)]

## Data Loading with PyTorch

In [11]:
import numpy as np
from torch.utils import data
from torchvision.datasets import MNIST
import jax

In [12]:
def numpy_collate(batch):
  return jax.tree.map(np.asarray,
                      data.default_collate(batch))

In [13]:
class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
               shuffle=False, sampler=None,
               batch_sampler=None, num_workers=0,
               pin_memory=False, drop_last=False,
               timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

In [14]:
class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [15]:
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 113097642.45it/s]


Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 22230811.86it/s]

Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 31995056.31it/s]


Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3709937.44it/s]

Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw






In [16]:
# get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1)
train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)

In [17]:
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)

# Training Loop

In [None]:
import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)

  print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
  print(f"Training set accuracy: {train_acc}")
  print(f"Test set accuracy: {test_acc}")

Epoch 0 in 5.09 sec
Training set accuracy: 0.9158166646957397
Test set accuracy: 0.9194999933242798
Epoch 1 in 4.39 sec
Training set accuracy: 0.9371500015258789
Test set accuracy: 0.9383999705314636
Epoch 2 in 3.61 sec
Training set accuracy: 0.9491333365440369
Test set accuracy: 0.9467999935150146
Epoch 3 in 3.62 sec
Training set accuracy: 0.9568833708763123
Test set accuracy: 0.9531999826431274
Epoch 4 in 4.37 sec
Training set accuracy: 0.9630500078201294
Test set accuracy: 0.9575999975204468
