In [1]:
import os

# simulate multicore gpu on cpu
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

In [2]:
import jax
import jax.numpy as jnp
from jaxfish.data_classes import freeze, Fish, Simulation, Terrain
from jaxfish.defaults import BRAIN_EIGHT_EYE_NO_HIDDEN
from jaxfish.simulation import run_simulation
from jaxfish.visulization import get_map_rgb
import matplotlib.pyplot as plt
from ipywidgets import interact
%matplotlib inline


seeds = jnp.arange(8)

terrain_config = freeze(Terrain())

fish_config = freeze(Fish())

simulation_config = Simulation(simulation_ind=0)
simulation_config.psp_waveform_length = 30  # reduce memory use
simulation_config.max_simulation_length = 1000  # reduce memory use
simulation_config = freeze(simulation_config)

brain_config = freeze(BRAIN_EIGHT_EYE_NO_HIDDEN)

In [3]:
run_simulation_parallel = jax.pmap(
    run_simulation, 
    in_axes=(0, None, None, None, None), 
    out_axes=0,
    static_broadcasted_argnums=(1, 2, 3, 4),
)

simulation_resulsts = run_simulation_parallel(
    seeds, terrain_config, simulation_config, fish_config, brain_config, 
)

In [4]:
sim_num = 0
simulation_result = [v[sim_num] for v in simulation_resulsts]

f, ax = plt.subplots()

def update_plot(t=0):
    ax.clear()  # Clear the previous plot
    health_history = simulation_result[6]
    map_rgb = get_map_rgb(t, simulation_result)
    ax.imshow(map_rgb)
    ax.set_title(f"sim_num: {sim_num}, t: {t+1:3d}/{len(health_history)}, health: {health_history[t]:5.2f}")
    ax.set_axis_off()
    plt.close(f)  # Close the figure to prevent display
    return f  # Return the figure object

interact(update_plot, t=(0, len(simulation_result[6]) - 1))

interactive(children=(IntSlider(value=0, description='t', max=999), Output()), _dom_classes=('widget-interact'…

<function __main__.update_plot(t=0)>

In [5]:
sim_num = 5
simulation_result = [v[sim_num] for v in simulation_resulsts]

f, ax = plt.subplots()

def update_plot(t=0):
    ax.clear()  # Clear the previous plot
    health_history = simulation_result[6]
    map_rgb = get_map_rgb(t, simulation_result)
    ax.imshow(map_rgb)
    ax.set_title(f"sim_num: {sim_num}, t: {t+1:3d}/{len(health_history)}, health: {health_history[t]:5.2f}")
    ax.set_axis_off()
    plt.close(f)  # Close the figure to prevent display
    return f  # Return the figure object

interact(update_plot, t=(0, len(simulation_result[6]) - 1))

interactive(children=(IntSlider(value=0, description='t', max=999), Output()), _dom_classes=('widget-interact'…

<function __main__.update_plot(t=0)>