In [1]:
import numpy as np

class Fish:
    def __init__(self, data:float=5.):
        self.data = data
    
    def act(self, t:int, history:np.ndarray):
        if np.random.rand() > 0.5:
            history[t] += self.data

class Simulation:
    def __init__(self, length=10000):
        self.length = length

    def run(self, fish):

        history = np.zeros((self.length, ))
        for t in range(self.length):
            fish.act(t, history)
        
        return history


simulation = Simulation()

for fish_i in range(1000):
    fish = Fish(data=np.random.rand())
    _ = simulation.run(fish)



In [2]:
import jax
import jax.numpy as jnp
from jax import random, vmap, lax

# Define the Fish behavior
@jax.jit
def fish_act(static_dict, t, rng, history):
    rand_value = random.uniform(rng)
    history = lax.cond(
        rand_value > 0.5,
        lambda h: h.at[t].add(static_dict["data"]),
        lambda h: h,
        history
    )
    return history


# one simulation for one single fish
def one_fish_one_simulation(static_dict, sim_key):

    length = static_dict["length"]
    data = static_dict["data"]

    rng = random.split(sim_key, length)
    history = jnp.zeros(length)

    def simulate_step(carry, t):
        history, rng = carry
        history = fish_act(static_dict, t, rng[t], history)
        return (history, rng), None
    
    (history, _), _ = lax.scan(simulate_step, (history, rng), jnp.arange(length))
    
    return history


# all simulation for one single fish
def one_fish_all_simulation(data, length, sim_keys):
    fn = vmap(one_fish_one_simulation, in_axes=(None, 0))
    histories = fn(
        {"data":data, "length":length}, 
        sim_keys
    )
    return histories


# Vectorized simulation for multiple fish
def run_simulations(length, fish_keys, sim_keys):
    data_array = vmap(random.uniform)(fish_keys)
    simulate_fn = vmap(one_fish_all_simulation, in_axes=(0, None, 0))
    histories = simulate_fn(data_array, length, sim_keys)
    return histories


# Parameters
num_fish = 1000
simulation_per_fish = 10
length = 10000

# Random keys for reproducibility
main_key = random.PRNGKey(42)
fish_key, simulation_key = random.split(main_key, 2)
fish_keys = random.split(fish_key, (num_fish,))
sim_keys = random.split(simulation_key, (num_fish, simulation_per_fish))

# Run simulations
histories = run_simulations(length, fish_keys, sim_keys)

In [3]:
print(histories.shape)

(1000, 10, 10000)
