In [None]:
from absl import app
from collections.abc import Iterator
from typing import NamedTuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds

from functools import partial

from typing import Dict
from typing_extensions import TypeAlias
from typing import Callable, List, Optional, Tuple

import neptune

from tqdm import tqdm

from dotenv import load_dotenv

import os

# Set default device to CPU for JAX
jax.config.update("jax_platform_name", "cpu")

In [None]:
# Get devices if any
devices = jax.devices("cpu")
num_devices = len(devices)
print(f"Detected the following devices: {tuple(devices)}")

In [None]:
load_dotenv('.env')
api_token = os.getenv("NEPTUNE_API_TOKEN")

In [None]:
run = neptune.init_run(
    project="yanisadel/learn-jax",
    api_token=api_token,
)

params = {"learning_rate": 0.001, "optimizer": "Adam"}
run["parameters"] = params

In [None]:
class Batch(NamedTuple):
  image: np.ndarray  # [B, H, W, 1]
  label: np.ndarray  # [B]

class TrainingState(NamedTuple):
  params: hk.Params
  opt_state: optax.OptState

Metrics: TypeAlias = Dict[str, jnp.ndarray]

In [None]:
def load_dataset(
    split: str,
    *,
    shuffle: bool,
    batch_size: int,
) -> Iterator[Batch]:
  """Loads the MNIST dataset."""
  ds, ds_info = tfds.load("mnist:3.*.*", split=split, with_info=True)
  ds.cache()
  if shuffle:
    ds = ds.shuffle(ds_info.splits[split].num_examples, seed=0)
  ds = ds.repeat()
  ds = ds.batch(batch_size)
  ds = ds.map(lambda x: Batch(**x))
  return iter(tfds.as_numpy(ds))

In [None]:
def main():
    def forward_fn(x: jax.Array) -> jax.Array:
        x = x.astype(jnp.float32) / 255.
        mlp = hk.Sequential([
            hk.Flatten(),
            hk.Linear(300), jax.nn.relu,
            hk.Linear(100), jax.nn.relu,
            hk.Linear(10),
        ])
        return mlp(x)
    
    def loss_fn(params: hk.Params, batch: NamedTuple):
        logits = network.apply(params, batch.image)
        labels = jax.nn.one_hot(batch.label, NUM_CLASSES)
        loss_value = -jnp.sum(labels * jax.nn.log_softmax(logits)) / batch.image.shape[0]

        predictions_proba = jax.nn.softmax(logits, axis=-1)
        predictions = jnp.argmax(predictions_proba, axis=-1)
        accuracy = jnp.mean(predictions == batch.label)
        pmean_accuracy = jax.lax.pmean(accuracy, axis_name='batch') # Faudra que je vérifie quand num_devices > 1

        metrics: Metrics = {
            "predictions": predictions_proba,
            "loss": loss_value,
            "accuracy": pmean_accuracy
        }

        return loss_value, metrics

    @partial(jax.pmap, devices=devices, axis_name='batch')
    def update(state: NamedTuple, batch: NamedTuple) -> Tuple[NamedTuple, Metrics]:
        (_, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params, batch)
        grads = jax.lax.pmean(grads, axis_name='batch')
        updates, opt_state = optimizer.update(grads, state.opt_state)
        params = optax.apply_updates(state.params, updates)

        return TrainingState(params, opt_state), metrics
    
    @partial(jax.pmap, devices=devices, axis_name='batch')
    def evaluate(state: NamedTuple, batch: NamedTuple):
        _, metrics = loss_fn(state.params, batch)
        # predictions_proba = metrics["predictions"]
        # predictions = jnp.argmax(predictions_proba, axis=-1)
        # accuracy = jnp.mean(predictions == batch.label)
        # pmean_accuracy = jax.lax.pmean(accuracy, axis_name='batch') # log

        return metrics

    NUM_CLASSES = 10
    BATCH_SIZE_TRAIN = 64
    BATCH_SIZE_TEST = 10000

    print("Initializing network..")
    network = hk.without_apply_rng(hk.transform(forward_fn))

    print("Initializing data..")
    data_train = load_dataset("train", shuffle=True, batch_size=BATCH_SIZE_TRAIN)
    data_test = load_dataset("test", shuffle=False, batch_size=BATCH_SIZE_TEST)

    print("Initializing parameters..")
    params = network.init(jax.random.PRNGKey(seed=0), next(data_train).image)
    optimizer = optax.adam(learning_rate=1e-4)
    opt_state = optimizer.init(params)
    state = TrainingState(params, opt_state)

    # Replicate parameters and optimizer state across devices
    state = jax.device_put_replicated(state, devices)

    batch_test = next(data_test)
    batch_test = jax.device_put_replicated(batch_test, devices)

    all_metrics = {
        "train_loss": [],
        "train_acc": [],
        "train_step": [],
        "val_loss": [],
        "val_acc": [],
        "val_step": [],
    }
    
    print("Training..")
    NUM_STEPS = 100
    validation_step = 10

    for step in tqdm(range(1, NUM_STEPS+1)):
        batch = next(data_train)
        # batch = jax.tree.map(lambda x: x.reshape((num_devices, -1) + x.shape[1:]), batch)
        batch = jax.device_put_replicated(batch, devices)
        state, metrics = update(state, batch)
        
        train_loss = jax.device_get(metrics["loss"]).mean()
        train_accuracy = jax.device_get(metrics["accuracy"]).mean()

        run["train/loss"].log(train_loss, step=step)
        run["train/accuracy"].log(train_accuracy, step=step)

        all_metrics["train_loss"].append(train_loss)
        all_metrics["train_acc"].append(train_accuracy)
        all_metrics["train_step"].append(step)

        if step % validation_step == 0:
            # Evaluate on test batch
            
            # batch_test = jax.tree.map(lambda x: x.reshape((num_devices, -1) + x.shape[1:]), batch_test)
            metrics = evaluate(state, batch_test)
            test_loss = jax.device_get(metrics["loss"]).mean()
            test_accuracy = jax.device_get(metrics["accuracy"]).mean()

            run["test/loss"].log(test_loss, step=step)
            run["test/accuracy"].log(test_accuracy, step=step)

            all_metrics["val_loss"].append(test_loss)
            all_metrics["val_acc"].append(test_accuracy)
            all_metrics["val_step"].append(step)

    run.stop()

main()