In [2]:
from zapbench.ts_forecasting.configs import linear
import zapbench.models.util as model_util
import jax
import jax.numpy as jnp

config = linear.get_config()
model = model_util.model_from_config(config)
init_rng, dropout_rng = jax.random.split(jax.random.PRNGKey(42), num=2)
variables = model.init(init_rng, jnp.ones(config.series_shape), train=False)
params = variables['params']
batch_stats = variables.get('batch_stats', None)

input_batch = jnp.ones((8, 4, 71721))
output = model.apply(variables, input_batch, train=False)

print("=== Model Structure ===")
print(model)

print("\n=== Parameter Shapes ===")
print(jax.tree_util.tree_map(lambda x: x.shape, params))

print("\n=== Detailed Parameters ===")
for path, param in jax.tree_util.tree_flatten_with_path(params)[0]:
    path_str = '.'.join(str(k.key) for k in path)
    print(f"{path_str:20} {param.shape} ({param.size:,} params)")

print("\n=== Model Tabulate ===")
print(model.tabulate(jax.random.PRNGKey(0), jnp.ones((1, 4, 71721))))

print(f"\nOutput shape: {output.shape}")


=== Model Structure ===
Nlinear(
    # attributes
    config = NlinearConfig(num_outputs=32, constant_init=True, normalization=False)
)

=== Parameter Shapes ===
{'Dense_0': {'bias': (32,), 'kernel': (4, 32)}}

=== Detailed Parameters ===
Dense_0.bias         (32,) (32 params)
Dense_0.kernel       (4, 32) (128 params)

=== Model Tabulate ===

[3m                                Nlinear Summary                                 [0m
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule [0m[1m [0m┃[1m [0m[1minputs           [0m[1m [0m┃[1m [0m[1moutputs         [0m[1m [0m┃[1m [0m[1mparams           [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│         │ Nlinear │ [2mfloat32[0m[1,4,7172… │ [2mfloat32[0m[1,32,71… │                   │
├─────────┼─────────┼───────────────────┼──────────────────┼───────────────────┤
│ Dense_0 │ Dense   │ [2

In [4]:
for series in config.train_specs:
  for name, input_spec in series.items():
    print(name, input_spec)
    print()

timeseries {'driver': 'zarr3', 'kvstore': {'bucket': 'zapbench-release', 'driver': 'gcs', 'path': 'volumes/20240930/traces/'}, 'transform': {'input_exclusive_max': [454, 71721], 'input_inclusive_min': [0, 0], 'input_labels': ['t', 'f'], 'output': [{'input_dimension': 0, 'offset': 1}, {'input_dimension': 1}]}}

covariates {'driver': 'zarr', 'kvstore': {'bucket': 'zapbench-release', 'driver': 'gcs', 'path': 'volumes/20240930/stimuli_features/'}, 'metadata': {'shape': [7879, 26]}, 'transform': {'input_exclusive_max': [454, [26]], 'input_inclusive_min': [0, 0], 'input_labels': ['t', 'f'], 'output': [{'input_dimension': 0, 'offset': 1}, {'input_dimension': 1}]}}

timeseries {'driver': 'zarr3', 'kvstore': {'bucket': 'zapbench-release', 'driver': 'gcs', 'path': 'volumes/20240930/traces/'}, 'transform': {'input_exclusive_max': [1240, 71721], 'input_inclusive_min': [0, 0], 'input_labels': ['t', 'f'], 'output': [{'input_dimension': 0, 'offset': 650}, {'input_dimension': 1}]}}

covariates {'drive