In [None]:
import os
from collections.abc import Sequence

import flax.jax_utils as flax_utils
import flax.linen as nn
import grain.python as grain
import jax
import numpy as np
from absl import logging
from connectomics.jax import checkpoint, training
from etils import epath
from orbax import checkpoint as ocp

import zapbench.models.util as model_util
from zapbench.ts_forecasting import heads, input_pipeline, train
from zapbench.ts_forecasting.configs import infer, linear


def _get_checkpoint_step(
    checkpoint_manager: ocp.CheckpointManager,
    selection_strategy: str,
) -> int | None:
  """Returns the checkpoint step to use given a selection strategy.

  Args:
    checkpoint_manager: Checkpoint manager.
    selection_strategy: Checkpoint selection strategy, can be 'early_stopping',
      'best_val_loss', or 'latest'.

  Returns:
    Checkpoint step.
  """
  if selection_strategy == 'early_stopping':
    checkpointed_state = dict(
        early_stop=None,
    )
    checkpointed_state = checkpoint.restore_checkpoint(
        checkpoint_manager,
        state=checkpointed_state,
        step=checkpoint_manager.latest_step(),
    )
    return checkpointed_state['early_stop']['best_step']
  elif selection_strategy == 'best_val_loss':
    checkpointed_state = dict(
        track_best_val_loss_step=None,
    )
    checkpointed_state = checkpoint.restore_checkpoint(
        checkpoint_manager,
        state=checkpointed_state,
        step=checkpoint_manager.latest_step(),
    )
    return checkpointed_state['track_best_val_loss_step']['best_step']
  elif selection_strategy == 'latest':
    return checkpoint_manager.latest_step()
  else:
    raise ValueError(f'Unknown checkpoint selection: {selection_strategy}')


def infer_single_step(
    model: nn.Module,
    head: heads.Head,
    train_state: train.TrainState,
    data_source: grain.RandomAccessDataSource,
    idx: int,
    infer_key: jax.Array,  # pylint: disable=unused-argument
    covariates: Sequence[str] = (),
    covariates_static: jax.Array | None = None,
    with_carry: bool = False,
) -> tuple[jax.Array, jax.Array]:
  """Runs independent inference on each index in the test set.

  Returns:
    prediction: prediction array
    target: target array
  """
  carry = None

  batch = data_source[idx]
  if 'covariates_static' in covariates:
    batch['covariates_static'] = covariates_static

  out = train.pred_step(
      model,
      train_state,
      batch,
      covariates,
      initial_carry=carry,
      return_carry=with_carry,
  )

  if not with_carry:
    dist = head.get_distribution(out)
  else:
    carry, dist = out[0], head.get_distribution(out[1])

  prediction = dist.mode()
  target = batch['timeseries_output']

  return prediction, target

In [None]:
exp_workdir = '/Users/s/vault/zapbench/train_subject_14'
exp_config = model_util.load_config(os.path.join(exp_workdir, 'config.json'))

model = model_util.model_from_config(exp_config)

covariates_static = input_pipeline.get_static_covariates(exp_config)

checkpoint_manager = checkpoint.get_checkpoint_manager(
    exp_workdir,
    item_names=(
        'early_stop',
        'train_state',
        'track_best_val_loss_step',
    ),
)

step = _get_checkpoint_step(checkpoint_manager, 'best_val_loss')

checkpointed_state = dict(
    train_state=None,
)
checkpointed_state = checkpoint.restore_checkpoint(
    checkpoint_manager,
    state=checkpointed_state,
    step=step,
)
train_state = checkpointed_state['train_state']
train_state = train.TrainState(
    **train_state
)
train_state = flax_utils.replicate(train_state)

In [None]:
# inference
infer_config = infer.get_config()
config = linear.get_config()
config.update(infer_config)

head = heads.create_head(config)
rng = training.get_rng(config.seed)
rng, infer_rng = jax.random.split(rng)
infer_source = input_pipeline.create_inference_source_with_transforms(config)
infer_key = jax.random.fold_in(key=infer_rng, data=step)

In [None]:
for infer_idx_set in config.infer_idx_sets:
  name, idx_list = (infer_idx_set[k] for k in ('name', 'idx_list'))
  infer_metrics = None
  train_state = train.merge_batch_stats(train_state)
  for i, idx in enumerate(idx_list):
    prediction, target = infer_single_step(
        model,
        head,
        flax_utils.unreplicate(train_state),
        infer_source,
        idx,
        infer_key=infer_key,
        covariates=tuple(config.covariates),
        covariates_static=covariates_static,
        with_carry=config.infer_with_carry,
    )

    if f'infer_{name}' in head.metrics:
      metrics_update = head.metrics[
          f'infer_{name}'
      ].single_from_model_output(predictions=prediction, targets=target)
      infer_metrics = (
          metrics_update
          if infer_metrics is None
          else infer_metrics.merge(metrics_update)
      )

  if infer_metrics is not None:
    infer_metrics_cpu = jax.tree.map(np.array, infer_metrics.compute())
    print(infer_metrics_cpu)


In [None]:
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['nature', 'no-latex'])

In [None]:
i_neuron_list = range(0, 10000, 500)

for i_neuron in i_neuron_list:
  fig, ax = plt.subplots(figsize=(4, 1), dpi=200)
  ax.plot(prediction[0, :, i_neuron])
  ax.plot(target[0, :, i_neuron])
  ax.set_ylim(0, 0.3)

In [None]:
def smooth_target(target, window_size=10):
    kernel = np.ones(window_size) / window_size
    smoothed = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode='valid'), axis=1, arr=target)
    return smoothed

target_smoothed = smooth_target(target, window_size=5)

In [None]:
for i_neuron in i_neuron_list:
  fig, ax = plt.subplots(figsize=(4, 1), dpi=200)
  ax.plot(prediction[0, :, i_neuron])
  ax.plot(target_smoothed[0, :, i_neuron])
  ax.set_ylim(0, 0.3)

In [None]:
import matplotlib.pyplot as plt

abs_diff = np.abs(prediction - target)[0, :, :]
plt.figure(figsize=(5, 5), dpi=200)
plt.imshow(abs_diff.T, aspect='auto', cmap='magma', origin='lower', vmin=0, vmax=0.2)
plt.colorbar()
plt.axis('off')
plt.show()