In [None]:
#!/usr/bin/env python3
"""Script to understand exactly how num_train_records is calculated."""

import sys
import os

from zapbench.ts_forecasting.configs import linear
from zapbench.ts_forecasting import input_pipeline
from zapbench.ts_forecasting import data_source
import jax

In [None]:
config = linear.get_config("normalization=4")

config.num_epochs

In [None]:
# Copyright 2025 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Linear models."""

from collections import abc
import dataclasses

from connectomics.jax import config_util
import immutabledict
import ml_collections as mlc
from zapbench import constants
from zapbench import hparam_utils as hyper
from zapbench.models import nlinear
from zapbench.ts_forecasting.configs import common


_ARGS = immutabledict.immutabledict({
    'normalization': False,
    'seed': -1,
    'timesteps_input': 4,
    'num_epochs': 1,
    'val_ckpt_every_steps': 250,
    'log_loss_every_steps': 100,
})
config = mlc.ConfigDict()
config.arg = config_util.parse_arg("timesteps_input=4,num_epochs=10", **_ARGS)

In [None]:
config.update(common.get_config(**config.arg))

In [None]:
config.num_epochs

In [None]:
from zapbench.ts_forecasting.configs import linear
from zapbench.ts_forecasting import input_pipeline
from zapbench.ts_forecasting import data_source
from connectomics.jax import grain_util
import grain.python as grain
import jax

config = linear.get_config("timesteps_input=4")
drop_remainder = True
shard_options = grain.ShardByJaxProcess(drop_remainder=drop_remainder)
all_ops = input_pipeline.get_all_ops()
transformations = list(grain_util.parse(config.pre_process_str, all_ops))
process_batch_size = jax.local_device_count() * config.per_device_batch_size
batch_op = grain.Batch(batch_size=process_batch_size, drop_remainder=drop_remainder)
transformations.append(batch_op)
transformations += list(grain_util.parse(config.batch_process_str, all_ops))

train_source = data_source.ConcatenatedTensorStoreTimeSeries(*[
    data_source.MergedTensorStoreTimeSeries(*[
        data_source.TensorStoreTimeSeries(
            config=data_source.TensorStoreTimeSeriesConfig(
                input_spec=input_spec.to_dict() if hasattr(input_spec, 'to_dict') else input_spec,
                timesteps_input=config.timesteps_input,
                timesteps_output=config.timesteps_output,
            ),
            prefetch=config.prefetch,
            prefix=name,
            sequential=config.sequential_data_source,
        )
        for name, input_spec in series.items()
    ])
    for series in config.train_specs
])

train_sampler = grain.IndexSampler(
    num_records=len(train_source),
    shuffle=True,
    seed=42,
    num_epochs=config.num_epochs,
    shard_options=shard_options,
)

train_loader = grain.DataLoader(
    data_source=train_source,
    sampler=train_sampler,
    operations=transformations,
    worker_count=config.grain_num_workers,
)

train_iter = iter(train_loader)
batch = next(train_iter)

In [None]:
train_source

In [None]:
from zapbench.data_utils import get_condition_bounds, adjust_condition_bounds_for_split, get_spec, adjust_spec_for_condition_and_split

In [None]:
inclusive_min, exclusive_max = get_condition_bounds(0)
adjust_condition_bounds_for_split('train', inclusive_min, exclusive_max, 4), inclusive_min, exclusive_max

In [None]:
ts.d['t']

In [None]:
spec[ts.d['t'][slice(inclusive_min, exclusive_max)]]

In [None]:
import tensorstore as ts

spec = get_spec('240930_traces')
spec
ds = ts.open(adjust_spec_for_condition_and_split(spec, 1, 'train', 32)).result()

In [None]:
ds

In [None]:
adjust_spec_for_condition_and_split(spec, 1, 'train', 32)

In [None]:
ts.open(spec[ts.d['t'][[0, 1, 2, 3, 4, 7, 8]]].translate_to[0]).result()

In [None]:
transformations, process_batch_size, config.grain_num_workers

In [None]:
train_source.srcs[-1].srcs[0]._len, train_source.srcs[1].srcs[0]._len

In [None]:
batch.keys()

In [None]:
batch['timeseries_input'].shape, batch['timeseries_output'].shape, batch['timestep']

Explore linear model

In [None]:
from zapbench.ts_forecasting.configs import linear
import zapbench.models.util as model_util
import jax
import jax.numpy as jnp

config = linear.get_config()
model = model_util.model_from_config(config)
init_rng, dropout_rng = jax.random.split(jax.random.PRNGKey(42), num=2)
variables = model.init(init_rng, jnp.ones(config.series_shape), train=False)
params = variables['params']
batch_stats = variables.get('batch_stats', None)

input_batch = jnp.ones((8, 4, 71721))
output = model.apply(variables, input_batch, train=False)

In [None]:
model

In [None]:
input_batch = jnp.ones((8, 4, 71721))
output = model.apply(variables, input_batch, train=False)

print("=== Model Structure ===")
print(model)

print("\n=== Parameter Shapes ===")
print(jax.tree_util.tree_map(lambda x: x.shape, params))

print("\n=== Detailed Parameters ===")
for path, param in jax.tree_util.tree_flatten_with_path(params)[0]:
    path_str = '.'.join(str(k.key) for k in path)
    print(f"{path_str:20} {param.shape} ({param.size:,} params)")

print("\n=== Model Tabulate ===")
print(model.tabulate(jax.random.PRNGKey(0), jnp.ones((1, 4, 71721))))

print(f"\nOutput shape: {output.shape}")


Review code changes

In [None]:
import tensorstore as ts
from typing import Optional
from connectomics.jax import grain_util
FlatFeatures = grain_util.FlatFeatures

from zapbench import constants

In [None]:
constants.get_dataset_config(constants.DEFAULT_DATASET)['condition_intervals']

In [None]:

## 2. DATA_UTILS.PY MODIFICATIONS

### New Function: get_condition_intervals()
def get_condition_intervals(condition: int, dataset_name: str = constants.DEFAULT_DATASET) -> tuple[tuple[int, int], ...]:
    """Get padded intervals for a condition."""
    dataset_config = constants.get_dataset_config(dataset_name)
    intervals = dataset_config['condition_intervals'][condition]

    padded_intervals = []
    for start, end in intervals:
        padded_start = start + constants.CONDITION_PADDING
        padded_end = end - constants.CONDITION_PADDING
        if padded_start < padded_end:
            padded_intervals.append((padded_start, padded_end))

    return tuple(padded_intervals)

get_condition_intervals(2)

In [None]:
from zapbench import constants

def safe_calculate_window_size(num_timesteps_context: int) -> int:
    """Calculate window size with safety checks and logging."""
    if num_timesteps_context <= 0:
        raise ValueError(f"num_timesteps_context must be > 0, got {num_timesteps_context}")

    if num_timesteps_context > constants.MAX_CONTEXT_LENGTH:
        raise ValueError(f"num_timesteps_context {num_timesteps_context} exceeds MAX_CONTEXT_LENGTH {constants.MAX_CONTEXT_LENGTH}")

    window_size = num_timesteps_context + constants.PREDICTION_WINDOW_LENGTH
    return window_size

safe_calculate_window_size(4)

In [None]:
def build_valid_timesteps(intervals: tuple[tuple[int, int], ...], window_size: int) -> list[int]:
    """Build timesteps that can start complete windows within intervals."""
    valid_timesteps = []
    min_interval_size = float('inf')

    for start, end in intervals:
        interval_size = end - start
        min_interval_size = min(min_interval_size, interval_size)

        if interval_size >= window_size:
            valid_timesteps.extend(range(start, end - window_size + 1))
        else:
            print(f"Warning: Interval [{start}, {end}) too small for window_size {window_size}")

    if not valid_timesteps:
        raise ValueError(f"No intervals large enough for window_size={window_size}. "
                        f"Minimum interval size: {min_interval_size}")

    return sorted(valid_timesteps)

print(constants.get_dataset_config(constants.DEFAULT_DATASET)['condition_intervals'][0])
build_valid_timesteps(((0, 10),(20, 30)), 7)

In [None]:
def adjust_spec_for_condition_and_split(
    spec: ts.Spec,
    condition: int,
    split: Optional[str],
    num_timesteps_context: int,
    dataset_name: str = constants.DEFAULT_DATASET,
) -> ts.Spec:
    """Adjust spec for multi-interval conditions with gap-aware windowing."""
    intervals = get_condition_intervals(condition, dataset_name)
    window_size = safe_calculate_window_size(num_timesteps_context)
    valid_timesteps = build_valid_timesteps(intervals, window_size)

    # Apply train/val/test split
    if split:
        total = len(valid_timesteps)
        test_count = int(total * constants.TEST_FRACTION)
        val_count = int(total * constants.VAL_FRACTION)
        train_count = total - test_count - val_count

        if split == 'train':
            valid_timesteps = valid_timesteps[:train_count]
        elif split == 'val':
            val_start = max(0, train_count - num_timesteps_context)
            valid_timesteps = valid_timesteps[val_start:train_count + val_count]
        elif split == 'test':
            test_start = max(0, train_count + val_count - num_timesteps_context)
            valid_timesteps = valid_timesteps[test_start:]
        elif split == 'test_holdout':
            holdout_start = max(0, total - constants.MAX_CONTEXT_LENGTH - constants.PREDICTION_WINDOW_LENGTH)
            valid_timesteps = valid_timesteps[holdout_start:]

    is_contiguous = len(valid_timesteps) == (valid_timesteps[-1] - valid_timesteps[0] + 1)
    if is_contiguous:
        return spec[ts.d['t'][slice(valid_timesteps[0], valid_timesteps[-1] + 1)]].translate_to[0]
    else:
        return spec[ts.d['t'][valid_timesteps]].translate_to[0]


adjust_spec_for_condition_and_split(
    spec,
    condition=0,
    split='train',
    num_timesteps_context=constants.MAX_CONTEXT_LENGTH,
    dataset_name=constants.DEFAULT_DATASET,
)