In [None]:
!pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

from tensorflow import keras

## MLP by JAX

In [None]:
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

In [None]:
def init_network_params(layer_sizes, key):
  keys = random.split(key, len(layer_sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(layer_sizes[:-1], layer_sizes[1:], keys)]

In [None]:
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [None]:
print(f'num. of layers: {len(params)}')

## Data preparation

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

In [None]:
print(f'x_train_shape: {x_train.shape} - y_train_shape: {y_train.shape}')
print(f'x_test_shape: {x_test.shape} - y_test_shape: {y_test.shape}')