In [1]:
import os
from glob import glob
from typing import Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import orbax.checkpoint as ocp
import tensorflow as tf
from flax.metrics import tensorboard
# from flax.core import FrozenDict
from flax.training import orbax_utils, train_state
from modal import App, Image, Volume, gpu
from tqdm import tqdm



In [2]:
# Config
NUM_CLASSES = 9
NUM_INPUTS = 3
KERNEL_SIZE = 5

tf.config.experimental.set_visible_devices([], "GPU")

app = App("flax-climate-forecast")
volume = Volume.from_name("climate-forecast")
img = Image.debian_slim().pip_install(
    "flax",
    "numpy",
    "tensorflow[and-cuda]",
    "tensorboard",
    "tqdm",
    "ml-collections",
    "tensorrt",
)

img = img.run_commands(
    [
        # "apt-get update",
        # "apt-get install -y wget",
        # "wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb",
        # "dpkg -i cuda-keyring_1.1-1_all.deb",
        # "apt-get update",
        # "apt-get -y install cuda-toolkit-12-3",
        "pip install -U 'jax[cuda12_pip]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'",
        "python -m site",
        "pip list | grep nvidia",
        "export PATH=/usr/local/cuda-12/bin:$PATH",
        "export LD_LIBRARY_PATH=/usr/local/cuda-12/lib64:/usr/local/lib/python3.11/site-packages/tensorrt_libs/:$LD_LIBRARY_PATH:"
    ]
)

In [3]:
def read_example(serialized: bytes) -> Tuple[jax.Array, jax.Array]:
    """Parses and reads a training example from bytes.

    Args:
        serialized: Serialized example bytes.

    Returns: An (inputs, labels) pair of arrays.
    """
    npz = np.load(serialized)
    inputs = npz['inputs']
    labels = npz['labels']
    # inputs = inputs.astype(jnp.float16)
    # labels = labels.astype(jnp.uint8)

    # Classifications are measured against one-hot encoded vectors.
    # one_hot_labels = jax.nn.one_hot(labels[:, :, :, 0], NUM_CLASSES)
    return (inputs, labels)

In [4]:
def read_dataset(
    data_path: str, train_test_ratio: float
) -> Tuple[Tuple[jax.Array, jax.Array], Tuple[jax.Array, jax.Array]]:
    files = glob(os.path.join(data_path, "*.npz"))
    # files = files[:5]
    # Load data from npz files
    inputs_list = []
    labels_list = []
    for file in files:
        with open(file, "rb") as f:
            inputs, labels = read_example(f)
            inputs_list.append(inputs)
            labels_list.append(labels)

    # Concatenate data
    inputs = np.concatenate(inputs_list, axis=0)
    labels = np.concatenate(labels_list, axis=0)
    print(f"Inputs: {inputs.shape}, Labels: {labels.shape}")
    print(f"Dataset created with {(inputs.shape[0])} examples")
    
    # Normalize data
    # inputs = jax.nn.standardize(inputs)

    train_size = int(inputs.shape[0] * train_test_ratio)
    train_inputs, test_inputs = inputs[:train_size], inputs[train_size:]
    train_labels, test_labels = labels[:train_size], labels[train_size:]

    print(f"Training data: {train_inputs.shape}, Labels: {train_labels.shape}")
    print(f"Testing data: {test_inputs.shape}, Labels: {test_labels.shape}")

    return (train_inputs, train_labels), (test_inputs, test_labels)

In [5]:
# _ = read_dataset("../data/climate_change/", 0.9)

In [6]:
# Define the Fully Convolutional Network.
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(KERNEL_SIZE, KERNEL_SIZE))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.ConvTranspose(features=16, kernel_size=(KERNEL_SIZE, KERNEL_SIZE))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=NUM_CLASSES)(x)
        # x = nn.softmax(x)
        return x

In [7]:
@jax.jit
def apply_model(state, images, labels):
    """Computes gradients, loss and accuracy for a single batch."""
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, images)
        one_hot = jax.nn.one_hot(labels, NUM_CLASSES)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy

In [8]:
@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [9]:
def train_epoch(state, train_ds, batch_size, rng):
    """Train for a single epoch."""
    train_ds_size = len(train_ds[0])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, len(train_ds[0]))
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))

    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
        batch_images = jnp.array(train_ds[0][perm, ...], dtype=jnp.float32)
        batch_images = jax.nn.standardize(batch_images)

        batch_labels = jnp.array(train_ds[1][perm, ...], dtype=jnp.uint8)
        # batch_labels = jax.nn.one_hot(batch_labels[:, :, :, 0], NUM_CLASSES)
        
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, train_loss, train_accuracy

In [10]:
def create_train_state(rng, config):
    """Creates initial `TrainState`."""
    model = CNN()
    params = model.init(rng, jnp.ones([1, 128, 128, NUM_INPUTS]))["params"]
    tx = optax.adam(config.learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


@app.function(
    image=img,
    timeout=60 * 60 * 24,
    volumes={"/vol": volume},
    gpu=gpu.A10G(count=1),
    _allow_background_volume_commits=True,
)
def train_and_evaluate(
    config: ml_collections.ConfigDict, data_dir: str, work_dir: str, ckpt_dir: str
) -> train_state.TrainState:
    """Execute model training and evaluation loop.

    Args:
      config: Hyperparameter configuration for training and evaluation.
      work_dir: Directory where the tensorboard summaries are written to.

    Returns:
      The train state (which includes the `.params`).
    """
    import os
    import shutil

    # logging.get_absl_handler().python_handler.stream = sys.stdout
    os.makedirs(work_dir, exist_ok=True)
    os.makedirs(data_dir, exist_ok=True)

    shutil.rmtree(ckpt_dir, ignore_errors=True)

    ckpt_options = ocp.CheckpointManagerOptions(
        max_to_keep=3,
    )
    ckpt_manager = ocp.CheckpointManager(
        ocp.test_utils.erase_and_create_empty(ckpt_dir),
        options=ckpt_options,
    )

    print(f"JAX process: {jax.process_index()} / {jax.process_count()}")
    print(f"JAX local devices: {jax.local_devices()}")
    train_ds, test_ds = read_dataset(data_dir, config.train_test_split)
    rng = jax.random.key(0)

    summary_writer = tensorboard.SummaryWriter(work_dir)
    summary_writer.hparams(dict(config))

    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, config)

    test_images = jnp.array(test_ds[0], dtype=jnp.float32)
    test_images = jax.nn.standardize(test_images)

    test_labels = jnp.array(test_ds[1], dtype=jnp.uint8)

    for epoch in tqdm(range(config.num_epochs)):
        rng, input_rng = jax.random.split(rng)
        state, train_loss, train_accuracy = train_epoch(
            state, train_ds, config.batch_size, input_rng
        )

        
        _, test_loss, test_accuracy = apply_model(state, test_images, test_labels)

        print(
            f"epoch:{epoch}, train_loss: {train_loss}, train_accuracy: {train_accuracy * 100}, test_loss: {test_loss}, test_accuracy: {test_accuracy * 100}"
        )
        summary_writer.scalar("train_loss", train_loss, epoch)
        summary_writer.scalar("train_accuracy", train_accuracy, epoch)
        summary_writer.scalar("test_loss", test_loss, epoch)
        summary_writer.scalar("test_accuracy", test_accuracy, epoch)

        ckpt = {"model": state}
        # save_args = orbax_utils.save_args_from_target(ckpt)
        ckpt_manager.save(epoch, args=ocp.args.StandardSave(ckpt))
        ckpt_manager.wait_until_finished()

    summary_writer.flush()
    return state

In [14]:
config = ml_collections.ConfigDict()

config.learning_rate = 0.0002
config.batch_size = 32
config.num_epochs = 100
config.train_test_split = 0.9

In [12]:
# ckpt_dir = os.path.abspath("../models/flax/checkpoints")
# train_and_evaluate(
#     config,
#     "../data/climate_change",
#     "../models/flax/logs",
#     ckpt_dir,
# )

In [15]:
with app.run(detach=True):
    train_and_evaluate.remote(
        config, "/vol/data/", "/vol/flax/logs", "/vol/flax/checkpoints"
    )

Output()

Output()

Output()