In [None]:
!pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

from tensorflow import keras

## MLP by JAX

In [None]:
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

In [None]:
def init_network_params(layer_sizes, key):
  keys = random.split(key, len(layer_sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(layer_sizes[:-1], layer_sizes[1:], keys)]

In [None]:
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [None]:
print(f'num. of layers: {len(params)}')
for i in range(len(params)):
  print(f'layer {i}: w={params[i][0].shape}, b={params[i][1].shape}')

In [None]:
def relu(x):
  return jnp.maximum(0, x)

In [None]:
from jax.scipy.special import logsumexp

In [None]:
def predict(params, image):
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  w_last, b_last = params[-1]
  logits = jnp.dot(w_last, activations) + b_last
  return logits - logsumexp(logits)

In [None]:
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)

In [None]:
print(f'preds: {preds.shape}')
print(f'preds values: {preds}')

In [None]:
batched_predict = vmap(predict, in_axes=(None, 0))

In [None]:
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28*28))
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

In [None]:
def one_hot(x, k, dtype=jnp.float32):
  # x[:, None] => add new axis (x.shape=(10,) -> x.shape=(10, 1))
  return jnp.array(x[:, None] == jnp.arange(k), dtype=dtype)

In [None]:
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class) # num. of True / target_class.shape[0] * target_class.shape[1]

## Data preparation

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

In [None]:
print(f'x_train_shape: {x_train.shape} - y_train_shape: {y_train.shape}')
print(f'x_test_shape: {x_test.shape} - y_test_shape: {y_test.shape}')