In [None]:
import os

from functools import partial

from absl import app
from dotenv import load_dotenv
import haiku as hk
import jax
import jax.numpy as jnp
import neptune
import numpy as np
import optax
from tqdm import tqdm
from typing import Dict, NamedTuple, Tuple, Callable
from learntrix.training.losses import cross_entropy_loss

from learntrix.dataloaders.computer_vision.mnist import load_mnist_dataset

from learntrix.types import Batch, TrainingState, Metrics

# Set default device to CPU for JAX
jax.config.update("jax_platform_name", "cpu")


In [None]:
# Get devices if any
devices = jax.devices("cpu")
num_devices = len(devices)
print(f"Detected the following devices: {tuple(devices)}")

In [None]:
def load_env_variable(path, name):
    load_dotenv(path)
    variable = os.getenv(name)
    return variable

def run_neptune(path, project):
    """
    path: path of env file with Neptune token
    neptune_project: name of the neptune project
    """
    api_token = load_env_variable(path=path, name='NEPTUNE_API_TOKEN')

    run = neptune.init_run(
        project=project,
        api_token=api_token,
    )

    return run

run = run_neptune(path='./.env', project="yanisadel/learn-jax")

params = {"learning_rate": 0.001, "optimizer": "Adam"}
run["parameters"] = params

In [None]:
data_train = load_mnist_dataset(
    "train",
    shuffle=True, 
    batch_size=64
    )
data_test = load_mnist_dataset(
    "test",
    shuffle=False, 
    batch_size=10000
    )

In [None]:
def forward_fn(x: jax.Array) -> jax.Array:
    x = x.astype(jnp.float32) / 255.
    mlp = hk.Sequential([
        hk.Flatten(),
        hk.Linear(300), jax.nn.relu,
        hk.Linear(100), jax.nn.relu,
        hk.Linear(10),
    ])
    return mlp(x)

In [None]:
class Trainer:
    def __init__(self, forward_fn, loss_fn, optimizer, num_classes):
        self._forward_fn = hk.without_apply_rng(hk.transform(forward_fn))
        self._loss_fn = partial(loss_fn,
                                forward_fn=self._forward_fn.apply,
                                num_classes=num_classes)
        self._optimizer = optimizer
        self._num_classes = num_classes
        
    
    def init(self, random_key, x):
        params = self._forward_fn.init(random_key, x)
        opt_state = self._optimizer.init(params)

        return TrainingState(params, opt_state)
    
    @partial(jax.jit, static_argnums=0)
    def update(self, state: TrainingState, batch: Batch) -> Tuple[TrainingState, Metrics]:
        (_, metrics), grads = jax.value_and_grad(self._loss_fn, has_aux=True)(state.params, batch)
        grads = jax.lax.pmean(grads, axis_name='batch')
        updates, opt_state = self._optimizer.update(grads, state.opt_state)
        params = optax.apply_updates(state.params, updates)

        return TrainingState(params, opt_state), metrics
    
    @partial(jax.jit, static_argnums=0)
    def evaluate(self, state: TrainingState, batch: Batch):
        _, metrics = self._loss_fn(state.params, batch)
    
        return metrics
    

In [None]:
trainer = Trainer(
    forward_fn=forward_fn,
    loss_fn=cross_entropy_loss,
    optimizer=optax.adam(learning_rate=1e-3),
    num_classes=10
    )

In [None]:
training_state = trainer.init(
    jax.random.PRNGKey(0), 
    x=jnp.ones(shape=(32, 28, 28, 1))
    )
training_state = jax.device_put_replicated(training_state, devices)

In [None]:
def run_train(update_fn, evaluate_fn, state, devices, dataset_train, dataset_test, num_steps, validation_step, run_neptune):
    all_metrics: Metrics = {
        "train_loss": [],
        "train_acc": [],
        "train_step": [],
        "val_loss": [],
        "val_acc": [],
        "val_step": [],
    }

    batch_test = next(dataset_test)
    batch_test = jax.device_put_replicated(batch_test, devices)

    for step in tqdm(range(1, num_steps+1)):
        batch = next(dataset_train)
        batch = jax.device_put_replicated(batch, devices)
        # batch = jax.tree.map(lambda x: x.reshape((num_devices, -1) + x.shape[1:]), batch)

        state, metrics = jax.pmap(update_fn, devices=devices, axis_name='batch')(state, batch)
        
        train_loss = jax.device_get(metrics["loss"]).mean()
        train_accuracy = jax.device_get(metrics["accuracy"]).mean()

        run_neptune["train/loss"].log(train_loss, step=step)
        run_neptune["train/accuracy"].log(train_accuracy, step=step)

        all_metrics["train_loss"].append(train_loss)
        all_metrics["train_acc"].append(train_accuracy)
        all_metrics["train_step"].append(step)

        if step % validation_step == 0:
            # batch_test = jax.tree.map(lambda x: x.reshape((num_devices, -1) + x.shape[1:]), batch_test)
            metrics = jax.pmap(trainer.evaluate, devices=devices, axis_name='batch')(state, batch_test)
            test_loss = jax.device_get(metrics["loss"]).mean()
            test_accuracy = jax.device_get(metrics["accuracy"]).mean()

            run_neptune["test/loss"].log(test_loss, step=step)
            run_neptune["test/accuracy"].log(test_accuracy, step=step)

            all_metrics["val_loss"].append(test_loss)
            all_metrics["val_acc"].append(test_accuracy)
            all_metrics["val_step"].append(step)
    
    return state, all_metrics

In [None]:
state, metrics = run_train(trainer.update, trainer.evaluate, state=training_state, devices=devices, dataset_train=data_train, dataset_test=data_test, num_steps=100, validation_step=10, run_neptune=run)

In [None]:
run.stop()