In [None]:
import jax
import jax.numpy as jnp
from jax import random, vmap, lax


def simulate_step(carry, t):
    """
    This is the main logic of simulation, 
    Args:
      carry: dictionary, contains all the data needed for simulation
      t: int, timestamp

    Return:
      carry: dictionary, updated data cache after timestamp t
    """
    
    return carry, None

# one simulation for one single fish
def one_fish_one_simulation(
    carry: dict, 
    length: int
):
    """
    Loop through length timestamps, at each timestamp run `simulation_step` function
    """
    carry, _ = lax.scan(simulate_step, carry, jnp.arange(length))
    return carry


# all simulation for one single fish
def one_fish_all_simulation(carry, length):
    fn = vmap(one_fish_one_simulation, in_axes=(None, None, 0))
    histories = fn(data, length, sim_keys)
    return histories


# Vectorized simulation for multiple fish
def run_simulations(length, fish_keys, sim_keys):
    data_array = vmap(random.uniform)(fish_keys)
    simulate_fn = vmap(one_fish_all_simulation, in_axes=(0, None, 0))
    histories = simulate_fn(data_array, length, sim_keys)
    return histories


# Parameters
num_fish = 1000
simulation_per_fish = 10
length = 10000

# Random keys for reproducibility
main_key = random.PRNGKey(42)
fish_key, simulation_key = random.split(main_key, 2)
fish_keys = random.split(fish_key, (num_fish,))
sim_keys = random.split(simulation_key, (num_fish, simulation_per_fish))

# Run simulations
histories = run_simulations(length, fish_keys, sim_keys)