This example is available as a jupyter notebook [here](https://github.com/SimiPixel/x_xy_v2/blob/main/docs/notebooks/batched_simulation.ipynb).

## Batched Dynamical Simulation

`System` object is a registered Jax-PyTree. This means it's a nested array.

This enables us to stack multiple systems (or states) to enable vectorized operations.

### Batched System

I.e. simulating two different system with the same initial state.

In [8]:
import x_xy

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

import mediapy as media
import vispy

# other backends are fine too, just not `jupyter_rfb`. You can install `pyqt6` using `pip install pyqt6`
vispy.use("pyqt6")

xml_str = """
<x_xy model="double_pendulum">
    <options gravity="0 0 9.81" dt="0.01"/>
    <worldbody>
        <body name="upper" euler="0 90 0" joint="ry" damping="2">
            <geom type="box" mass="10" pos="0.5 0 0" dim="1 0.25 0.2"/>
            <body name="lower" pos="1 0 0" joint="ry" damping="2">
                <geom type="box" mass="10" pos="0.5 0 0" dim="1 0.25 0.2"/>
            </body>
        </body>
    </worldbody>
</x_xy>
"""

sys = x_xy.load_sys_from_str(xml_str)
state = x_xy.State.create(sys)

In [9]:
# second system with gravity disabled
sys_nograv = sys.replace(gravity = sys.gravity * 0.0)
sys_batched = sys.batch(sys_nograv)

next_state_batched = jax.vmap(x_xy.step, in_axes=(0, None))(sys_batched, state)

In [11]:
# note how the state of the system without gravity has not changed at all
next_state_batched.q

Array([[-1.7982468e-10,  2.3305433e-10],
       [ 0.0000000e+00,  0.0000000e+00]], dtype=float32)

### Batched State

In [5]:
second_state = x_xy.State.create(sys, qd=jnp.ones((2,)))
state_batched = state.batch(second_state)
next_state_batched = jax.vmap(x_xy.step, in_axes=(None, 0))(sys, state_batched)

In [7]:
next_state_batched.q

Array([[-1.7982468e-10,  2.3305433e-10],
       [ 1.0048340e-02,  9.8215193e-03]], dtype=float32)

## Batched Kinematic Simulation

Batched kinematic simulation is done using either `x_xy.batch_generator` or `x_xy.offline_generator`.

In [15]:
gen = x_xy.build_generator(sys, x_xy.RCMG_Config(T=10.0, t_max=1.5))
batchsize = 8
gen_batched = x_xy.batch_generator(gen, batchsizes=batchsize)
seed = 1
qs, xs = gen_batched(jax.random.PRNGKey(seed))

In [16]:
qs.shape

(8, 1000, 2)