This example is available as a jupyter notebook [here](https://github.com/SimiPixel/x_xy_v2/blob/main/docs/notebooks/control.ipynb).

## Balance an inverted Pendulum on a cart

In [6]:
import x_xy

from x_xy.algorithms.generator.pd_control import _pd_control

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

import mediapy as media


The `step` function also takes generalized forces `tau` applied to the degrees of freedom its third input `step(sys, state, taus)`.

Let's consider an inverted pendulum on a cart, and apply a left-right force onto the cart such that the pole stays in the upright position.

In [2]:
xml_str = """
<x_xy model="inv_pendulum">
    <options gravity="0 0 9.81" dt="0.01"/>
    <defaults>
        <geom edge_color="black" color="white"/>
    </defaults>
    <worldbody>
        <body name="cart" joint="px" damping="0.01">
            <geom type="box" mass="1" dim="0.4 0.1 0.1"/>
            <body name="pendulum" joint="ry" euler="0 -90 0" damping="0.01">
                <geom type="box" mass="0.5" pos="0.5 0 0" dim="1 0.1 0.1"/>
            </body>
        </body>
    </worldbody>
</x_xy>
"""

sys = x_xy.load_sys_from_str(xml_str)
state = x_xy.State.create(sys, q=jnp.array([0.0, 0.2])) 

xs = []
T = 10.0
for t in range(int(T / sys.dt)):
    measurement_noise = np.random.normal() * 5
    phi = jnp.rad2deg(state.q[1]) + measurement_noise
    cart_motor_input = 0.1 * phi * abs(phi)
    taus = jnp.clip(jnp.array([cart_motor_input, 0.0]), -10, 10) 
    state = jax.jit(x_xy.step)(sys, state, taus)
    xs.append(state.x)

In [3]:
def show_video(sys, xs: list[x_xy.Transform]):
    assert sys.dt == 0.01
    # only render every fourth to get a framerate of 25 fps
    frames = x_xy.render(sys, [xs[i] for i in range(0, len(xs), 4)], camera="c", add_cameras={-1: '<camera name="c" mode="targetbody" target="0" pos="0 -2 2"/>'})
    # convert rgba to rgb
    frames = [frame[..., :3] for frame in frames]
    media.show_video(frames, fps=25)

show_video(sys, xs)

Rendering frames..: 100%|██████████| 250/250 [00:01<00:00, 173.05it/s]


0
This browser does not support the video tag.


## PD Control

In [7]:
xml_str = """
<x_xy>
    <options gravity="0 0 9.81" dt="0.01"/>
    <worldbody>
        <body name="pendulum" joint="ry" euler="0 90 0" damping="0.01" pos="0 0 1">
            <geom type="box" mass="0.5" pos="0.5 0 0" dim="1 0.1 0.1"/>
        </body>
    </worldbody>
</x_xy>
"""

sys = x_xy.load_sys_from_str(xml_str)
P, D = jnp.array([10.0]), jnp.array([1.0])

def simulate_pd_control(sys, P, D):
    controller = _pd_control(P, D)
    # reference signal
    q_ref = jnp.ones((1000, 1)) * jnp.pi / 2
    controller_state = controller.init(sys, q_ref)
    state = x_xy.State.create(sys) 

    xs = []
    T = 5.0
    for t in range(int(T / sys.dt)):
        controller_state, taus = jax.jit(controller.apply)(controller_state, sys, state)
        state = jax.jit(x_xy.step)(sys, state, taus)
        xs.append(state.x)
    return xs

In [8]:
xs = simulate_pd_control(sys, P, D)
show_video(sys, xs)

Rendering frames..: 100%|██████████| 125/125 [00:00<00:00, 135.27it/s]


0
This browser does not support the video tag.


Note the steady state error. This is because we have gravity and no Integral part (so no PID control).

If we remove gravity the steady state error also vanishes (as is expected.)

In [9]:
sys_nograv = sys.replace(gravity = sys.gravity * 0.0)
xs = simulate_pd_control(sys_nograv, P, D)
show_video(sys_nograv, xs)

Rendering frames..: 100%|██████████| 125/125 [00:00<00:00, 138.94it/s]


0
This browser does not support the video tag.
