In [2]:
import os
from glob import glob
from typing import Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import logging
import ml_collections
import numpy as np
import optax
import tensorflow as tf
from tqdm import tqdm
from flax.metrics import tensorboard
# from flax.core import FrozenDict
from flax.training import train_state

In [3]:
# Config
NUM_CLASSES = 9
NUM_INPUTS = 3
KERNEL_SIZE = 5

tf.config.experimental.set_visible_devices([], 'GPU')

In [4]:
def read_example(serialized: bytes) -> Tuple[jax.Array, jax.Array]:
    """Parses and reads a training example from bytes.

    Args:
        serialized: Serialized example bytes.

    Returns: An (inputs, labels) pair of arrays.
    """
    npz = jnp.load(serialized)
    inputs = npz['inputs']
    labels = npz['labels']
    inputs = inputs.astype(jnp.float16)
    labels = labels.astype(jnp.uint8)

    # Classifications are measured against one-hot encoded vectors.
    one_hot_labels = jax.nn.one_hot(labels[:, :, :, 0], NUM_CLASSES)
    return (inputs, one_hot_labels)

In [5]:
def read_dataset(
    data_path: str, train_test_ratio: float
) -> Tuple[Tuple[jax.Array, jax.Array], Tuple[jax.Array, jax.Array]]:
    files = glob(os.path.join(data_path, "*.npz"))
    files = files[:5]
    # Load data from npz files
    inputs_list = []
    labels_list = []
    for file in files:
        with open(file, "rb") as f:
            inputs, labels = read_example(f)
            inputs_list.append(inputs)
            labels_list.append(labels)

    # Concatenate data
    inputs = jnp.concatenate(inputs_list, axis=0)
    labels = jnp.concatenate(labels_list, axis=0)
    print(f"Inputs: {inputs.shape}, Labels: {labels.shape}")
    print(f"Dataset created with {(inputs.shape[0])} examples")
    
    # Normalize data
    inputs = jax.nn.standardize(inputs)

    train_size = int(inputs.shape[0] * train_test_ratio)
    train_inputs, test_inputs = inputs[:train_size], inputs[train_size:]
    train_labels, test_labels = labels[:train_size], labels[train_size:]

    print(f"Training data: {train_inputs.shape}, Labels: {train_labels.shape}")
    print(f"Testing data: {test_inputs.shape}, Labels: {test_labels.shape}")

    return (train_inputs, train_labels), (test_inputs, test_labels)

In [6]:
# _ = read_dataset("../data/climate_change/", 0.9)

In [7]:
# Define the Fully Convolutional Network.
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(KERNEL_SIZE, KERNEL_SIZE))(x)
        x = nn.relu(x)
        x = nn.ConvTranspose(features=16, kernel_size=(KERNEL_SIZE, KERNEL_SIZE))(x)
        x = nn.relu(x)
        x = nn.Dense(features=NUM_CLASSES)(x)
        # x = nn.softmax(x)
        return x

In [14]:
@jax.jit
def apply_model(state, images, labels):
    """Computes gradients, loss and accuracy for a single batch."""
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, images)
        # one_hot = jax.nn.one_hot(labels[:, :, :, 0], NUM_CLASSES)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=labels))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    return grads, loss, accuracy

In [9]:
@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [10]:
def train_epoch(state, train_ds, batch_size, rng):
    """Train for a single epoch."""
    train_ds_size = len(train_ds[0])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, len(train_ds[0]))
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))

    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
        batch_images = train_ds[0][perm, ...]
        batch_labels = train_ds[1][perm, ...]
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, train_loss, train_accuracy

In [11]:
def create_train_state(rng, config):
    """Creates initial `TrainState`."""
    model = CNN()
    params = model.init(rng, jnp.ones([1, 128, 128, NUM_INPUTS]))["params"]
    tx = optax.adam(config.learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


def train_and_evaluate(
    config: ml_collections.ConfigDict, workdir: str
) -> train_state.TrainState:
    """Execute model training and evaluation loop.

    Args:
      config: Hyperparameter configuration for training and evaluation.
      workdir: Directory where the tensorboard summaries are written to.

    Returns:
      The train state (which includes the `.params`).
    """
    train_ds, test_ds = read_dataset("../data/climate_change/", config.train_test_split)
    rng = jax.random.key(0)

    summary_writer = tensorboard.SummaryWriter(workdir)
    summary_writer.hparams(dict(config))

    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, config)

    for epoch in tqdm(range(config.num_epochs)):
        rng, input_rng = jax.random.split(rng)
        state, train_loss, train_accuracy = train_epoch(
            state, train_ds, config.batch_size, input_rng
        )
        _, test_loss, test_accuracy = apply_model(state, test_ds[0], test_ds[1])

        logging.info(
            "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,"
            "est_accuracy: %.2f"
            % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
        )
        summary_writer.scalar("train_loss", train_loss, epoch)
        summary_writer.scalar("train_accuracy", train_accuracy, epoch)
        summary_writer.scalar("test_loss", test_loss, epoch)
        summary_writer.scalar("test_accuracy", test_accuracy, epoch)

    summary_writer.flush()
    return state

In [12]:
config = ml_collections.ConfigDict()

config.learning_rate = 0.002
config.batch_size = 16
config.num_epochs = 10
config.train_test_split = 0.9

In [15]:
train_and_evaluate(config, "models/flax")

Inputs: (533, 128, 128, 3), Labels: (533, 128, 128, 9)
Dataset created with 533 examples
Training data: (479, 128, 128, 3), Labels: (479, 128, 128, 9)
Testing data: (54, 128, 128, 3), Labels: (54, 128, 128, 9)


100%|██████████| 10/10 [03:12<00:00, 19.28s/it]


TrainState(step=Array(290, dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of CNN()>, params={'ConvTranspose_0': {'bias': Array([ 0.03885959,  0.01519761, -0.01803895,  0.07234871, -0.01056602,
       -0.01196541,  0.06809472,  0.06171704,  0.00197484, -0.02126722,
        0.09919451, -0.01911338, -0.01825645, -0.01731723, -0.02100618,
       -0.01282899], dtype=float32), 'kernel': Array([[[[ 2.68274080e-02, -2.15668529e-02, -1.42716821e-02, ...,
          -1.19334254e-02,  1.77734196e-02,  2.12352909e-02],
         [ 3.24405432e-02, -1.71673987e-02,  2.33256468e-03, ...,
           2.30056942e-02, -2.63182130e-02,  2.86908029e-03],
         [-1.61250995e-03,  6.28879964e-02,  1.54315308e-02, ...,
           3.73944081e-02, -2.70166285e-02,  9.80455521e-03],
         ...,
         [ 2.71996967e-02, -2.46805269e-02, -6.11730516e-02, ...,
           1.29407539e-03, -1.47007992e-02,  3.79358605e-02],
         [ 2.72708712e-03, -3.99893001e-02, -5.85625358e-02, ...,
     