# Lecture 2 - Dynamics and Integration

In [None]:
import jax
import jax.numpy as jnp
import mujoco
import viser
from jax.experimental.ode import odeint
from judo.visualizers.model import ViserMjModel
from matplotlib import pyplot as plt


## Numerical integrators

Here we implement three simple numerical integrators: Euler, Midpoint, and RK4. Each integrator takes a dynamics function `f`, the current state `x`, control input `u`, and time step `dt`, and returns the next state.

The Euler update is given by:
$$x_{t+1} = x_t + f(x_t, u_t) \cdot dt$$

The Midpoint update is given by:
$$k_1 = f(x_t, u_t)$$
$$k_2 = f(x_t + \frac{dt}{2} k_1, u_t)$$
$$ x_{t+1}= x_t + k_2 \cdot dt $$

And RK4 is given by:
$$k_1 = f(x_t, u_t)$$
$$k_2 = f(x_t + \frac{dt}{2} k_1, u_t)$$
$$k_3 = f(x_t + \frac{dt}{2} k_2, u_t)$$
$$k_4 = f(x_t + dt \cdot k_3, u_t)$$
$$ x_{t+1}= x_t + \frac{dt}{6} (k_1 + 2k_2 + 2k_3 + k_4) $$

In [None]:
def euler_step(f: callable, x: jnp.ndarray, u: jnp.ndarray, dt: float) -> jnp.ndarray:
    """A single Euler integration step of x_dot = f(x, u).
    
    Args:
        f: Dynamics function f(x, u) returning x_dot.
        x: Current state.
        u: Current control input.
        dt: Time step for integration.
    """
    # Step forward by dt using slope at the beginning of the interval.
    return x + dt * f(x, u)

def midpoint_step(f: callable, x: jnp.ndarray, u: jnp.ndarray, dt: float) -> jnp.ndarray:
    """A single midpoint integration step of x_dot = f(x, u).
    
    Args:
        f: Dynamics function f(x, u) returning x_dot.
        x: Current state.
        u: Current control input.
        dt: Time step for integration.
    """
    # Evaluate dynamics at the beginning of the interval.
    k1 = f(x, u)

    # Evaluate dynamics at (Euler estimate of) midpoint.
    k2 = f(x + (1/2) * dt * k1, u)

    # Step forward by dt, but using the midpoint slope.
    return x + dt * k2

def rk4_step(f: callable, x: jnp.ndarray, u: jnp.ndarray, dt: float) -> jnp.ndarray:
    """A single RK4 integration step of x_dot = f(x, u).        
    
    Args:
        f: Dynamics function f(x, u) returning x_dot.
        x: Current state.
        u: Current control input.
        dt: Time step for integration.
    """
    # Compute all four slopes.
    k1 = f(x, u)
    k2 = f(x + (1/2) * dt * k1, u)
    k3 = f(x + (1/2) * dt * k2, u)
    k4 = f(x + dt * k3, u)

    # Implement RK4 formula to combine slopes.
    return x + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)

## Pendulum Dynamics
Here we implement continuous-time dynamics for a simple pendulum system. The state $\bm{x} = \begin{bmatrix} \theta \\ \dot{\theta} \end{bmatrix}$ consists of the angle $\theta$ and angular velocity $\dot{\theta}$. The dynamics are given by:
$$ \dot{\bm{x}} = \bm{f}(\bm{x}, \bm{u}) = \begin{bmatrix} \dot{\theta} \\ -\frac{mg \ell}{I} \sin(\theta) + \frac{1}{I} u \end{bmatrix} $$
where $g$ is the acceleration due to gravity, $\ell$ is the length of the pendulum, $m$ is the mass, $I = m \ell^2$ is the moment of inertia, and $u$ is the control torque applied at the pivot.

In [None]:
## Pendulum parameters
m = 1.0  # mass
l = 1.0  # length
g = 9.81  # gravity
I = m * l**2  # moment of inertia

def f_pendulum_continuous(x: jnp.ndarray, u: jnp.ndarray) -> jnp.ndarray:
    """Continuous-time dynamics of a simple pendulum.

    Args:
        x: state [theta, theta_dot]
        u: control input (torque)

    Returns:
        x_dot: time derivative of state
    """
    # Unpack state.
    theta, theta_dot = x
    
    # Implement acceleration dynamics.
    theta_ddot = (u[0] - (m * g * l / I) * jnp.sin(theta)) / I

    # Repack state derivative.
    return jnp.array([theta_dot, theta_ddot])

def zero_controller(x: jnp.ndarray) -> jnp.ndarray:
    """A constant zero controller, to simulate passive dynamics."""
    return jnp.zeros((1,))

kp = 5.0  # proportional gain
kd = 1.0   # derivative gain
THETA_DES = jnp.pi
def pd_controller(x: jnp.ndarray) -> jnp.ndarray:
    """A simple PD controller for the pendulum, to regulate to upright position.

    Args:
        x: state [theta, theta_dot]
    Returns:
        u: control input (torque)
    """
    # Unpack state.
    theta, theta_dot = x

    # Implement PD control law.
    u = jnp.array([-kp * (theta - THETA_DES) - kd * theta_dot])

    return u

In [None]:
def rollout_dynamics(step_fn: callable, f: callable, x0: jnp.ndarray, dt: float, num_steps: int, controller: callable) -> jnp.ndarray:
    """Rolls out system dynamics with a particular controller."""
    def scan_fn(x, _):
        u = controller(x)
        x_next = step_fn(f, x, u, dt)
        return x_next, x_next
    
    _, xs = jax.lax.scan(scan_fn, x0, None, length=num_steps)
    return jnp.vstack([x0, xs])

def make_f_for_odeint(f: callable, controller: callable) -> callable:
    """Wraps dynamics to include controller for use with odeint."""
    def f_for_odeint(x, t):
        u = controller(x)
        return f(x, u)
    return f_for_odeint

# Run comparison using the original open-loop dynamics
x0 = jnp.array([0.5, 0.0])
T = 20.0
dt = 0.1
N = int(T / dt)
ts = jnp.linspace(0, T, N + 1)

f_pendulum_odeint = make_f_for_odeint(f_pendulum_continuous, zero_controller)

xs_euler = rollout_dynamics(euler_step, f_pendulum_continuous, x0, dt, N, zero_controller)
xs_midpoint = rollout_dynamics(midpoint_step, f_pendulum_continuous, x0, dt, N, zero_controller)
xs_rk4 = rollout_dynamics(rk4_step, f_pendulum_continuous, x0, dt, N, zero_controller)
xs_odeint = odeint(f_pendulum_odeint, x0, ts)

In [None]:
# Plot trajectories and energy over time in two subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
ax1.plot(ts, xs_euler[:, 0], label='Euler')
ax1.plot(ts, xs_midpoint[:, 0], label='Midpoint')
ax1.plot(ts, xs_rk4[:, 0], label='RK4')
ax1.plot(ts, xs_odeint[:, 0], label='odeint', linestyle='--')
ax1.set_xlabel('Time [s]')
ax1.set_ylabel('Angle [rad]')
ax1.set_title('Pendulum Angle over Time')
ax1.legend()

def compute_energy(x):
    theta, theta_dot = x
    kinetic = 0.5 * I * theta_dot**2
    potential = m * g * l * (1 - jnp.cos(theta))
    return kinetic + potential

# Apply energy with vmap
energies_euler = jax.vmap(compute_energy)(xs_euler)
energies_midpoint = jax.vmap(compute_energy)(xs_midpoint)
energies_rk4 = jax.vmap(compute_energy)(xs_rk4)
energies_odeint = jax.vmap(compute_energy)(xs_odeint)

ax2.plot(ts, energies_euler, label='Euler')
ax2.plot(ts, energies_midpoint, label='Midpoint')
ax2.plot(ts, energies_rk4, label='RK4')
ax2.plot(ts, energies_odeint, label='odeint', linestyle='--')
ax2.set_xlabel('Time [s]')
ax2.set_ylabel('Energy [J]')
ax2.set_title('Pendulum Energy over Time')
# log scale for energy plot
ax2.set_yscale('log')
ax2.legend()
plt.tight_layout()



In [None]:
PENDULUM_XML ="""
<mujoco model="pendulum_comparison">
  <worldbody>
    <light diffuse=".5 .5 .5" pos="0 0 3" dir="0 0 -1"/>
    <geom type="plane" size="3 3 0.1" rgba=".9 .9 .9 1"/>
    
    <!-- Euler (C0 blue) -->
    <body name="euler" pos="-1.5 0 1">
      <joint name="euler_joint" type="hinge" axis="0 1 0"/>
      <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="0.12 0.47 0.71 1"/>
      <geom type="sphere" pos="0 0 -0.8" size="0.08" rgba="0.12 0.47 0.71 1"/>
    </body>
    
    <!-- Midpoint (C1 orange) -->
    <body name="midpoint" pos="-0.5 0 1">
      <joint name="midpoint_joint" type="hinge" axis="0 1 0"/>
      <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="1.0 0.5 0.05 1"/>
      <geom type="sphere" pos="0 0 -0.8" size="0.08" rgba="1.0 0.5 0.05 1"/>
    </body>
    
    <!-- RK4 (C2 green) -->
    <body name="rk4" pos="0.5 0 1">
      <joint name="rk4_joint" type="hinge" axis="0 1 0"/>
      <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="0.17 0.63 0.17 1"/>
      <geom type="sphere" pos="0 0 -0.8" size="0.08" rgba="0.17 0.63 0.17 1"/>
    </body>
    
    <!-- odeint (C3 red) -->
    <body name="odeint" pos="1.5 0 1">
      <joint name="odeint_joint" type="hinge" axis="0 1 0"/>
      <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="0.84 0.15 0.16 1"/>
      <geom type="sphere" pos="0 0 -0.8" size="0.08" rgba="0.84 0.15 0.16 1"/>
    </body>
  </worldbody>
</mujoco>
"""

mj_spec = mujoco.MjSpec.from_string(PENDULUM_XML)
mj_model = mj_spec.compile()
mj_data = mujoco.MjData(mj_model)

# Only create if they don't exist
if 'server' not in dir() or server is None:
    server = viser.ViserServer()
else:
    server.scene.reset()

viser_model = ViserMjModel(server, mj_spec, )

# Animate the scene.
serializer = server.get_scene_serializer()

for i in range(len(ts)):
    # Set joint angles for each pendulum
    mj_data.qpos[ mj_model.joint('euler_joint').qposadr ] = xs_euler[i, 0]
    mj_data.qpos[ mj_model.joint('midpoint_joint').qposadr ] = xs_midpoint[i, 0]
    mj_data.qpos[ mj_model.joint('rk4_joint').qposadr ] = xs_rk4[i, 0]
    mj_data.qpos[ mj_model.joint('odeint_joint').qposadr ] = xs_odeint[i, 0]
    
    # Step the simulation to update visualization
    mujoco.mj_forward(mj_model, mj_data)
    viser_model.set_data(mj_data)
    serializer.insert_sleep(dt)

In [None]:
serializer.show()

## Double Pendulum Dynamics
We can also study the integrators on a double pendulum (two links connected in series). The state $\bm{x} = \begin{bmatrix} \theta_1 \\ \theta_2 \\ \dot{\theta}_1 \\ \dot{\theta}_2 \end{bmatrix}$ consists of the angles and angular velocities of both links. The dynamics are more complex due to the coupling between the two links - in fact, they are *chaotic* meaning that small differences in initial conditions (or small numerical integration errors) can lead to vastly different trajectories over time.

In [None]:
# Physical constants.
m1 = m2 = 1.0  # masses
l1 = l2 = 1.0  # lengths
g = 9.81       # gravity

def f_double_pendulum_continuous(x: jnp.ndarray, u: jnp.ndarray) -> jnp.ndarray:
    """Continuous-time dynamics of a double pendulum.

    Args:
        x: state [theta1, theta2, theta1_dot, theta2_dot]
        u: control input (torque on first joint) 

    Returns:
        x_dot: time derivative of state
    """
    # Unpack state.
    t1, t2, t1_dot, t2_dot = x

    # Compute some convenience quantities for dynamics.
    delta = t1 - t2
    alpha = m1 + m2 * jnp.sin(delta)**2

    # Implement acceleration dynamics - can derive using Lagrangian mechanics.
    t1_ddot = (-jnp.sin(delta)*(m2*l1 * t1_dot**2 * jnp.cos(delta)) -g * ((m1 + m2) * jnp.sin(t1) - m2 * jnp.sin(t2)* jnp.cos(delta)))/ (l1 * alpha)
    t2_ddot = (jnp.sin(delta) * ((m1+m2)* l1 * t1_dot **2 + m2 * l2 * t2_dot**2 * jnp.cos(delta)) + g * (m1+ m2) * (jnp.sin(t1)* jnp.cos(delta) - jnp.sin(t2)) / (l2 * alpha))

    return jnp.array([t1_dot, t2_dot, t1_ddot, t2_ddot])

def zero_controller_double_pendulum(x: jnp.ndarray) -> jnp.ndarray:
    """A constant zero controller, to simulate passive dynamics."""
    return jnp.zeros((2,))

def pd_controller_double_pendulum(x: jnp.ndarray) -> jnp.ndarray:
    """A simple PD controller for the double pendulum, to regulate to upright position."""
    t1, t2, t1_dot, t2_dot = x

    # Implement PD control law for both joints.
    u1 = -kp * (t1 - THETA_DES) - kd * t1_dot
    u2 = -kp * (t2 - THETA_DES) - kd * t2_dot

    return jnp.array([u1, u2])

# Compute odeint convenience wrapper for double pendulum.
f_for_odeint_double_pendulum = make_f_for_odeint(f_double_pendulum_continuous, zero_controller_double_pendulum)

# Run comparison using the original open-loop dynamics
x0_double = jnp.array([0.75, 0.75, 0.0, 0.0])
T = 50.0
dt = 0.1
N = int(T / dt)
ts = jnp.linspace(0, T, N + 1)
xs_euler_double = rollout_dynamics(euler_step, f_double_pendulum_continuous, x0_double, dt, N, zero_controller_double_pendulum)
xs_midpoint_double = rollout_dynamics(midpoint_step, f_double_pendulum_continuous, x0_double, dt, N, zero_controller_double_pendulum)
xs_rk4_double = rollout_dynamics(rk4_step, f_double_pendulum_continuous, x0_double, dt, N, zero_controller_double_pendulum)
xs_odeint_double = odeint(f_for_odeint_double_pendulum, x0_double, ts)

In [None]:
# Plot angle trajectories and energy over time for double pendulum in three subplots
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))
ax1.plot(ts, xs_euler_double[:, 0], label='Euler')
ax1.plot(ts, xs_midpoint_double[:, 0], label='Midpoint')
ax1.plot(ts, xs_rk4_double[:, 0], label='RK4')
ax1.plot(ts, xs_odeint_double[:, 0], label='odeint', linestyle='--')
ax1.set_xlabel('Time [s]')
ax1.set_ylabel('Angle [rad]')
ax1.set_title('Pendulum Angle 1 over Time')
ax1.legend()

ax2.plot(ts, xs_euler_double[:, 1], label='Euler')
ax2.plot(ts, xs_midpoint_double[:, 1], label='Midpoint')
ax2.plot(ts, xs_rk4_double[:, 1], label='RK4')
ax2.plot(ts, xs_odeint_double[:, 1], label='odeint', linestyle='--')
ax2.set_xlabel('Time [s]')
ax2.set_ylabel('Angle [rad]')
ax2.set_title('Pendulum Angle 2 over Time')
ax2.legend()

def compute_energy(x):
    t1, t2, t1_dot, t2_dot = x
    kinetic = 0.5 * (m1 + m2) * l1**2 * t1_dot**2 + 0.5 * m2 * l2**2 * t2_dot**2 + m2 * l1 * l2 * t1_dot * t2_dot * jnp.cos(t1 - t2)
    potential = -(m1 + m2) * g * l1 * jnp.cos(t1) - m2 * g * l2 * jnp.cos(t2)
    return kinetic + potential

# Apply energy with vmap
energies_euler = jax.vmap(compute_energy)(xs_euler_double)
energies_midpoint = jax.vmap(compute_energy)(xs_midpoint_double)
energies_rk4 = jax.vmap(compute_energy)(xs_rk4_double)
energies_odeint = jax.vmap(compute_energy)(xs_odeint_double)

ax3.plot(ts, energies_euler, label='Euler')
ax3.plot(ts, energies_midpoint, label='Midpoint')
ax3.plot(ts, energies_rk4, label='RK4')
ax3.plot(ts, energies_odeint, label='odeint', linestyle='--')
ax3.set_xlabel('Time [s]')
ax3.set_ylabel('Energy [J]')
ax3.set_title('Pendulum Energy over Time')
ax3.legend()
plt.tight_layout()

In [None]:
DOUBLE_PENDULUM_XML = """
<mujoco model="double_pendulum_comparison">
  <worldbody>
    <light diffuse=".5 .5 .5" pos="0 0 3" dir="0 0 -1"/>
    <geom type="plane" size="3 3 0.1" rgba=".9 .9 .9 1"/>
    <!-- Euler (C0 blue) -->
    <body name="euler_1" pos="-1.5 0 1.5">
      <joint name="euler_joint_1" type="hinge" axis="0 1 0"/>
      <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="0.12 0.47 0.71 1"/>
      <body name="euler_2" pos="0 0 -0.8">
        <joint name="euler_joint_2" type="hinge" axis="0 1 0"/>
        <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="0.12 0.47 0.71 0.7"/>
        <geom type="sphere" pos="0 0 -0.8" size="0.08" rgba="0.12 0.47 0.71 1"/>
      </body>
    </body>
    <!-- Midpoint (C1 orange) -->
    <body name="midpoint_1" pos="-0.5 0 1.5">
      <joint name="midpoint_joint_1" type="hinge" axis="0 1 0"/>
      <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="1.0 0.5 0.05 1"/>
      <body name="midpoint_2" pos="0 0 -0.8">
        <joint name="midpoint_joint_2" type="hinge" axis="0 1 0"/>
        <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="1.0 0.5 0.05 0.7"/>
        <geom type="sphere" pos="0 0 -0.8" size="0.08" rgba="1.0 0.5 0.05 1"/>
      </body>
    </body>
    <!-- RK4 (C2 green) -->
    <body name="rk4_1" pos="0.5 0 1.5">
      <joint name="rk4_joint_1" type="hinge" axis="0 1 0"/>
      <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="0.17 0.63 0.17 1"/>
      <body name="rk4_2" pos="0 0 -0.8">
        <joint name="rk4_joint_2" type="hinge" axis="0 1 0"/>
        <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="0.17 0.63 0.17 0.7"/>
        <geom type="sphere" pos="0 0 -0.8" size="0.08" rgba="0.17 0.63 0.17 1"/>
      </body>
    </body>
    <!-- odeint (C3 red) -->
    <body name="odeint_1" pos="1.5 0 1.5">
      <joint name="odeint_joint_1" type="hinge" axis="0 1 0"/>
      <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="0.84 0.15 0.16 1"/>
      <body name="odeint_2" pos="0 0 -0.8">
        <joint name="odeint_joint_2" type="hinge" axis="0 1 0"/>
        <geom type="capsule" fromto="0 0 0 0 0 -0.8" size="0.04" rgba="0.84 0.15 0.16 0.7"/>
        <geom type="sphere" pos="0 0 -0.8" size="0.08" rgba="0.84 0.15 0.16 1"/>
      </body>
    </body>
  </worldbody>
</mujoco>
"""

mj_spec_double_pendulum = mujoco.MjSpec.from_string(DOUBLE_PENDULUM_XML)
mj_model_double_pendulum = mj_spec_double_pendulum.compile()
mj_data_double_pendulum = mujoco.MjData(mj_model_double_pendulum)

# Only create if they don't exist
if 'server' not in dir() or server is None:
    server = viser.ViserServer()
else:
    server.scene.reset()

viser_model = ViserMjModel(server, mj_spec_double_pendulum, )

# Animate the scene.
serializer = server.get_scene_serializer()

for i in range(len(ts)):
    # Set joint angles for each pendulum
    mj_data_double_pendulum.qpos[ mj_model_double_pendulum.joint('euler_joint_1').qposadr ] = xs_euler_double[i, 0]
    mj_data_double_pendulum.qpos[ mj_model_double_pendulum.joint('midpoint_joint_1').qposadr ] = xs_midpoint_double[i, 0]
    mj_data_double_pendulum.qpos[ mj_model_double_pendulum.joint('rk4_joint_1').qposadr ] = xs_rk4_double[i, 0]
    mj_data_double_pendulum.qpos[ mj_model_double_pendulum.joint('odeint_joint_1').qposadr ] = xs_odeint_double[i, 0]

    # Set second joint angles
    mj_data_double_pendulum.qpos[ mj_model_double_pendulum.joint('euler_joint_2').qposadr ] = xs_euler_double[i, 1]
    mj_data_double_pendulum.qpos[ mj_model_double_pendulum.joint('midpoint_joint_2').qposadr ] = xs_midpoint_double[i, 1]
    mj_data_double_pendulum.qpos[ mj_model_double_pendulum.joint('rk4_joint_2').qposadr ] = xs_rk4_double[i, 1]
    mj_data_double_pendulum.qpos[ mj_model_double_pendulum.joint('odeint_joint_2').qposadr ] = xs_odeint_double[i, 1]
    
    # Step the simulation to update visualization
    mujoco.mj_forward(mj_model_double_pendulum, mj_data_double_pendulum)
    viser_model.set_data(mj_data_double_pendulum)
    serializer.insert_sleep(dt)

In [None]:
serializer.show()