### Multi-agent iLQGames example

Reference: 

*Fridovich-Keil, David, et al.* "[Efficient iterative linear-quadratic approximations for nonlinear multi-player general-sum differential games](https://arxiv.org/abs/1909.04694)." 2020 IEEE international conference on robotics and automation (ICRA).

In [1]:
import jax 
import jax.numpy as jnp 
from jax import vmap, jit, grad

try:
    from lqrax import iLQR
except:
    %pip install lqrax
    from lqrax import iLQR

In [2]:
# global variables 
dt = 0.05
tsteps = 100
device = jax.devices("cpu")[0]



In [3]:
# Define the first agent (a differential-drive vehicle)
class DiffdriveAgent(iLQR):
    def __init__(self, dt, x_dim, u_dim, Q, R):
        super().__init__(dt, x_dim, u_dim, Q, R)
    
    def dyn(self, xt, ut):
        return jnp.array([
            ut[0] * jnp.cos(xt[2]),
            ut[0] * jnp.sin(xt[2]),
            ut[1]
        ])

In [4]:
# Define the second agent (a second-order point mass, e.g., a pedestrian)
class PointAgent(iLQR):
    def __init__(self, dt, x_dim, u_dim, Q, R):
        super().__init__(dt, x_dim, u_dim, Q, R)
    
    def dyn(self, xt, ut):
        return jnp.array([
            xt[2], xt[3], ut[0], ut[1]
        ])

In [5]:
# Define the third agent (a bicycle)
class BicycleAgent(iLQR):
    def __init__(self, dt, x_dim, u_dim, Q, R):
        super().__init__(dt, x_dim, u_dim, Q, R)

    def dyn(self, xt, ut):
        L = 0.03
        x, y, theta = xt
        v, delta = ut
        dx = v * jnp.cos(theta)
        dy = v * jnp.sin(theta)
        dtheta = v * jnp.tan(delta) / L
        return jnp.array([dx, dy, dtheta])

In [6]:
# Specify the nonlinear loss of the first agent
Q_diffdrive = jnp.diag(jnp.array([0.1, 0.1, 0.01]))
R_diffdrive = jnp.diag(jnp.array([1.0, 0.01]))

diffdrive_ilqgames = DiffdriveAgent(dt=dt, x_dim=3, u_dim=2, Q=Q_diffdrive, R=R_diffdrive)


def diffdrive_runtime_loss(xt, ut, ref_xt, other_xt1, other_xt2):
    nav_loss = jnp.sum(jnp.square(xt[:2]-ref_xt[:2]))
    collision_loss1 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt1[:2])))
    collision_loss2 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt2[:2])))
    ctrl_loss = 0.1 * jnp.sum(jnp.square(ut * jnp.array([1.0, 0.01])))
    return nav_loss + collision_loss1 + collision_loss2 + ctrl_loss

 
def diffdrive_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    runtime_loss_array = vmap(diffdrive_runtime_loss, in_axes=(
        0, 0, 0, 0, 0))(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return runtime_loss_array.sum() * diffdrive_ilqgames.dt


def diffdrive_linearize_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    dldx = grad(diffdrive_runtime_loss, argnums=(0))
    dldu = grad(diffdrive_runtime_loss, argnums=(1))
    a_traj = vmap(dldx, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    b_traj = vmap(dldu, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return a_traj, b_traj


diffdrive_linearize_dyn = jit(diffdrive_ilqgames.linearize_dyn, device=device)
diffdrive_solve_ilqr = jit(diffdrive_ilqgames.solve, device=device)
diffdrive_loss = jit(diffdrive_loss, device=device)
diffdrive_linearize_loss = jit(diffdrive_linearize_loss, device=device)

In [7]:
# Specify the nonlinear loss of the second agent
Q_point = jnp.diag(jnp.array([0.1, 0.1, 0.001, 0.001]))
R_point = jnp.diag(jnp.array([0.01, 0.01]))

point_ilqgames = PointAgent(
    dt=dt, x_dim=4, u_dim=2, Q=Q_point, R=R_point)


def point_runtime_loss(xt, ut, ref_xt, other_xt1, other_xt2):
    nav_loss = jnp.sum(jnp.square(xt[:2]-ref_xt[:2]))
    collision_loss1 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt1[:2])))
    collision_loss2 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt2[:2])))
    ctrl_loss = 0.1 * jnp.sum(jnp.square(ut * jnp.array([1.0, 0.5])))
    return nav_loss + collision_loss1 + collision_loss2 + ctrl_loss


def point_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    runtime_loss_array = vmap(point_runtime_loss, in_axes=(
        0, 0, 0, 0, 0))(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return runtime_loss_array.sum() * point_ilqgames.dt


def point_linearize_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    dldx = grad(point_runtime_loss, argnums=(0))
    dldu = grad(point_runtime_loss, argnums=(1))
    a_traj = vmap(dldx, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    b_traj = vmap(dldu, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return a_traj, b_traj


point_linearize_dyn = jit(point_ilqgames.linearize_dyn, device=device)
point_solve_ilqr = jit(point_ilqgames.solve, device=device)
point_loss = jit(point_loss, device=device)
point_linearize_loss = jit(point_linearize_loss, device=device)

In [None]:
# Specify the nonlinear loss of the third agent
Q_bicycle = jnp.diag(jnp.array([0.1, 0.1, 0.01]))
R_bicycle = jnp.diag(jnp.array([1.0, 0.1]))

bicycle_ilqgames = BicycleAgent(
    dt=dt, x_dim=3, u_dim=2, Q=Q_bicycle, R=R_bicycle)


def bicycle_runtime_loss(xt, ut, ref_xt, other_xt1, other_xt2):
    nav_loss =  jnp.sum(jnp.square(xt[:2]-ref_xt[:2]))
    collision_loss1 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt1[:2])))
    collision_loss2 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt2[:2])))
    ctrl_loss = 0.1 * jnp.sum(jnp.square(ut * jnp.array([1.0, 0.01])))
    return nav_loss + collision_loss1 + collision_loss2 + ctrl_loss


def bicycle_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    runtime_loss_array = vmap(bicycle_runtime_loss, in_axes=(
        0, 0, 0, 0, 0))(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return runtime_loss_array.sum() * bicycle_ilqgames.dt


def bicycle_linearize_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    dldx = grad(bicycle_runtime_loss, argnums=(0))
    dldu = grad(bicycle_runtime_loss, argnums=(1))
    a_traj = vmap(dldx, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    b_traj = vmap(dldu, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return a_traj, b_traj


bicycle_linearize_dyn = jit(bicycle_ilqgames.linearize_dyn, device=device)
bicycle_solve_ilqr = jit(bicycle_ilqgames.solve, device=device)

bicycle_loss = jit(bicycle_loss, device=device)
bicycle_linearize_loss = jit(bicycle_linearize_loss, device=device)

In [11]:
import numpy as np

# Start iLQGames iterations here
diffdrive_x0 = jnp.array([-4.0, -1.0, 0.0])
diffdrive_u_traj = jnp.tile(jnp.array([0.8, 0.0]), reps=(tsteps, 1))
diffdrive_ref_traj = jnp.linspace(
    jnp.array([-2.0, 0.0]), jnp.array([2.0, 0.0]), tsteps+1
)[1:]

point_x0 = jnp.array([2.0, 0.1, -0.8, jnp.pi/2])
point_u_traj = jnp.zeros((tsteps, 2))
point_ref_traj = jnp.linspace(
    jnp.array([2.0, 0.0]), jnp.array([-2.0, 0.0]), tsteps+1
)[1:]

bicycle_x0 = jnp.array([-1.2, -5.0, jnp.pi/2.0])
bicycle_u_traj = jnp.tile(jnp.array([0.5, 0.0]), reps=(tsteps, 1))
bicycle_ref_traj = jnp.linspace(
    jnp.array([0.0, -2.0]), jnp.array([0.0, 2.0]), tsteps+1
)[1:]

num_iters = 200
step_size = 0.002
for iter in range(num_iters+1):
    # linearize dynamics at the current trajectory/control
    diffdrive_x_traj, diffdrive_A_traj, diffdrive_B_traj = \
        diffdrive_linearize_dyn(diffdrive_x0, diffdrive_u_traj)
    point_x_traj, point_A_traj, point_B_traj = \
        point_linearize_dyn(point_x0, point_u_traj)
    bicycle_x_traj, bicycle_A_traj, bicycle_B_traj = \
        bicycle_linearize_dyn(bicycle_x0, bicycle_u_traj)
    
    # linearize the loss function at the current trajectory/control
    diffdrive_a_traj, diffdrive_b_traj = \
        diffdrive_linearize_loss(
            diffdrive_x_traj, diffdrive_u_traj, diffdrive_ref_traj, point_x_traj, bicycle_x_traj)
    point_a_traj, point_b_traj = \
        point_linearize_loss(
            point_x_traj, point_u_traj, point_ref_traj, diffdrive_x_traj, bicycle_x_traj)
    bicycle_a_traj, bicycle_b_traj = \
        bicycle_linearize_loss(
            bicycle_x_traj, bicycle_u_traj, bicycle_ref_traj, diffdrive_x_traj, point_x_traj)
    
    # compute descent direction on the control
    diffdrive_v_traj, _ = diffdrive_solve_ilqr(
        diffdrive_A_traj, diffdrive_B_traj, diffdrive_a_traj, diffdrive_b_traj)
    point_v_traj, _ = point_solve_ilqr(
        point_A_traj, point_B_traj, point_a_traj, point_b_traj)
    bicycle_v_traj, _ = bicycle_solve_ilqr(
        bicycle_A_traj, bicycle_B_traj, bicycle_a_traj, bicycle_b_traj)
    
    # update control
    if iter % int(num_iters/10) == 0:
        diffdrive_loss_val = diffdrive_loss(
            diffdrive_x_traj, diffdrive_u_traj, diffdrive_ref_traj, point_x_traj, bicycle_x_traj)
        point_loss_val = point_loss(
            point_x_traj, point_u_traj, point_ref_traj, diffdrive_x_traj, bicycle_x_traj)
        bicycle_loss_val = bicycle_loss(
            bicycle_x_traj, bicycle_u_traj, bicycle_ref_traj, diffdrive_x_traj, point_x_traj)
        print(
            f'iter[{iter:3d}/{num_iters}] | diffdrive loss: {diffdrive_loss_val:5.2f} | point loss: {point_loss_val:5.2f} | bicycle loss: {bicycle_loss_val:5.2f}')
        
    diffdrive_u_traj += step_size * diffdrive_v_traj
    point_u_traj += step_size * point_v_traj
    bicycle_u_traj += step_size * bicycle_v_traj

    # break

iter[  0/200] | diffdrive loss: 25.32 | point loss: 108.37 | bicycle loss: 78.86
iter[ 20/200] | diffdrive loss: 13.44 | point loss: 30.35 | bicycle loss: 40.24
iter[ 40/200] | diffdrive loss:  8.36 | point loss:  8.62 | bicycle loss: 23.15
iter[ 60/200] | diffdrive loss:  6.23 | point loss:  3.01 | bicycle loss: 15.78
iter[ 80/200] | diffdrive loss:  5.01 | point loss:  2.45 | bicycle loss: 13.83
iter[100/200] | diffdrive loss:  4.33 | point loss:  3.81 | bicycle loss: 17.65
iter[120/200] | diffdrive loss:  4.07 | point loss:  2.20 | bicycle loss: 32.77
iter[140/200] | diffdrive loss:  3.93 | point loss:  1.13 | bicycle loss: 29.16
iter[160/200] | diffdrive loss:  3.88 | point loss:  1.00 | bicycle loss: 19.05
iter[180/200] | diffdrive loss:  3.63 | point loss:  1.05 | bicycle loss: 19.89
iter[200/200] | diffdrive loss:  3.50 | point loss:  0.90 | bicycle loss: 20.35


In [12]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=120, tight_layout=True)
imgs = []


def update(t):
    ax.cla()
    ax.set_aspect('equal')
    ax.set_xlim(-2.5, 2.5)
    ax.set_ylim(-2.5, 2.5)
    ax.axis('off')

    ax.plot(diffdrive_ref_traj[-1, 0], diffdrive_ref_traj[-1, 1], linestyle='',
            marker='X', markersize=20, color='C0', alpha=0.5)
    ax.plot(point_ref_traj[-1, 0], point_ref_traj[-1, 1], linestyle='',
            marker='X', markersize=20, color='C1', alpha=0.5)
    ax.plot(bicycle_ref_traj[-1, 0], bicycle_ref_traj[-1, 1], linestyle='',
            marker='X', markersize=20, color='C2', alpha=0.5)

    diffdrive_xt = diffdrive_x_traj[t]
    diffdrive_theta = diffdrive_xt[2]
    diffdrive_angle = np.rad2deg(diffdrive_theta)
    ax.plot(diffdrive_x_traj[:t, 0], diffdrive_x_traj[:t, 1],
            linestyle='-', linewidth=5, color='C0', alpha=0.5)
    ax.plot(diffdrive_xt[0], diffdrive_xt[1], linestyle='', marker=(
        4, 0, diffdrive_angle+45), markersize=30, color='C0')
    ax.plot(diffdrive_xt[0]+np.cos(diffdrive_theta)*0.32, diffdrive_xt[1]+np.sin(diffdrive_theta)
            * 0.32, linestyle='', marker=(3, 0, diffdrive_angle+30), markersize=15, color='C0')
    
    point_xt = point_x_traj[t]
    point_theta = np.arctan2(point_xt[3], point_xt[2])
    point_angle = np.rad2deg(point_theta)
    ax.plot(point_x_traj[:t, 0], point_x_traj[:t, 1],
            linestyle='-', linewidth=5, color='C1', alpha=0.5)
    ax.plot(point_xt[0], point_xt[1], linestyle='',
            marker='o', markersize=25, color='C1')
    ax.plot(point_xt[0]+np.cos(point_theta)*0.36, point_xt[1]+np.sin(point_theta)
            * 0.36, linestyle='', marker=(3, 0, point_angle+30), markersize=15, color='C1')
    
    bicycle_xt = bicycle_x_traj[t]
    bicycle_theta = bicycle_xt[2]
    bicycle_angle = np.rad2deg(bicycle_theta)
    ax.plot(bicycle_x_traj[:t, 0], bicycle_x_traj[:t, 1],
            linestyle='-', linewidth=5, color='C2', alpha=0.5)
    ax.plot(bicycle_xt[0], bicycle_xt[1], linestyle='', marker=(
        4, 0, bicycle_angle+45), markersize=30, color='C2')
    ax.plot(bicycle_xt[0]+np.cos(bicycle_theta)*0.33, bicycle_xt[1]+np.sin(bicycle_theta)
            * 0.33, linestyle='', marker=(3, 0, bicycle_angle+30), markersize=15, color='C2')
    
    return []


ani = animation.FuncAnimation(fig, update, frames=tsteps, interval=50)
plt.close()
HTML(ani.to_html5_video())