!!! note

    This example is available as a jupyter notebook [here](https://github.com/simon-bachhuber/ring/blob/main/docs/notebooks/batched_simulation.ipynb).

## Batched Dynamical Simulation

`System` object is a registered Jax-PyTree. This means it's a nested array.

This enables us to stack multiple systems (or states) to enable vectorized operations.

### Batched System

I.e. simulating two different system with the same initial state.

In [1]:
import ring

import jax
import jax.numpy as jnp


xml_str = """
<x_xy model="double_pendulum">
    <options gravity="0 0 9.81" dt="0.01"/>
    <worldbody>
        <body name="upper" euler="0 90 0" joint="ry" damping="2">
            <geom type="box" mass="10" pos="0.5 0 0" dim="1 0.25 0.2"/>
            <body name="lower" pos="1 0 0" joint="ry" damping="2">
                <geom type="box" mass="10" pos="0.5 0 0" dim="1 0.25 0.2"/>
            </body>
        </body>
    </worldbody>
</x_xy>
"""

sys = ring.System.create(xml_str)
state = ring.State.create(sys)

In [2]:
# second system with gravity disabled
sys_nograv = sys.replace(gravity = sys.gravity * 0.0)
sys_batched = sys.batch(sys_nograv)

next_state_batched = jax.vmap(ring.step, in_axes=(0, None))(sys_batched, state)

In [3]:
# note how the state of the system without gravity has not changed at all
next_state_batched.q

Array([[0., 0.],
       [0., 0.]], dtype=float32)

### Batched State

In [4]:
second_state = ring.State.create(sys, qd=jnp.ones((2,)))
state_batched = state.batch(second_state)
next_state_batched = jax.vmap(ring.step, in_axes=(None, 0))(sys, state_batched)

In [5]:
next_state_batched.q

Array([[0.        , 0.        ],
       [0.01004834, 0.00982152]], dtype=float32)

## Batched Kinematic Simulation

Batched kinematic simulation is done by providing the `sizes` argument to `build_generator`

In [6]:
batchsize = 8
seed = 1
gen = ring.RCMG(sys, ring.MotionConfig(T=10.0, t_max=1.5), keep_output_extras=True).to_lazy_gen(batchsize)
(X, y), (_, q, x, _) = gen(jax.random.PRNGKey(seed))

In [7]:
q.shape

(8, 1000, 2)