### Multi-agent iLQGames example

Reference: 

*Fridovich-Keil, David, et al.* "[Efficient iterative linear-quadratic approximations for nonlinear multi-player general-sum differential games](https://arxiv.org/abs/1909.04694)." 2020 IEEE international conference on robotics and automation (ICRA).

In [2]:
import jax 
import jax.numpy as jnp 
from jax import vmap, jit, grad

try:
    from lqrax import iLQR
except:
    %pip install lqrax
    from lqrax import iLQR

In [7]:
# global variables 
dt = 0.05
tsteps = 100
device = jax.devices("cpu")[0]

In [3]:
# Define the first agent (a differential-drive vehicle)
class DiffdriveAgent(iLQR):
    def __init__(self, dt, x_dim, u_dim, Q, R):
        super().__init__(dt, x_dim, u_dim, Q, R)
    
    def dyn(self, xt, ut):
        return jnp.array([
            ut[0] * jnp.cos(xt[2]),
            ut[0] * jnp.sin(xt[2]),
            ut[1]
        ])

In [5]:
# Define the second agent (a second-order point mass, e.g., a pedestrian)
class PointAgent(iLQR):
    def __init__(self, dt, x_dim, u_dim, Q, R):
        super().__init__(dt, x_dim, u_dim, Q, R)
    
    def dyn(self, xt, ut):
        return jnp.array([
            xt[2], xt[3], ut[0], ut[1]
        ])

In [4]:
# Define the third agent (a bicycle)
class BicycleAgent(iLQR):
    def __init__(self, dt, x_dim, u_dim, Q, R):
        super().__init__(dt, x_dim, u_dim, Q, R)

    def dyn(self, xt, ut):
        L = 0.03
        x, y, theta = xt
        v, delta = ut
        dx = v * jnp.cos(theta)
        dy = v * jnp.sin(theta)
        dtheta = v * jnp.tan(delta) / L
        return jnp.array([dx, dy, dtheta])

In [8]:
# Specify the nonlinear loss of the first agent
Q_diffdrive = jnp.diag(jnp.array([0.1, 0.1, 0.01]))
R_diffdrive = jnp.diag(jnp.array([1.0, 0.01]))

diffdrive_ilqgames = DiffdriveAgent(dt=dt, x_dim=3, u_dim=2, Q=Q_diffdrive, R=R_diffdrive)


def diffdrive_runtime_loss(xt, ut, ref_xt, other_xt1, other_xt2):
    nav_loss = jnp.sum(jnp.square(xt[:2]-ref_xt[:2]))
    collision_loss1 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt1[:2])))
    collision_loss2 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt2[:2])))
    ctrl_loss = 0.1 * jnp.sum(jnp.square(ut * jnp.array([1.0, 0.01])))
    return nav_loss + collision_loss1 + collision_loss2 + ctrl_loss

 
def diffdrive_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    runtime_loss_array = vmap(diffdrive_runtime_loss, in_axes=(
        0, 0, 0, 0, 0))(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return runtime_loss_array.sum() * diffdrive_ilqgames.dt


def diffdrive_linearize_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    dldx = grad(diffdrive_runtime_loss, argnums=(0))
    dldu = grad(diffdrive_runtime_loss, argnums=(1))
    a_traj = vmap(dldx, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    b_traj = vmap(dldu, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return a_traj, b_traj


diffdrive_linearize_dyn = jit(diffdrive_ilqgames.linearize_dyn, device=device)
diffdrive_solve_ilqr = jit(diffdrive_ilqgames.solve, device=device)
diffdrive_loss = jit(diffdrive_loss, device=device)
diffdrive_linearize_loss = jit(diffdrive_linearize_loss, device=device)

In [9]:
# Specify the nonlinear loss of the second agent
Q_point = jnp.diag(jnp.array([0.1, 0.1, 0.001, 0.001]))
R_point = jnp.diag(jnp.array([0.01, 0.01]))

point_ilqgames = PointAgent(
    dt=dt, x_dim=4, u_dim=2, Q=Q_point, R=R_point)


def point_runtime_loss(xt, ut, ref_xt, other_xt1, other_xt2):
    nav_loss = jnp.sum(jnp.square(xt[:2]-ref_xt[:2]))
    collision_loss1 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt1[:2])))
    collision_loss2 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt2[:2])))
    ctrl_loss = 0.1 * jnp.sum(jnp.square(ut * jnp.array([1.0, 0.5])))
    return nav_loss + collision_loss1 + collision_loss2 + ctrl_loss


def point_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    runtime_loss_array = vmap(point_runtime_loss, in_axes=(
        0, 0, 0, 0, 0))(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return runtime_loss_array.sum() * point_ilqgames.dt


def point_linearize_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    dldx = grad(point_runtime_loss, argnums=(0))
    dldu = grad(point_runtime_loss, argnums=(1))
    a_traj = vmap(dldx, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    b_traj = vmap(dldu, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return a_traj, b_traj


point_linearize_dyn = jit(point_ilqgames.linearize_dyn, device=device)
point_solve_ilqr = jit(point_ilqgames.solve, device=device)
point_loss = jit(point_loss, device=device)
point_linearize_loss = jit(point_linearize_loss, device=device)

In [10]:
# Specify the nonlinear loss of the third agent
Q_bicycle = jnp.diag(jnp.array([0.1, 0.1, 0.01]))
R_bicycle = jnp.diag(jnp.array([1.0, 0.1]))

bicycle_ilqgames = BicycleAgent(
    dt=dt, x_dim=3, u_dim=2, Q=Q_bicycle, R=R_bicycle)


def bicycle_runtime_loss(xt, ut, ref_xt, other_xt1, other_xt2):
    nav_loss =  jnp.sum(jnp.square(xt[:2]-ref_xt[:2]))
    collision_loss1 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt1[:2])))
    collision_loss2 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt2[:2])))
    ctrl_loss = 0.1 * jnp.sum(jnp.square(ut * jnp.array([1.0, 0.01])))
    return nav_loss + collision_loss1 + collision_loss2 + ctrl_loss


def bicycle_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    runtime_loss_array = vmap(bicycle_runtime_loss, in_axes=(
        0, 0, 0, 0, 0))(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return runtime_loss_array.sum() * bicycle_ilqgames.dt


def bicycle_linearize_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    dldx = grad(bicycle_runtime_loss, argnums=(0))
    dldu = grad(bicycle_runtime_loss, argnums=(1))
    a_traj = vmap(dldx, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    b_traj = vmap(dldu, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return a_traj, b_traj


bicycle_linearize_dyn = jit(bicycle_ilqgames.linearize_dyn, device=device)
bicycle_solve_ilqr = jit(bicycle_ilqgames.solve, device=device)
bicycle_loss = jit(bicycle_loss, device=device)
bicycle_linearize_loss = jit(bicycle_linearize_loss, device=device)

In [18]:
# Start iLQGames iterations here
diffdrive_x0 = jnp.array([-2.0, -0.1, 0.0])
diffdrive_u_traj = jnp.tile(jnp.array([0.8, 0.0]), reps=(tsteps, 1))
diffdrive_ref_traj = jnp.linspace(
    jnp.array([-2.0, 0.0]), jnp.array([2.0, 0.0]), tsteps+1
)[1:]

point_x0 = jnp.array([2.0, 0.1, -0.8, 0.0])
point_u_traj = jnp.zeros((tsteps, 2))
point_ref_traj = jnp.linspace(
    jnp.array([2.0, 0.0]), jnp.array([-2.0, 0.0]), tsteps+1
)[1:]

bicycle_x0 = jnp.array([-0.2, -2.0, jnp.pi/2.0])
bicycle_u_traj = jnp.tile(jnp.array([0.5, 0.0]), reps=(tsteps, 1))
bicycle_ref_traj = jnp.linspace(
    jnp.array([0.0, -2.0]), jnp.array([0.0, 2.0]), tsteps+1
)[1:]

num_iters = 200
step_size = 0.002
for iter in range(num_iters+1):
    # linearize dynamics at the current trajectory/control
    diffdrive_x_traj, diffdrive_A_traj, diffdrive_B_traj = \
        diffdrive_linearize_dyn(diffdrive_x0, diffdrive_u_traj)
    point_x_traj, point_A_traj, point_B_traj = \
        point_linearize_dyn(point_x0, point_u_traj)
    bicycle_x_traj, bicycle_A_traj, bicycle_B_traj = \
        bicycle_linearize_dyn(bicycle_x0, bicycle_u_traj)
    
    # linearize the loss function at the current trajectory/control
    diffdrive_a_traj, diffdrive_b_traj = \
        diffdrive_linearize_loss(
            diffdrive_x_traj, diffdrive_u_traj, diffdrive_ref_traj, point_x_traj, bicycle_x_traj)
    point_a_traj, point_b_traj = \
        point_linearize_loss(
            point_x_traj, point_u_traj, point_ref_traj, diffdrive_x_traj, bicycle_x_traj)
    bicycle_a_traj, bicycle_b_traj = \
        bicycle_linearize_loss(
            bicycle_x_traj, bicycle_u_traj, bicycle_ref_traj, diffdrive_x_traj, point_x_traj)
    
    # compute descent direction on the control
    diffdrive_v_traj, _ = diffdrive_solve_ilqr(
        diffdrive_A_traj, diffdrive_B_traj, diffdrive_a_traj, diffdrive_b_traj)
    point_v_traj, _ = point_solve_ilqr(
        point_A_traj, point_B_traj, point_a_traj, point_b_traj)
    bicycle_v_traj, _ = bicycle_solve_ilqr(
        bicycle_A_traj, bicycle_B_traj, bicycle_a_traj, bicycle_b_traj)
    
    # update control
    if iter % int(num_iters/10) == 0:
        diffdrive_loss_val = diffdrive_loss(
            diffdrive_x_traj, diffdrive_u_traj, diffdrive_ref_traj, point_x_traj, bicycle_x_traj)
        point_loss_val = point_loss(
            point_x_traj, point_u_traj, point_ref_traj, diffdrive_x_traj, bicycle_x_traj)
        bicycle_loss_val = bicycle_loss(
            bicycle_x_traj, bicycle_u_traj, bicycle_ref_traj, diffdrive_x_traj, point_x_traj)
        print(
            f'iter[{iter:3d}/{num_iters}] | diffdrive loss: {diffdrive_loss_val:5.2f} | point loss: {point_loss_val:5.2f} | bicycle loss: {bicycle_loss_val:5.2f}')
        
    diffdrive_u_traj += step_size * diffdrive_v_traj
    point_u_traj += step_size * point_v_traj
    bicycle_u_traj += step_size * bicycle_v_traj

    # break

iter: 0
diffdrive_x0: [-2.  -0.1  0. ]
point_x0: [ 2.   0.1 -0.8  0. ]
bicycle_x0: [-0.2       -2.         1.5707964]
iter[  0/200] | diffdrive loss:  5.40 | point loss:  5.38 | bicycle loss:  6.37
iter: 1
diffdrive_x0: [-2.  -0.1  0. ]
point_x0: [ 2.   0.1 -0.8  0. ]
bicycle_x0: [-0.2       -2.         1.5707964]
iter: 2
diffdrive_x0: [-2.  -0.1  0. ]
point_x0: [ 2.   0.1 -0.8  0. ]
bicycle_x0: [-0.2       -2.         1.5707964]
iter: 3
diffdrive_x0: [-2.  -0.1  0. ]
point_x0: [ 2.   0.1 -0.8  0. ]
bicycle_x0: [-0.2       -2.         1.5707964]
iter: 4
diffdrive_x0: [-2.  -0.1  0. ]
point_x0: [ 2.   0.1 -0.8  0. ]
bicycle_x0: [-0.2       -2.         1.5707964]
iter: 5
diffdrive_x0: [-2.  -0.1  0. ]
point_x0: [ 2.   0.1 -0.8  0. ]
bicycle_x0: [-0.2       -2.         1.5707964]
iter: 6
diffdrive_x0: [-2.  -0.1  0. ]
point_x0: [ 2.   0.1 -0.8  0. ]
bicycle_x0: [-0.2       -2.         1.5707964]
iter: 7
diffdrive_x0: [-2.  -0.1  0. ]
point_x0: [ 2.   0.1 -0.8  0. ]
bicycle_x0: [-0.2      

In [13]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=120, tight_layout=True)
imgs = []


def update(t):
    ax.cla()
    ax.set_aspect('equal')
    ax.set_xlim(-2.5, 2.5)
    ax.set_ylim(-2.5, 2.5)
    ax.axis('off')

    ax.plot(diffdrive_ref_traj[-1, 0], diffdrive_ref_traj[-1, 1], linestyle='',
            marker='X', markersize=20, color='C0', alpha=0.5)
    ax.plot(point_ref_traj[-1, 0], point_ref_traj[-1, 1], linestyle='',
            marker='X', markersize=20, color='C1', alpha=0.5)
    ax.plot(bicycle_ref_traj[-1, 0], bicycle_ref_traj[-1, 1], linestyle='',
            marker='X', markersize=20, color='C2', alpha=0.5)

    diffdrive_xt = diffdrive_x_traj[t]
    diffdrive_theta = diffdrive_xt[2]
    diffdrive_angle = np.rad2deg(diffdrive_theta)
    ax.plot(diffdrive_x_traj[:t, 0], diffdrive_x_traj[:t, 1],
            linestyle='-', linewidth=5, color='C0', alpha=0.5)
    ax.plot(diffdrive_xt[0], diffdrive_xt[1], linestyle='', marker=(
        4, 0, diffdrive_angle+45), markersize=30, color='C0')
    ax.plot(diffdrive_xt[0]+np.cos(diffdrive_theta)*0.32, diffdrive_xt[1]+np.sin(diffdrive_theta)
            * 0.32, linestyle='', marker=(3, 0, diffdrive_angle+30), markersize=15, color='C0')
    
    point_xt = point_x_traj[t]
    point_theta = np.arctan2(point_xt[3], point_xt[2])
    point_angle = np.rad2deg(point_theta)
    ax.plot(point_x_traj[:t, 0], point_x_traj[:t, 1],
            linestyle='-', linewidth=5, color='C1', alpha=0.5)
    ax.plot(point_xt[0], point_xt[1], linestyle='',
            marker='o', markersize=25, color='C1')
    ax.plot(point_xt[0]+np.cos(point_theta)*0.36, point_xt[1]+np.sin(point_theta)
            * 0.36, linestyle='', marker=(3, 0, point_angle+30), markersize=15, color='C1')
    
    bicycle_xt = bicycle_x_traj[t]
    bicycle_theta = bicycle_xt[2]
    bicycle_angle = np.rad2deg(bicycle_theta)
    ax.plot(bicycle_x_traj[:t, 0], bicycle_x_traj[:t, 1],
            linestyle='-', linewidth=5, color='C2', alpha=0.5)
    ax.plot(bicycle_xt[0], bicycle_xt[1], linestyle='', marker=(
        4, 0, bicycle_angle+45), markersize=30, color='C2')
    ax.plot(bicycle_xt[0]+np.cos(bicycle_theta)*0.33, bicycle_xt[1]+np.sin(bicycle_theta)
            * 0.33, linestyle='', marker=(3, 0, bicycle_angle+30), markersize=15, color='C2')
    
    return []


ani = animation.FuncAnimation(fig, update, frames=tsteps, interval=50)
plt.close()
HTML(ani.to_html5_video())

In [19]:
# Define each point agent and loss functions
Q_point = jnp.diag(jnp.array([0.1, 0.1, 0.001, 0.001]))
R_point = jnp.diag(jnp.array([0.01, 0.01]))

point_ilqgames1 = PointAgent(
    dt=dt, x_dim=4, u_dim=2, Q=Q_point, R=R_point)
point_ilqgames2 = PointAgent(
    dt=dt, x_dim=4, u_dim=2, Q=Q_point, R=R_point)
point_ilqgames3 = PointAgent(
    dt=dt, x_dim=4, u_dim=2, Q=Q_point, R=R_point)

def point_runtime_loss(xt, ut, ref_xt, other_xt1, other_xt2):
    nav_loss = jnp.sum(jnp.square(xt[:2]-ref_xt[:2]))
    collision_loss1 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt1[:2])))
    collision_loss2 = 10.0 * \
        jnp.exp(-5.0 * jnp.sum(jnp.square(xt[:2]-other_xt2[:2])))
    ctrl_loss = 0.1 * jnp.sum(jnp.square(ut * jnp.array([1.0, 0.5])))
    return nav_loss + collision_loss1 + collision_loss2 + ctrl_loss


def point_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    runtime_loss_array = vmap(point_runtime_loss, in_axes=(
        0, 0, 0, 0, 0))(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return runtime_loss_array.sum() * point_ilqgames.dt


def point_linearize_loss(x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2):
    dldx = grad(point_runtime_loss, argnums=(0))
    dldu = grad(point_runtime_loss, argnums=(1))
    a_traj = vmap(dldx, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    b_traj = vmap(dldu, in_axes=(0, 0, 0, 0, 0))(
        x_traj, u_traj, ref_x_traj, other_x_traj1, other_x_traj2)
    return a_traj, b_traj


point1_linearize_dyn = jit(point_ilqgames1.linearize_dyn, device=device)
point1_solve_ilqr = jit(point_ilqgames1.solve, device=device)
point1_loss = jit(point_loss, device=device)
point1_linearize_loss = jit(point_linearize_loss, device=device)

point2_linearize_dyn = jit(point_ilqgames2.linearize_dyn, device=device)
point2_solve_ilqr = jit(point_ilqgames2.solve, device=device)
point2_loss = jit(point_loss, device=device)
point2_linearize_loss = jit(point_linearize_loss, device=device)

point3_linearize_dyn = jit(point_ilqgames3.linearize_dyn, device=device)
point3_solve_ilqr = jit(point_ilqgames3.solve, device=device)
point3_loss = jit(point_loss, device=device)
point3_linearize_loss = jit(point_linearize_loss, device=device)

In [42]:
import random

# Random initialization of start and goal positions
def random_position():
    slope = random.uniform(-4.0, 4.0)
    start_x_coord = random.uniform(0.0, 3.0)
    start_y_coord = slope * start_x_coord

    end_x_coord = random.uniform(-3.0, 0.0)
    end_y_coord = slope * end_x_coord

    return start_x_coord, start_y_coord, end_x_coord, end_y_coord


In [43]:

# Start iLQGames iterations here
start1_x, start1_y, end1_x, end1_y = random_position()
point1_x0 = jnp.array([start1_x, start1_y, -0.8, 0.0])
point1_u_traj = jnp.zeros((tsteps, 2))
point1_ref_traj = jnp.linspace(
    jnp.array([start1_x, start1_y]), jnp.array([end1_x, end1_y]), tsteps+1
)[1:]

start2_x, start2_y, end2_x, end2_y = random_position()
point2_x0 = jnp.array([start2_x, start2_y, -0.8, 0.0])
point2_u_traj = jnp.zeros((tsteps, 2))
point2_ref_traj = jnp.linspace(
    jnp.array([start2_x, start2_y]), jnp.array([end2_x, end2_y]), tsteps+1
)[1:]

start3_x, start3_y, end3_x, end3_y = random_position()
point3_x0 = jnp.array([start3_x, start3_y, -0.8, 0.0])
point3_u_traj = jnp.zeros((tsteps, 2))
point3_ref_traj = jnp.linspace(
    jnp.array([start3_x, start3_y]), jnp.array([end3_x, end3_y]), tsteps+1
)[1:]

num_iters = 1000
step_size = 0.002
for iter in range(num_iters+1):
    # linearize dynamics at the current trajectory/control
    point1_x_traj, point1_A_traj, point1_B_traj = \
        point1_linearize_dyn(point1_x0, point1_u_traj)
    point2_x_traj, point2_A_traj, point2_B_traj = \
        point2_linearize_dyn(point2_x0, point2_u_traj)
    point3_x_traj, point3_A_traj, point3_B_traj = \
        point3_linearize_dyn(point3_x0, point3_u_traj)
    
    # linearize the loss function at the current trajectory/control
    point1_a_traj, point1_b_traj = \
        point1_linearize_loss(
            point1_x_traj, point1_u_traj, point1_ref_traj, point2_x_traj, point3_x_traj)
    point2_a_traj, point2_b_traj = \
        point2_linearize_loss(
            point2_x_traj, point2_u_traj, point2_ref_traj, point1_x_traj, point3_x_traj)
    point3_a_traj, point3_b_traj = \
        point3_linearize_loss(
            point3_x_traj, point3_u_traj, point3_ref_traj, point1_x_traj, point2_x_traj)
    
    # compute descent direction on the control
    point1_v_traj, _ = point1_solve_ilqr(
        point1_A_traj, point1_B_traj, point1_a_traj, point1_b_traj)
    point2_v_traj, _ = point2_solve_ilqr(
        point2_A_traj, point2_B_traj, point2_a_traj, point2_b_traj)
    point3_v_traj, _ = point3_solve_ilqr(
        point3_A_traj, point3_B_traj, point3_a_traj, point3_b_traj)
    
    # update control
    if iter % int(num_iters/10) == 0:
        point1_loss_val = point1_loss(
            point1_x_traj, point1_u_traj, point1_ref_traj, point2_x_traj, point3_x_traj)
        point2_loss_val = point2_loss(
            point2_x_traj, point2_u_traj, point2_ref_traj, point1_x_traj, point3_x_traj)
        point3_loss_val = point3_loss(
            point3_x_traj, point3_u_traj, point3_ref_traj, point1_x_traj, point2_x_traj)
        print(
            f'iter[{iter:3d}/{num_iters}] | point1 loss: {point1_loss_val:5.2f} | point2 loss: {point2_loss_val:5.2f} | point3 loss: {point3_loss_val:5.2f}')
        
    point1_u_traj += step_size * point1_v_traj
    point2_u_traj += step_size * point2_v_traj
    point3_u_traj += step_size * point3_v_traj


iter[  0/1000] | point1 loss: 204.37 | point2 loss: 128.78 | point3 loss: 367.69
iter[100/1000] | point1 loss:  0.76 | point2 loss:  1.03 | point3 loss:  2.01
iter[200/1000] | point1 loss:  0.43 | point2 loss:  0.60 | point3 loss:  1.11
iter[300/1000] | point1 loss:  0.43 | point2 loss:  0.60 | point3 loss:  1.11
iter[400/1000] | point1 loss:  0.43 | point2 loss:  0.60 | point3 loss:  1.11
iter[500/1000] | point1 loss:  0.43 | point2 loss:  0.60 | point3 loss:  1.11
iter[600/1000] | point1 loss:  0.43 | point2 loss:  0.60 | point3 loss:  1.11
iter[700/1000] | point1 loss:  0.43 | point2 loss:  0.60 | point3 loss:  1.11
iter[800/1000] | point1 loss:  0.43 | point2 loss:  0.60 | point3 loss:  1.11
iter[900/1000] | point1 loss:  0.43 | point2 loss:  0.60 | point3 loss:  1.11
iter[1000/1000] | point1 loss:  0.43 | point2 loss:  0.60 | point3 loss:  1.11


In [45]:
# Check final positions of each agent after simulation
print("=== Final Agent Positions ===")

# Get the final trajectory for each agent (after the simulation loop)
point1_final_x_traj, _, _ = point1_linearize_dyn(point1_x0, point1_u_traj)
point2_final_x_traj, _, _ = point2_linearize_dyn(point2_x0, point2_u_traj)
point3_final_x_traj, _, _ = point3_linearize_dyn(point3_x0, point3_u_traj)

# Extract final positions (last timestep)
point1_final_pos = point1_final_x_traj[-1, :2]  # [x, y]
point2_final_pos = point2_final_x_traj[-1, :2]  # [x, y]
point3_final_pos = point3_final_x_traj[-1, :2]  # [x, y]

# Extract goal positions
point1_goal = point1_ref_traj[-1, :2]  # [x, y]
point2_goal = point2_ref_traj[-1, :2]  # [x, y]
point3_goal = point3_ref_traj[-1, :2]  # [x, y]

# Calculate distances to goals
point1_dist_to_goal = jnp.linalg.norm(point1_final_pos - point1_goal)
point2_dist_to_goal = jnp.linalg.norm(point2_final_pos - point2_goal)
point3_dist_to_goal = jnp.linalg.norm(point3_final_pos - point3_goal)

print(f"Agent 1:")
print(f"  Start: ({point1_x0[0]:.3f}, {point1_x0[1]:.3f})")
print(f"  Final: ({point1_final_pos[0]:.3f}, {point1_final_pos[1]:.3f})")
print(f"  Goal:  ({point1_goal[0]:.3f}, {point1_goal[1]:.3f})")
print(f"  Distance to goal: {point1_dist_to_goal:.3f}")

print(f"\nAgent 2:")
print(f"  Start: ({point2_x0[0]:.3f}, {point2_x0[1]:.3f})")
print(f"  Final: ({point2_final_pos[0]:.3f}, {point2_final_pos[1]:.3f})")
print(f"  Goal:  ({point2_goal[0]:.3f}, {point2_goal[1]:.3f})")
print(f"  Distance to goal: {point2_dist_to_goal:.3f}")

print(f"\nAgent 3:")
print(f"  Start: ({point3_x0[0]:.3f}, {point3_x0[1]:.3f})")
print(f"  Final: ({point3_final_pos[0]:.3f}, {point3_final_pos[1]:.3f})")
print(f"  Goal:  ({point3_goal[0]:.3f}, {point3_goal[1]:.3f})")
print(f"  Distance to goal: {point3_dist_to_goal:.3f}")

# Summary
print(f"\n=== Summary ===")
print(f"Average distance to goal: {(point1_dist_to_goal + point2_dist_to_goal + point3_dist_to_goal) / 3:.3f}")
print(f"Max distance to goal: {max(point1_dist_to_goal, point2_dist_to_goal, point3_dist_to_goal):.3f}")
print(f"Min distance to goal: {min(point1_dist_to_goal, point2_dist_to_goal, point3_dist_to_goal):.3f}")


=== Final Agent Positions ===
Agent 1:
  Start: (0.848, -2.576)
  Final: (-2.769, 8.408)
  Goal:  (-2.768, 8.408)
  Distance to goal: 0.000

Agent 2:
  Start: (2.764, -5.983)
  Final: (-1.339, 2.738)
  Goal:  (-1.267, 2.742)
  Distance to goal: 0.072

Agent 3:
  Start: (2.622, 10.063)
  Final: (-1.147, -4.674)
  Goal:  (-1.219, -4.679)
  Distance to goal: 0.072

=== Summary ===
Average distance to goal: 0.048
Max distance to goal: 0.072
Min distance to goal: 0.000


In [44]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Visualize the three point agents scenario
fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=120, tight_layout=True)
imgs = []

def update_three_point_agents(t):
    ax.cla()
    ax.set_aspect('equal')
    ax.set_xlim(-5.0, 5.0)
    ax.set_ylim(-5.0, 5.0)
    ax.axis('off')
    ax.set_title('Three Point Agents - iLQGames', fontsize=14)

    # Plot goals as X markers
    ax.plot(point1_ref_traj[-1, 0], point1_ref_traj[-1, 1], linestyle='',
            marker='X', markersize=20, color='C0', alpha=0.5, label='Agent 1 Goal')
    ax.plot(point2_ref_traj[-1, 0], point2_ref_traj[-1, 1], linestyle='',
            marker='X', markersize=20, color='C1', alpha=0.5, label='Agent 2 Goal')
    ax.plot(point3_ref_traj[-1, 0], point3_ref_traj[-1, 1], linestyle='',
            marker='X', markersize=20, color='C2', alpha=0.5, label='Agent 3 Goal')

    # Agent 1 trajectory and position
    point1_xt = point1_x_traj[t]
    point1_theta = np.arctan2(point1_xt[3], point1_xt[2])
    point1_angle = np.rad2deg(point1_theta)
    ax.plot(point1_x_traj[:t+1, 0], point1_x_traj[:t+1, 1],
            linestyle='-', linewidth=3, color='C0', alpha=0.7, label='Agent 1')
    ax.plot(point1_xt[0], point1_xt[1], linestyle='',
            marker='o', markersize=15, color='C0')
    if np.linalg.norm([point1_xt[2], point1_xt[3]]) > 0.01:  # Only show direction if moving
        ax.plot(point1_xt[0]+np.cos(point1_theta)*0.2, point1_xt[1]+np.sin(point1_theta)*0.2, 
                linestyle='', marker=(3, 0, point1_angle+30), markersize=8, color='C0')
    
    # Agent 2 trajectory and position  
    point2_xt = point2_x_traj[t]
    point2_theta = np.arctan2(point2_xt[3], point2_xt[2])
    point2_angle = np.rad2deg(point2_theta)
    ax.plot(point2_x_traj[:t+1, 0], point2_x_traj[:t+1, 1],
            linestyle='-', linewidth=3, color='C1', alpha=0.7, label='Agent 2')
    ax.plot(point2_xt[0], point2_xt[1], linestyle='',
            marker='o', markersize=15, color='C1')
    if np.linalg.norm([point2_xt[2], point2_xt[3]]) > 0.01:  # Only show direction if moving
        ax.plot(point2_xt[0]+np.cos(point2_theta)*0.2, point2_xt[1]+np.sin(point2_theta)*0.2, 
                linestyle='', marker=(3, 0, point2_angle+30), markersize=8, color='C1')
    
    # Agent 3 trajectory and position
    point3_xt = point3_x_traj[t]
    point3_theta = np.arctan2(point3_xt[3], point3_xt[2])
    point3_angle = np.rad2deg(point3_theta)
    ax.plot(point3_x_traj[:t+1, 0], point3_x_traj[:t+1, 1],
            linestyle='-', linewidth=3, color='C2', alpha=0.7, label='Agent 3')
    ax.plot(point3_xt[0], point3_xt[1], linestyle='',
            marker='o', markersize=15, color='C2')
    if np.linalg.norm([point3_xt[2], point3_xt[3]]) > 0.01:  # Only show direction if moving
        ax.plot(point3_xt[0]+np.cos(point3_theta)*0.2, point3_xt[1]+np.sin(point3_theta)*0.2, 
                linestyle='', marker=(3, 0, point3_angle+30), markersize=8, color='C2')
    
    # Add legend only on first frame
    if t == 0:
        ax.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0))
    
    return []

# Create animation
ani_three_point = animation.FuncAnimation(fig, update_three_point_agents, frames=tsteps, interval=50)
plt.close()
HTML(ani_three_point.to_html5_video())
