In [1]:
import os

# simulate multicore gpu on cpu
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

In [2]:
import jax
import jax.numpy as jnp
from jaxfish.data_classes import freeze, Fish, Simulation, Terrain
from jaxfish.defaults import BRAIN_EIGHT_EYE_NO_HIDDEN
from jaxfish.simulation import run_simulation


seeds = jnp.arange(8)

terrain_config = freeze(Terrain())

fish_config = freeze(Fish())

simulation_config = Simulation(simulation_ind=0)
simulation_config.psp_waveform_length = 30  # reduce memory use
simulation_config.max_simulation_length = 1000  # reduce memory use
simulation_config = freeze(simulation_config)

brain_config = freeze(BRAIN_EIGHT_EYE_NO_HIDDEN)

In [6]:
run_simulation_parallel = jax.pmap(
    run_simulation, 
    in_axes=(0, None, None, None, None), 
    out_axes=0,
    static_broadcasted_argnums=(1, 2, 3, 4),
)

simulation_resulsts = run_simulation_parallel(
    seeds, terrain_config, simulation_config, fish_config, brain_config, 
)

In [9]:
simulation_result_0 = [r[0] for r in simulation_resulsts]

(
    _,
    _,
    _,
    terrain_map,
    food_positions_history,
    fish_position_history,
    health_history,
    firing_history,
    _,
    psp_history,
) = simulation_result_0

print(f"\n{fish_position_history[0:10]=}")
print(f"\n{health_history[0:10]=}")
print(f"\n{firing_history[0:10]=}")


fish_position_history[0:10]=Array([[5, 2],
       [5, 2],
       [5, 3],
       [5, 2],
       [5, 3],
       [5, 2],
       [5, 3],
       [5, 2],
       [5, 3],
       [5, 2]], dtype=int32)

health_history[0:10]=Array([10.       ,  9.99     ,  9.978999 ,  9.9679985,  9.956998 ,
        9.945997 ,  9.934997 ,  9.923996 ,  9.912995 ,  9.901995 ],      dtype=float32)

firing_history[0:10]=Array([[0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)
