In [2]:
from zapbench import constants
from zapbench import data_utils
from zapbench.ts_forecasting import data_source


# condition_name = 'turning'  # can be any name in constants.CONDITION_NAMES
# num_timesteps_context = 4  # 4 for short context, 256 for long context
# split = 'train'  # change to 'val' for validation set, e.g., for early stopping

# config = data_source.TensorStoreTimeSeriesConfig(
#     input_spec=data_utils.adjust_spec_for_condition_and_split(
#         condition=constants.CONDITION_NAMES.index(condition_name),
#         split=split,
#         spec=data_utils.get_spec('240930_traces'),
#         num_timesteps_context=num_timesteps_context),
#     timesteps_input=num_timesteps_context,
#     timesteps_output=constants.PREDICTION_WINDOW_LENGTH,
# )
# source = data_source.TensorStoreTimeSeries(config)

# print(f'{len(source)=}')

In [None]:
offset = config.timesteps_input + config.timesteps_output
source.volume.shape[0], offset

In [None]:
source.n_indexer

... when indexing into the data source, we get `series_input`, i.e., past activity of `num_timesteps_context` length, and `series_output`, 32 timesteps of subsequent activity (the prediction horizon used in ZAPBench).

By enabling `prefetch` on `data_source.TensorStoreTimeSeries`, we can load the entire data into memory upfront. This makes indexing significantly faster once the source has been initialized.

In [None]:
import random

source = data_source.TensorStoreTimeSeries(config, prefetch=False)  # Default

In [None]:
%%timeit
_ = source[random.randint(0, len(source)-1)]

In [None]:
source = data_source.TensorStoreTimeSeries(config, prefetch=True)

In [None]:
%%timeit
_ = source[random.randint(0, len(source)-1)]

We can also create a data source that combines data from all training conditions (should take about a minute to prefetch):

In [None]:
sources = []

# Iterate over all training conditions (excludes 'taxis'), and create
# data sources.
for condition_id in constants.CONDITIONS_TRAIN:
  config = data_source.TensorStoreTimeSeriesConfig(
      input_spec=data_utils.adjust_spec_for_condition_and_split(
          condition=condition_id,
          split='train',
          spec=data_utils.get_spec('240930_traces'),
          num_timesteps_context=num_timesteps_context),
      timesteps_input=num_timesteps_context,
      timesteps_output=constants.PREDICTION_WINDOW_LENGTH,
  )
  sources.append(data_source.TensorStoreTimeSeries(config, prefetch=True))

# Concatenate into a single source.
source = data_source.ConcatenatedTensorStoreTimeSeries(*sources)

f'{len(source)=}'

Next, we set up an index sampler and construct a data loader with `grain`:

In [None]:
import grain.python as grain


batch_size = 8
num_epochs = 1
shuffle = True

index_sampler = grain.IndexSampler(
    num_records=len(source),
    num_epochs=num_epochs,
    shard_options=grain.ShardOptions(
        shard_index=0, shard_count=1, drop_remainder=True),
    shuffle=shuffle,
    seed=101
)

data_loader = grain.DataLoader(
    data_source=source,
    sampler=index_sampler,
    operations=[
        grain.Batch(
            batch_size=batch_size, drop_remainder=True)
    ],
    worker_count=0
)

We can iterate over the data loader which will get elements with a batch dimension in random order for `num_epochs`:

In [None]:
from tqdm import tqdm


for element in tqdm(data_loader):
  #
  # ... train model with element
  #
  continue

element

`grain` has many useful features -- for example, we can easily add operations to the data loader to adjust shapes, or add augmentations. More details are in [grain's DataLoader guide](https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html).

## Evaluation

Say we have trained a new baseline, how do we evaluate it?

We are going to use the mean baseline from the manuscript as an example: It can easily be re-implemented in NumPy and does not require any training.

In [None]:
import numpy as np


def f_mean(past_activity: np.ndarray) -> np.ndarray:
  """Mean baseline

  Args:
    past_activity: Past activity as time x neurons matrix.

  Returns:
    Predicted activity calculated by taking the per-neuron mean across time and
    repeating it for all 32 timesteps in the prediction horizon.
  """
  return past_activity.mean(axis=0).reshape((1, -1)).repeat(
      constants.PREDICTION_WINDOW_LENGTH, axis=0)

For inference, we create a data source containing the full trace matrix, and index it as described in [the manuscript](https://openreview.net/pdf?id=oCHsDpyawq) (section 3.2) to compute metrics.

In [None]:
infer_source = data_source.TensorStoreTimeSeries(
    data_source.TensorStoreTimeSeriesConfig(
        input_spec=data_utils.get_spec('240930_traces'),
        timesteps_input=num_timesteps_context,
        timesteps_output=constants.PREDICTION_WINDOW_LENGTH,
    ),
    prefetch=True
)

In [None]:
from collections import defaultdict

from connectomics.jax import metrics


# Placeholder for results
MAEs = defaultdict(list)

# Iterate over all conditions, and make predictions for all contiguous snippets
# of length 32 in the respective test set.
for condition_id, condition_name in tqdm(enumerate(constants.CONDITION_NAMES)):
  split = ('test' if condition_id not in constants.CONDITIONS_HOLDOUT
           else 'test_holdout')
  test_min, test_max = data_utils.adjust_condition_bounds_for_split(
      split,
      *data_utils.get_condition_bounds(condition_id),
      num_timesteps_context=num_timesteps_context)

  for window in range(
      data_utils.get_num_windows(test_min, test_max, num_timesteps_context)):
    element = infer_source[test_min + window]

    predictions = f_mean(element['series_input'])
    mae = metrics.mae(predictions=predictions, targets=element['series_output'])

    MAEs[condition_name].append(np.array(mae))

... let's plot our results:

In [None]:
import matplotlib.pyplot as plt


steps_ahead = np.arange(32) + 1

for condition_name in constants.CONDITION_NAMES:
  mae = np.stack(MAEs[condition_name]).mean(axis=0)  # Average over windows
  plt.plot(steps_ahead, mae, label=condition_name)

plt.title('mean baseline, short context')
plt.xlabel('steps predicted ahead')
plt.ylabel('MAE')
plt.ylim((0.015, 0.06))
plt.xlim(1, 32)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

Finally, we briefly check that these results match the ones in the manuscript:

In [None]:
from connectomics.common import ts_utils
import pandas as pd


# Load dataframe with results reported in the manuscript.
df = pd.DataFrame(
    ts_utils.load_json(f'gs://zapbench-release/dataframes/20250131/combined.json'))
df.head()

In [None]:
for condition_name in constants.CONDITION_NAMES:
  mae = np.stack(MAEs[condition_name]).mean(axis=0)
  mae_df = df.query(
      f'method == "mean" and context == 4 and condition == "{condition_name}"'
  ).sort_values('steps_ahead')['MAE'].to_numpy()
  np.testing.assert_array_almost_equal(mae, mae_df, decimal=8)