In [None]:
import os

from collections.abc import Iterator
from functools import partial

from absl import app
from dotenv import load_dotenv
import haiku as hk
import jax
import jax.numpy as jnp
import neptune
import numpy as np
import optax
from tqdm import tqdm
from typing import Dict, NamedTuple, Tuple, Callable

from learntrix.datalaoders.computer_vision.mnist import load_mnist_dataset

from learntrix.types import Batch, TrainingState, Metrics

# Set default device to CPU for JAX
jax.config.update("jax_platform_name", "cpu")


In [None]:
# Get devices if any
devices = jax.devices("cpu")
num_devices = len(devices)
print(f"Detected the following devices: {tuple(devices)}")

In [None]:

def load_env_variable(path, name):
    load_dotenv(path)
    variable = os.getenv(name)
    return variable

def run_neptune(path, project):
    """
    path: path of env file with Neptune token
    neptune_project: name of the neptune project
    """
    api_token = load_env_variable(path=path, name='NEPTUNE_API_TOKEN')

    run = neptune.init_run(
        project=project,
        api_token=api_token,
    )

    return run

run = run_neptune(path='./.env', project="yanisadel/learn-jax")

params = {"learning_rate": 0.001, "optimizer": "Adam"}
run["parameters"] = params

In [None]:
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 10000

data_train = load_mnist_dataset(
    "train",
    shuffle=True, 
    batch_size=BATCH_SIZE_TRAIN
    )
data_test = load_mnist_dataset(
    "test",
    shuffle=False, 
    batch_size=BATCH_SIZE_TEST
    )

In [None]:
NUM_CLASSES = 10

def forward_fn(x: jax.Array) -> jax.Array:
    x = x.astype(jnp.float32) / 255.
    mlp = hk.Sequential([
        hk.Flatten(),
        hk.Linear(300), jax.nn.relu,
        hk.Linear(100), jax.nn.relu,
        hk.Linear(10),
    ])
    return mlp(x)

def loss_fn(params: hk.Params, batch: NamedTuple, forward_fn: Callable[[jax.Array], jax.Array]) -> Tuple[jax.Array, Metrics]:
    logits = forward_fn(params, batch.image)
    labels = jax.nn.one_hot(batch.label, NUM_CLASSES)
    loss_value = -jnp.sum(labels * jax.nn.log_softmax(logits)) / batch.image.shape[0]

    predictions_proba = jax.nn.softmax(logits, axis=-1)
    predictions = jnp.argmax(predictions_proba, axis=-1)
    accuracy = jnp.mean(predictions == batch.label)
    pmean_accuracy = jax.lax.pmean(accuracy, axis_name='batch') # Faudra que je vérifie quand num_devices > 1

    metrics: Metrics = {
        "predictions": predictions_proba,
        "loss": loss_value,
        "accuracy": pmean_accuracy
    }

    return loss_value, metrics

In [None]:
class Trainer:
    def __init__(self, forward_fn, loss_fn, optimizer, num_classes):
        self._forward_fn = hk.without_apply_rng(hk.transform(forward_fn))
        self._loss_fn = partial(loss_fn,
                                forward_fn=self._forward_fn.apply)
        self._optimizer = optimizer
        self._num_classes = num_classes
        
    
    def init(self, random_key, x):
        params = self._forward_fn.init(random_key, x)
        opt_state = self._optimizer.init(params)

        return TrainingState(params, opt_state)
    
    @partial(jax.jit, static_argnums=0)
    def update(self, state: TrainingState, batch: Batch) -> Tuple[TrainingState, Metrics]:
        (_, metrics), grads = jax.value_and_grad(self._loss_fn, has_aux=True)(state.params, batch)
        grads = jax.lax.pmean(grads, axis_name='batch')
        updates, opt_state = self._optimizer.update(grads, state.opt_state)
        params = optax.apply_updates(state.params, updates)

        return TrainingState(params, opt_state), metrics
    
    @partial(jax.jit, static_argnums=0)
    def evaluate(self, state: TrainingState, batch: Batch):
        _, metrics = self._loss_fn(state.params, batch)
        # predictions_proba = metrics["predictions"]
        # predictions = jnp.argmax(predictions_proba, axis=-1)
        # accuracy = jnp.mean(predictions == batch.label)
        # pmean_accuracy = jax.lax.pmean(accuracy, axis_name='batch') #

        return metrics
    

In [None]:
trainer = Trainer(
    forward_fn=forward_fn,
    loss_fn=loss_fn,
    optimizer=optax.adam(learning_rate=1e-3),
    num_classes=10
    )

In [None]:
training_state = trainer.init(
    jax.random.PRNGKey(0), 
    x=jnp.ones(shape=(32, 28, 28, 1))
    )
training_state = jax.device_put_replicated(training_state, devices)

In [None]:
batch_test = next(data_test)
batch_test = jax.device_put_replicated(batch_test, devices)

all_metrics = {
    "train_loss": [],
    "train_acc": [],
    "train_step": [],
    "val_loss": [],
    "val_acc": [],
    "val_step": [],
}


In [None]:
NUM_STEPS = 100
validation_step = 10

update = trainer.update

for step in tqdm(range(1, NUM_STEPS+1)):
    batch = next(data_train)
    # batch = jax.tree.map(lambda x: x.reshape((num_devices, -1) + x.shape[1:]), batch)
    batch = jax.device_put_replicated(batch, devices)

    training_state, metrics = jax.pmap(update, devices=devices, axis_name='batch')(training_state, batch)
    
    train_loss = jax.device_get(metrics["loss"]).mean()
    train_accuracy = jax.device_get(metrics["accuracy"]).mean()

    run["train/loss"].log(train_loss, step=step)
    run["train/accuracy"].log(train_accuracy, step=step)

    all_metrics["train_loss"].append(train_loss)
    all_metrics["train_acc"].append(train_accuracy)
    all_metrics["train_step"].append(step)

    if step % validation_step == 0:
        # Evaluate on test batch
        
        # batch_test = jax.tree.map(lambda x: x.reshape((num_devices, -1) + x.shape[1:]), batch_test)
        metrics = jax.pmap(trainer.evaluate, devices=devices, axis_name='batch')(training_state, batch_test)
        test_loss = jax.device_get(metrics["loss"]).mean()
        test_accuracy = jax.device_get(metrics["accuracy"]).mean()

        run["test/loss"].log(test_loss, step=step)
        run["test/accuracy"].log(test_accuracy, step=step)

        all_metrics["val_loss"].append(test_loss)
        all_metrics["val_acc"].append(test_accuracy)
        all_metrics["val_step"].append(step)


In [None]:
run.stop()