In [None]:
import os

from functools import partial

from absl import app
from dotenv import load_dotenv
import haiku as hk
import jax
import jax.numpy as jnp
import neptune
import numpy as np
import optax
from tqdm import tqdm
from typing import Dict, NamedTuple, Tuple, Callable
from learntrix.training.losses import cross_entropy_loss

from learntrix.dataloaders.computer_vision.mnist import (
    load_mnist_dataset
)
from learntrix.training.supervised_trainers.runners import (
    run_train
)
from learntrix.training.supervised_trainers.classification_trainer import ClassificationTrainer
from learntrix.types import Batch, TrainingState, Metrics

# Set default device to CPU for JAX
jax.config.update("jax_platform_name", "cpu")


In [None]:
# Get devices if any
devices = jax.devices("cpu")
num_devices = len(devices)
print(f"Detected the following devices: {tuple(devices)}")

In [None]:
def load_env_variable(path, name):
    load_dotenv(path)
    variable = os.getenv(name)
    return variable

def run_neptune(path, project):
    """
    path: path of env file with Neptune token
    neptune_project: name of the neptune project
    """
    api_token = load_env_variable(path=path, name='NEPTUNE_API_TOKEN')

    run = neptune.init_run(
        project=project,
        api_token=api_token,
    )

    return run

run = run_neptune(path='./.env', project="yanisadel/learn-jax")

params = {"learning_rate": 0.001, "optimizer": "Adam"}
run["parameters"] = params

In [None]:
data_train = load_mnist_dataset(
    "train",
    shuffle=True, 
    batch_size=64
    )
data_test = load_mnist_dataset(
    "test",
    shuffle=False, 
    batch_size=10000
    )

In [None]:
def forward_fn(x: jax.Array) -> jax.Array:
    x = x.astype(jnp.float32) / 255.
    mlp = hk.Sequential([
        hk.Flatten(),
        hk.Linear(300), jax.nn.relu,
        hk.Linear(100), jax.nn.relu,
        hk.Linear(10),
    ])
    return mlp(x)

In [None]:
trainer = ClassificationTrainer(
    forward_fn=forward_fn,
    loss_fn=cross_entropy_loss,
    optimizer=optax.adam(learning_rate=1e-3),
    num_classes=10
    )

In [None]:
training_state = trainer.init(
    jax.random.PRNGKey(0), 
    x=jnp.ones(shape=(32, 28, 28, 1))
    )
training_state = jax.device_put_replicated(training_state, devices)

In [None]:
state, metrics = run_train(
    update_fn=trainer.update,
    evaluate_fn=trainer.evaluate,
    state=training_state,
    devices=devices,
    dataset_train=data_train,
    dataset_test=data_test,
    num_steps=100,
    validation_step=10,
    run_neptune=run)

In [None]:
run.stop()