In [None]:
from collections.abc import Iterator
from typing import NamedTuple

from absl import app
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds

# Set default device to CPU for JAX
jax.config.update("jax_platform_name", "cpu")

In [None]:
# Get devices if any
devices = jax.devices("cpu")
num_devices = len(devices)
print(f"Detected the following devices: {tuple(devices)}")

In [None]:
class Batch(NamedTuple):
  image: np.ndarray  # [B, H, W, 1]
  label: np.ndarray  # [B]

class TrainingState(NamedTuple):
  params: hk.Params
  opt_state: optax.OptState

In [None]:
def load_dataset(
    split: str,
    *,
    shuffle: bool,
    batch_size: int,
) -> Iterator[Batch]:
  """Loads the MNIST dataset."""
  ds, ds_info = tfds.load("mnist:3.*.*", split=split, with_info=True)
  ds.cache()
  if shuffle:
    ds = ds.shuffle(ds_info.splits[split].num_examples, seed=0)
  ds = ds.repeat()
  ds = ds.batch(batch_size)
  ds = ds.map(lambda x: Batch(**x))
  return iter(tfds.as_numpy(ds))

In [None]:
def main():
    def forward_fn(x: jax.Array) -> jax.Array:
        x = x.astype(jnp.float32) / 255.
        mlp = hk.Sequential([
            hk.Flatten(),
            hk.Linear(300), jax.nn.relu,
            hk.Linear(100), jax.nn.relu,
            hk.Linear(10),
        ])
        return mlp(x)
    
    def loss_fn(params, batch):
        logits = network.apply(params, batch.image)
        labels = jax.nn.one_hot(batch.label, NUM_CLASSES)
        loss_value = -jnp.sum(labels * jax.nn.log_softmax(logits))
        return loss_value
    
    @jax.jit
    def update(state, batch):
        grads = jax.grad(loss_fn)(state.params, batch)
        updates, opt_state = optimizer.update(grads, state.opt_state)
        params = optax.apply_updates(state.params, updates)

        return TrainingState(params, opt_state)
    
    @jax.jit
    def evaluate(state, batch):
        logits = network.apply(state.params, batch.image)
        predictions = jnp.argmax(logits, axis=-1)
        return jnp.mean(predictions == batch.label)

    NUM_CLASSES = 10
    BATCH_SIZE_TRAIN = 64
    BATCH_SIZE_TEST = 10000

    print("Initializing network..")
    network = hk.without_apply_rng(hk.transform(forward_fn))

    print("Initializing data..")
    data_train = load_dataset("train", shuffle=True, batch_size=BATCH_SIZE_TRAIN)
    data_test = load_dataset("test", shuffle=False, batch_size=BATCH_SIZE_TEST)

    print("Initializing parameters..")
    params = network.init(jax.random.PRNGKey(seed=0), next(data_train).image)
    optimizer = optax.adam(learning_rate=1e-4)
    opt_state = optimizer.init(params)
    state = TrainingState(params, opt_state)

    batch_test = next(data_test)

    print("Training..")
    NUM_EPOCHS = 5
    for epoch in range(NUM_EPOCHS):
        print(f"--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
        for i in range(100):
            batch = next(data_train)
            state = update(state, batch)
        
        accuracy = evaluate(state, batch_test)
        print("Accuracy : ", accuracy)

main()