# How to set up periodic action in the loop

This notebook demonstrates `crontab` feature of bobbin.

## Preamble: Install prerequisites, import modules.

In [1]:
!pip -q install --upgrade pip
!pip -q install --upgrade "jax[cpu]"
!pip -q uninstall -y bobbin
!pip -q install --upgrade git+https://github.com/yotarok/bobbin.git

In [2]:
%%capture
import logging
import sys
import tempfile
import time

import bobbin
import chex
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds

Array = chex.Array

logging.basicConfig(stream=sys.stdout)
logging.root.setLevel(logging.INFO)

2023-02-27 17:41:54.518752: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/google/home/yotaro/cuda/gpus/cuda_11_0/lib64
2023-02-27 17:41:54.518868: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/google/home/yotaro/cuda/gpus/cuda_11_0/lib64


In [3]:
logger = logging.getLogger()
logger.addHandler(logging.FileHandler("/dev/stdout"))

## Define tasks and models

Here, we will demonstrate how to construct a loop that involves full training setup. Some training/ evaluation setup is needed.
Only minimum explanation added to the training/ evaluation code below.  Please refer the following documents for training/ evaluation tasks in bobbin.

- Training: [How to write a training loop](https://bobbin.readthedocs.io/en/latest/train_task.html)
- Evaluation: [How to define an evaluation task](https://bobbin.readthedocs.io/en/latest/eval_task.html)

First, let's build a pipeline for pulling the training and evaluation datasets.
The functions can be built as follows:

In [4]:
def get_train_dataset(batch_size):
    ds = tfds.load("mnist", split="train", as_supervised=True)
    ds = ds.repeat().shuffle(1024).batch(batch_size).prefetch(1)
    return ds


def get_eval_dataset(batch_size):
    ds = tfds.load("mnist", split="test[:1000]", as_supervised=True)
    ds = ds.batch(batch_size).prefetch(1)
    return ds

Then, we define the classifier model and loss function (in a subclass of `TrainTask`), as follows:
(please also check [How to write a training loop](https://bobbin.readthedocs.io/en/latest/train_task.html))

In [5]:
class MnistClassifier(nn.Module):
    @nn.compact
    def __call__(self, x: Array, *, training=True) -> Array:
        batch_size, *unused_image_dims = x.shape
        x = x.reshape((batch_size, -1))  # flatten the input image.
        hidden = nn.sigmoid(nn.Dense(features=512)(x))
        return nn.Dense(features=10)(hidden)


class MnistTrainingTask(bobbin.TrainTask):
    def __init__(self):
        super().__init__(
            MnistClassifier(),
            example_args=(
                np.zeros((1, 28, 28, 1), np.float32),  # comma-here is important
            ),
        )

    def compute_loss(self, params, batch, *, extra_vars, prng_key, step):
        images, labels = batch
        logits = self.model.apply({"params": params}, images)
        per_sample_loss = optax.softmax_cross_entropy(
            logits=logits, labels=jax.nn.one_hot(labels, 10)
        )
        return jnp.mean(per_sample_loss), ({}, None)


task = MnistTrainingTask()
train_step_fn = task.make_training_step_fn().jit()

The evaluation metrics and how to evaluate the model can be defined as follows:
(check [How to define an evaluation task](https://bobbin.readthedocs.io/en/latest/eval_task.html), too)

In [12]:
class EvalResults(bobbin.EvalResults):
    correct_count: int
    predict_count: int

    @property
    def accuracy(self) -> float:
        return self.correct_count / self.predict_count

    def is_better_than(self, other: "EvalResults") -> bool:
        return self.accuracy > other.accuracy

    def reduce(self, other: "EvalResults") -> "EvalResults":
        return jax.tree_util.tree_map(lambda x, y: x + y, self, other)

    def to_log_message(self) -> str:
        return f"formatted in `EvalResults.to_log_message`. acc={self.accuracy:.2f}"


class EvalTask(bobbin.EvalTask):
    def __init__(self):
        self.model = MnistClassifier()

    def create_eval_results(self, dataset_name):
        return EvalResults(correct_count=0, predict_count=0)

    def evaluate(self, batch, model_vars) -> EvalResults:
        inputs, labels = batch
        logits = self.model.apply(model_vars, inputs)
        predicts = logits.argmax(axis=-1)
        return EvalResults(
            correct_count=(predicts == labels).astype(np.int32).sum(),
            predict_count=labels.shape[0],
        )


eval_batch_gens = {
    "test": get_eval_dataset(32).as_numpy_iterator,
}
evaler = EvalTask()

Load dataset info from /usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1
Fields info.[citation, splits, supervised_keys, module_name] from disk and from code do not match. Keeping the one from code.
Reusing dataset mnist (/usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1)
Constructing tf.data.Dataset mnist for split test[:1000], from /usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1


## Setup crontab

Given the above models and tasks, we are now ready to actually write a training loop.
As a first example, we design our main loop to greet to users for each 0.1 second using `CronTab.schedule` method.

In [19]:
def say_hello(train_state, *, message: str, **kwargs):
    print(
        f"{message} Training is currently at {train_state.step}-th step. {time.time()}"
    )


crontab = bobbin.CronTab()
crontab.schedule(say_hello, time_interval=0.1)

prng_key = jax.random.PRNGKey(0)
train_state = task.initialize_train_state(jax.random.PRNGKey(0), optax.sgd(0.01))
for batch in get_train_dataset(64).take(500).as_numpy_iterator():
    rng, prng_key = jax.random.split(prng_key)
    train_state, step_info = train_step_fn(train_state, batch, rng)
    crontab.run(train_state, message="Hello!!", is_train_state_replicated=False)

Load dataset info from /usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1
Fields info.[citation, splits, supervised_keys, module_name] from disk and from code do not match. Keeping the one from code.
Reusing dataset mnist (/usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1)
Constructing tf.data.Dataset mnist for split train, from /usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1
Hello!! Training is currently at 1-th step. 1677497122.7034385
Hello!! Training is currently at 27-th step. 1677497122.8050828
Hello!! Training is currently at 76-th step. 1677497122.9070957
Hello!! Training is currently at 136-th step. 1677497123.0086877
Hello!! Training is currently at 197-th step. 1677497123.1097565
Hello!! Training is currently at 257-th step. 1677497123.211207
Hello!! Training is currently at 319-th step. 1677497123.312581
Hello!! Training is currently at 380-th step. 1677497123.4127362
Hello!! Training is currently at 416-th step. 1677497123.5137897
Hello

2023-02-27 20:25:23.648860: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


The first argument of `CronTab.schedule` is something called "Action" that can be anything called as `f(train_state, **kwargs)`.
The action registered by `CronTab.schedule` will be called when you call `CronTab.run` at the end of each training step, and if the pre-specified condition met.
In this case, the pre-defined condition is satisfied when the elapsed time since the action is lastly executed is longer than 0.1 second.
(In other words, the action executed only once even if the step took longer than 0.2 seconds.)

One can pass additional context information by adding keywords arguments to the call of `CronTab.run`.

`CronTab` is defined to be a hub for weakly connect the functionalities provided by other bobbin sub-modules.

For example, `TrainTask` provides an action that write training log to the logger, and `EvalTask` provides an action to run the evaluation process over the datasets, as follows:

In [20]:
crontab = bobbin.CronTab()
crontab.schedule(
    task.make_log_writer(loglevel=logging.WARNING), at_step=123, step_interval=100
)
crontab.schedule(
    evaler.make_cron_action(eval_batch_gens, tensorboard_root_path=None),
    step_interval=123,
)
prng_key = jax.random.PRNGKey(0)
train_state = task.initialize_train_state(jax.random.PRNGKey(0), optax.sgd(0.01))
for batch in get_train_dataset(64).take(500).as_numpy_iterator():
    rng, prng_key = jax.random.split(prng_key)
    train_state, step_info = train_step_fn(train_state, batch, rng)
    crontab.run(train_state, step_info=step_info, is_train_state_replicated=False)

Load dataset info from /usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1
Fields info.[citation, splits, supervised_keys, module_name] from disk and from code do not match. Keeping the one from code.
Reusing dataset mnist (/usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1)
Constructing tf.data.Dataset mnist for split train, from /usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1
@step=100, loss=0.953003
@step=123, loss=0.935717
Start evaluation process over test
Evaluation results for dataset=test @step=123
formatted in `EvalResults.to_log_message`. acc=0.84
@step=200, loss=0.647676
Start evaluation process over test
Evaluation results for dataset=test @step=246
formatted in `EvalResults.to_log_message`. acc=0.89
@step=300, loss=0.472512
Start evaluation process over test
Evaluation results for dataset=test @step=369
formatted in `EvalResults.to_log_message`. acc=0.90
@step=400, loss=0.521293
Start evaluation process over test
Evaluation results for da

2023-02-27 20:30:18.723999: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In the example, `TrainTask.make_log_writer` only writes a very simple log message, this can be customized by overriding `TrainTask.write_trainer_log` function.

`CronTab` can also be used for tying the training loop with checkpoint writers.  In the below example, we use two directory for storing checkpoints; one is for storing normal checkpoints for resuming the training processes, and the other one is for keeping best performing checkpoints for the future usage.

In [18]:
checkpoint_temp_dir = tempfile.TemporaryDirectory()
best_checkpoint_temp_dir = tempfile.TemporaryDirectory()

crontab = bobbin.CronTab()
crontab.schedule(
    task.make_checkpoint_saver(checkpoint_temp_dir.name), step_interval=1000
)
crontab.schedule(
    evaler.make_cron_action(
        eval_batch_gens, tensorboard_root_path=None
    ).keep_best_checkpoint("test", best_checkpoint_temp_dir.name),
    step_interval=1000,
)

prng_key = jax.random.PRNGKey(0)
train_state = task.initialize_train_state(jax.random.PRNGKey(0), optax.sgd(0.1))
for batch in get_train_dataset(64).take(5000).as_numpy_iterator():
    rng, prng_key = jax.random.split(prng_key)
    train_state, step_info = train_step_fn(train_state, batch, rng)
    crontab.run(train_state, step_info=step_info, is_train_state_replicated=False)

print("Latest checkpoints:")
!ls {checkpoint_temp_dir.name}
print("Best checkpoints:")
!ls {best_checkpoint_temp_dir.name}
print("Results of the best checkpoint")
!cat {best_checkpoint_temp_dir.name}/results.json

Load dataset info from /usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1
Fields info.[citation, splits, supervised_keys, module_name] from disk and from code do not match. Keeping the one from code.
Reusing dataset mnist (/usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1)
Constructing tf.data.Dataset mnist for split train, from /usr/local/google/home/yotaro/tensorflow_datasets/mnist/3.0.1
Saving checkpoint at step: 1000
Saved checkpoint at /tmp/tmppqxpbnaf/checkpoint_1000
Start evaluation process over test
Evaluation results for dataset=test @step=1000
formatted in `EvalResults.to_log_message`. acc=0.90
Saving checkpoint at step: 1000
Saved checkpoint at /tmp/tmpf5gihi9h/checkpoint_1000
Saving checkpoint at step: 2000
Saved checkpoint at /tmp/tmppqxpbnaf/checkpoint_2000
Removing checkpoint at /tmp/tmppqxpbnaf/checkpoint_1000
Start evaluation process over test
Evaluation results for dataset=test @step=2000
formatted in `EvalResults.to_log_message`. acc=0.90
Savin