# Interactive time-series forecasting

This tutorial shows how to run a `jax` time-series forecasting model interactively, which may e.g. be useful for development. The results in the manuscript were generated by running models [through command-line execution](https://github.com/google-research/zapbench/blob/main/zapbench/ts_forecasting/README.md).

In [None]:
!pip install git+https://github.com/google-research/zapbench.git#egg=zapbench
!pip install penzai

In [None]:
import functools
import importlib
import itertools as it

from connectomics.jax import training
import flax.jax_utils as flax_utils
import jax
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from penzai import pz
from zapbench.ts_forecasting import heads
from zapbench.ts_forecasting import input_pipeline
from zapbench.ts_forecasting import train


pz.enable_interactive_context()
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

Set up config, seeding, and load datasets:

In [None]:
model_name = 'timemix'  # see zapbench/ts_forecasting/configs/
config_module = importlib.import_module(
    f'zapbench.ts_forecasting.configs.{model_name}')

config_arg = 'timesteps_input=4'
config = config_module.get_config(arg=config_arg)
config.per_device_batch_size = 4
config.prefetch = True

rng = training.get_rng(config.seed)
rng, data_seed = jax.random.split(rng)
data_seed = int(
    jax.random.randint(data_seed, [], minval=0, maxval=np.iinfo(np.int32).max)
)

train_loader, _, val_loader, _ = (
    input_pipeline.create_datasets(config, data_seed)
)
covariates_static = input_pipeline.get_static_covariates(config)

Initialize the model, state, and helpers:

In [None]:
input_shapes=(config.series_shape,) + tuple(config.covariates_shapes)
print(f'{input_shapes=}')

rng, model_rng = jax.random.split(rng)
model, optimizer, schedule, train_state = train.create_train_state(
    config,
    model_rng,
    input_shapes=input_shapes,
)
head = heads.create_head(config)

train_state = flax_utils.replicate(train_state)

p_train_step = jax.pmap(
    functools.partial(
        train.train_step,
        model=model,
        head=head,
        optimizer=optimizer,
        schedule=schedule,
        covariates=config.covariates,
    ),
    axis_name='batch',
)

Train the model for 3,000 steps and monitor the loss:

In [None]:
train_losses = {}
train_iter = iter(train_loader)

for i in tqdm(range(3_000)):
  batch = next(train_iter)
  if 'covariates_static' in config.covariates:
    batch['covariates_static'] = covariates_static
  batch = training.reshape_batch_local_devices(batch)

  train_state, metrics_update = p_train_step(
      train_state=train_state, batch=batch)

  metric_update = flax_utils.unreplicate(metrics_update)
  train_losses[i] = metric_update.compute()['train_loss']

In [None]:
plt.plot(train_losses.keys(), train_losses.values(), label='train loss')
plt.xlabel('train step')
plt.ylabel('loss');

Finally, we make predictions on the validation set and compare them against ground-truth.

In [None]:
batch = next(iter(val_loader))

out = model.apply(
    {'params': flax_utils.unreplicate(train_state.params)},
    *[batch[k] for k in it.chain(('timeseries_input',), config.covariates)],
    train=False,
    capture_intermediates=False,
)

In [None]:
print('predictions')
out

In [None]:
print('targets')
batch['timeseries_output']