# Fundamental Concepts of Jax/Flax

In [None]:
!pip install jax jaxlib flax 

Collecting flax
[?25l  Downloading https://files.pythonhosted.org/packages/63/1f/63c720200f7a679d9fd408eb2641960d6fa99030c874ecba24091e694f91/flax-0.3.3-py3-none-any.whl (179kB)
[K     |████████████████████████████████| 184kB 6.2MB/s 
Installing collected packages: flax
Successfully installed flax-0.3.3


## jax.numpy and numpy
---


In [None]:
import numpy as np
import jax.numpy as jnp

In [None]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32) 


In [None]:
jax_array[0,0]=1 # This will throw an error

TypeError: ignored

In [None]:
from jax.ops import index, index_add, index_update

In [None]:
new_jax_array = index_update(jax_array, index[1, :], 1.)
new_jax_array

DeviceArray([[0., 0., 0.],
             [1., 1., 1.],
             [0., 0., 0.]], dtype=float32)

In [None]:
new_jax_array = index_add(new_jax_array, index[1, :], 1)
new_jax_array

DeviceArray([[0., 0., 0.],
             [2., 2., 2.],
             [0., 0., 0.]], dtype=float32)

In [None]:
from jax import random
key = random.PRNGKey(0)

In [None]:
print(random.normal(key, shape=(1,)))
# Unlike numpy, this will give the exact same results;
print(random.normal(key, shape=(1,)))

[-0.20584235]
[-0.20584235]


In [None]:
key, subkey, _ = random.split(key, 3)
print(random.normal(key, shape=(1,)))
print(random.normal(subkey, shape=(1,)))

[1.1188383]
[0.5781487]


In [None]:
jnp.add(jax_array, 1) # works

DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32)

In [None]:
from jax import lax
lax.add(jax_array, 1) # nope

TypeError: ignored

In [None]:
lax.add(jax_array, 1.0)

DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32)

In [None]:
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype # This will not work!


dtype('float32')

In [None]:
# from jax.config import config
# config.update("jax_enable_x64", True)
# Must run the above commands on start-up to use float64.

## JIT simple example
---
This is a simple example of using `jit` functionality in Jax. Note that if you run the code twice, things will be different.

In [None]:
import numpy as np
import jax.numpy as jnp
from jax import jit

In [None]:
@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)


In [None]:
f(x, y)

DeviceArray([3.2744508, 6.406835 , 5.2696176], dtype=float32)

### Pure functions and jit

In [None]:
g = 0.
def impure_uses_globals(x):
  return x + g

print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  
print ("Second call: ", jit(impure_uses_globals)(5.))
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

# Question: Why does this happen?

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [None]:
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  


First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


In [None]:
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

45
0


In [None]:
# This works
class Counter:
  """A simple counter."""

  def __init__(self):
    self.n = 0

  def count(self) -> int:
    """Increments the counter and returns the new value."""
    self.n += 1
    return self.n

  def reset(self):
    """Resets the counter to zero."""
    self.n = 0


counter = Counter()

for _ in range(3):
  print(counter.count())

1
2
3


In [None]:
# Doesn't work
counter.reset()
fast_count = jit(counter.count)

for _ in range(3):
  print(fast_count())

1
1
1


In [None]:
from typing import Tuple

CounterState = int

class CounterV2:

  def count(self, n: CounterState) -> Tuple[int, CounterState]:
    return n+1, n+1

  def reset(self) -> CounterState:
    return 0

counter = CounterV2()
state = counter.reset()

fast_count = jax.jit(counter.count)

for _ in range(3):
  value, state = fast_count(state)
  print(value)

1
2
3


### Tracing and static variables/operations in jit


In [None]:
# This will NOT work!
@jit
def f(x, neg):
  return -x if neg else x

f(1, True)

ConcretizationTypeError: ignored

In [None]:
# This works
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
    print("Need to compile")
    return -x if neg else x

In [None]:
f(1, True)
f(2, False)

Need to compile


DeviceArray(2, dtype=int32)

In [None]:
# This doesn't work
@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)

ConcretizationTypeError: ignored

In [None]:
# To understand why, it is good practice to print which is traced and which is not
@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)

x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=0/1)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>


In [None]:
# Just replace it with numpy solves the problem
@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)

DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)

### Control flow

In [None]:
# lax.cond is equivalent to 
# def cond(pred, true_fun, false_fun, operand):
#   if pred:
#     return true_fun(operand)
#   else:
#     return false_fun(operand)

In [None]:
@jit
def f(x, neg):
  return lax.cond(neg, 
                  lambda inp: inp, 
                  lambda inp: -inp, 
                  x)
print(f(x, True))
print(f(x, False))

[[1. 1. 1.]
 [1. 1. 1.]]
[[-1. -1. -1.]
 [-1. -1. -1.]]


In [None]:
# The definition of scan
# def scan(f, init, xs, length=None):
#   if xs is None:
#     xs = [None] * length
#   carry = init
#   ys = []
#   for x in xs:
#     carry, y = f(carry, x)
#     ys.append(y)
#   return carry, np.stack(ys)

# Example: [1,2,3,4,5,6,7,8,9,10]
# f = lambda x, y:x+y, x+y
# init = 0
# xs = [1,2,3,4,5,6,7,8,9,10]
# carry = 0
# ys = []
# carry, y = 1, 1
# ys = [1]
# carry, y = 3, 3
# ys = [1,3]
# And the code continues

In [None]:
# Exercise: rewrite the following code using scan
def cum(x):
    for i in range(10):
        x+=10
    return x

In [None]:
def cum_v2(x):
    def func_inner(inp, y):
        return inp+10, None
    return lax.scan(func_inner, x, None, 10)

In [None]:
cum_v2(0)

(DeviceArray(100, dtype=int32), None)

## PyTree

In [None]:
from jax import tree_map, tree_multimap
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

tree_map(lambda x: x*2, list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [None]:
another_list_of_lists = list_of_lists
tree_multimap(lambda x, y: x+y*2, list_of_lists, another_list_of_lists)

[[3, 6, 9], [3, 6], [3, 6, 9, 12]]

## vmap and pmap

In [None]:
def predict(W, b, input_vec):
    activations = input_vec
    outputs = jnp.dot(W, activations) + b  # `input_vec` on the right-hand side!
    activations = jnp.tanh(outputs)
    return outputs

In [None]:
from jax import vmap
W, b = jnp.ones((10, 5)), jnp.ones(1)
input_batch = jnp.ones((64,5))
predictions = vmap(predict, in_axes=(None, None, 0))(W, b, input_batch)
print(predictions)

# Implement Autograd using Jax

## Autograd basics
---
In general, since we only care about autograd function, we only need to define the so-called Jacobian-vector product (jvp). To understand what is jvp, consider the following approximation of an arbitrary function.

$f(x+v) \approx f(x) + \partial f v$.

Here, $\partial f v$ is the jvp. In most cases, it suffices to to define jvp in order to get complete autograd machenism to work. 



In [None]:
def toy1(w, b):
    return jnp.dot(w, b).sum()

In [None]:
w = jnp.ones((3, 3))
b = jnp.ones(3)

toy1(w, b)

DeviceArray(9., dtype=float32)

In [None]:
from jax import grad, value_and_grad
print(grad(toy1)(w, b))

[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]


In [None]:
print(grad(toy1, (0, ))(w, b))

(DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32),)


In [None]:
print(grad(toy1, (1, ))(w, b))

(DeviceArray([3., 3., 3.], dtype=float32),)


In [None]:
print(grad(toy1, (0, 1))(w, b))

(DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32), DeviceArray([3., 3., 3.], dtype=float32))


In [None]:
# A more convenient way is to define the dictionaries. 
def toy2(params):
    return jnp.dot(params['w'], params['b']).sum()

params = dict()
params['w'] = w
params['b'] = b

print(toy2(params))

9.0


In [None]:
print(grad(toy2)(params))

{'b': DeviceArray([3., 3., 3.], dtype=float32), 'w': DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32)}


## Defining our own autograd functions

In [None]:
from jax import custom_jvp

### This is the original function 
@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

### This is the definition of custom jvp 
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
  x, = primals # This is the original input
  x_dot, = tangents # This is the tangents associated with x
  ans = log1pexp(x) # The function value
  ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot # The derivative w.r.t. to times the tangets of x
  return ans, ans_dot

## Entmax-$\alpha$ implementation

In [None]:
import jax.numpy as jnp
from jax import grad, jit, value_and_grad
from jax import vmap, pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp

def p_tau(z, tau, alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau, 0) ** (1 / (alpha - 1))


def get_tau(tau, tau_max, tau_min, z_value):
    return lax.cond(z_value < 1,
                    lambda _: (tau, tau_min),
                    lambda _: (tau_max, tau),
                    operand=None
                    )
    
def body(kwargs, x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z, tau, alpha).sum()
    taus = get_tau(tau, tau_max, tau_min, z_value)
    tau_max, tau_min = taus[0], taus[1]
    return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None

def map_row(z_input, alpha, T):
    z = (alpha - 1) * z_input

    tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
    result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
                         length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z, tau, alpha)
    return result / result.sum()

def _entmax(input, axis=-1, alpha=1.5, T=20):
    result = vmap(lambda z: map_row(z, alpha, T), axis)(input) ## Pay attention here!
    return result

def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)



In [None]:
import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

toy(input, weight)

In [None]:
@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)

def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    input = primals[0]
    Y = entmax(input, axis, alpha, T)
    gppr = Y  ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q, axis=axis)
    dX -= q * gppr
    return Y, dX


@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
    return _entmax_jvp_impl(axis, alpha, T, primals, tangents)

In [None]:
import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

value_and_grad(toy)(input, weight)

In [None]:
import jax.numpy as jnp
from jax import grad, jit, value_and_grad
from jax import vmap, pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp


def p_tau(z, tau, alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau, 0) ** (1 / (alpha - 1))


def get_tau(tau, tau_max, tau_min, z_value):
    return lax.cond(z_value < 1,
                    lambda _: (tau, tau_min),
                    lambda _: (tau_max, tau),
                    operand=None
                    )


def body(kwargs, x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z, tau, alpha).sum()
    taus = get_tau(tau, tau_max, tau_min, z_value)
    tau_max, tau_min = taus[0], taus[1]
    return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None

@jax.partial(jax.jit, static_argnums=(2,))
def map_row(z_input, alpha, T):
    z = (alpha - 1) * z_input

    tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
    result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
                         length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z, tau, alpha)
    return result / result.sum()

@jax.partial(jax.jit, static_argnums=(1,3,))
def _entmax(input, axis=-1, alpha=1.5, T=20):
    result = vmap(lambda z: map_row(z, alpha, T), axis)(input)
    return result

@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)

@jax.partial(jax.jit, static_argnums=(0,2,))    
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    input = primals[0]
    Y = entmax(input, axis, alpha, T)
    gppr = Y  ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q, axis=axis)
    dX -= q * gppr
    return Y, dX


@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
    return _entmax_jvp_impl(axis, alpha, T, primals, tangents)

import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

jax.jit(value_and_grad(toy))(input, weight)

# Flax


## Defining our own networks
---
Let us first define our own neural networks and see if we can run a toy model.

In [None]:
import jax
import flax
from flax import linen as nn
from jax import random
import jax.numpy as jnp

In [None]:
model = nn.Dense(features=1) # The easiest way is to directly create a model that is predefined.

In [None]:
x_key, noise_key, init_key= random.split(random.PRNGKey(0), 3)
dummy = jnp.ones((10, )) # This is only needed to trigger shape inference
params = model.init(init_key, x)  # Note that we can rely on shape inference here

In [None]:
# In reality, flax modules are just wraps around pytrees, as can be seen here.
jax.tree_map(lambda x: x.shape, params) 


In [None]:
# To evaluate the model, use apply
model.apply(params, dummy)

In [None]:
x = random.normal(x_key,(50, 10))
noise = 0.1 * random.normal(noise_key)
y = jnp.dot(x, jnp.ones(10)) + noise

In [None]:
def make_mse_func(x_batched, y_batched):
    def mse(params):
        def squared_error(x, y):
            pred = model.apply(params, x)
            return jnp.inner(y-pred, y-pred)/2.0
        return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
    return jax.jit(mse)

loss = make_mse_func(x, y)

In [None]:
from flax import optim
optimizer_def = optim.GradientDescent(learning_rate=0.1) # Choose the method
optimizer = optimizer_def.create(params) # Create the wrapping optimizer with initial parameters
loss_grad_fn = jax.value_and_grad(loss)

In [None]:
for i in range(101):
  loss_val, grad = loss_grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad) # Return the updated optimizer with parameters. Question why?
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Next, let us see how we can define a multi-layer MLP

In [None]:
from typing import Any, Callable, Sequence, Optional

class MyMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    self.layers = [nn.Dense(feat) for feat in self.features]

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x


In [None]:
mymlp = MyMLP([5,5,1])
params = mymlp.init(init_key, dummy)
jax.tree_map(lambda x:x.shape, params)

However, due to the static graph feature, we can actually even make the definition shorter!

In [None]:
class MyMLP2(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x) # This can be confusing since it is hard to spot the layer definition if one is not careful
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x


In [None]:
mymlp = MyMLP2([5,5,1])
params = mymlp.init(init_key, dummy)
jax.tree_map(lambda x:x.shape, params)

Let us see now how can we define our own layer using module parameters.

In [None]:
class MyDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.xavier_normal()
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, 
                        (inputs.shape[-1], self.features))  
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) 
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

In [None]:
mydense = MyDense(1)
params = mydense.init(init_key, dummy)
jax.tree_map(lambda x: x, params)

In [None]:
class BiasAdderWithRunningMean(nn.Module):
  decay: float = 0.99

  @nn.compact
  def __call__(self, x):
    # easy pattern to detect if we're initializing via empty variable tree
    is_initialized = self.has_variable('batch_stats', 'mean')
    ra_mean = self.variable('batch_stats', 'mean',
                            lambda s: jnp.zeros(s),
                            x.shape[1:])
    mean = ra_mean.value # This will get either the value, or trigger init
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

    return x - ra_mean.value + bias

# MNist Example

In [None]:
import jax
import jax.numpy as jnp            
from flax import linen as nn        
from flax import optim              

import numpy as np                 
import tensorflow_datasets as tfds  

In [None]:
class CNN(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1)) 
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)   
    x = nn.log_softmax(x)
    return x

def cross_entropy_loss(logits, labels):
  one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
  return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

In [None]:
def create_optimizer(params, learning_rate, beta):
  optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta)
  optimizer = optimizer_def.create(params)
  return optimizer

In [None]:
def get_initial_params(key):
  init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)
  initial_params = CNN().init(key, init_shape)['params']
  return initial_params

In [None]:
def compute_metrics(logits, labels):
  loss = cross_entropy_loss(logits, labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics

In [None]:
def get_datasets():
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))

  train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
  return train_ds, test_ds

In [None]:
@jax.jit
def train_step(optimizer, batch):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits, batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grad = grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  metrics = compute_metrics(logits, batch['label'])
  return optimizer, metrics

In [None]:
@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits, batch['label'])

In [None]:
def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # Skip an incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  batch_metrics = []

  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    optimizer, metrics = train_step(optimizer, batch)
    batch_metrics.append(metrics)

  training_batch_metrics = jax.device_get(batch_metrics)
  training_epoch_metrics = {
      k: np.mean([metrics[k] for metrics in training_batch_metrics])
      for k in training_batch_metrics[0]}

  print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))

  return optimizer, training_epoch_metrics

In [None]:
def eval_model(model, test_ds):
  metrics = eval_step(model, test_ds)    # Evalue the model on the test set
  metrics = jax.device_get(metrics)
  eval_summary = jax.tree_map(lambda x: x.item(), metrics)
  return eval_summary['loss'], eval_summary['accuracy']

In [None]:
train_ds, test_ds = get_datasets()

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m




In [None]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

In [None]:
params = get_initial_params(init_rng)

In [None]:
learning_rate = 0.1
beta = 0.9
num_epochs = 10
batch_size = 32

optimizer = create_optimizer(params, learning_rate=learning_rate, beta=beta)

In [None]:
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  optimizer, train_metrics = train_epoch(optimizer, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch 
  test_loss, test_accuracy = eval_model(optimizer.target, test_ds)
  print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

Training - epoch: 1, loss: 0.1314, accuracy: 96.04
Testing - epoch: 1, loss: 0.06, accuracy: 98.10


KeyboardInterrupt: ignored