## Lecture 9: Augmented Lagrangian Demo

In [1]:
import time
from typing import Callable

import numpy as onp
import jax
import jaxlie
import trimesh
import viser
from jax import numpy as jnp

%load_ext autoreload
%autoreload 2

Run the below cell, and click [here](http://localhost:8080) to open the visualization.

In [3]:
if 'server' not in dir() or server is None:
    server = viser.ViserServer()
else:
    server.scene.reset()
    server.gui.reset()

server.scene.add_grid("grid")

init_pos = jnp.array([-2.0, -2.0, 1.0])
final_pos = jnp.array([2.0, 2.0, 1.0])

obstacle_centers = jnp.array([
    [0.0, 0.0, 1.0],
    [-1.5, -1.5, 1.25],
    [-1.5, -0.75, 0.5],
    [1.5, 1.0, 0.75],
    [1.0, 1.75, 1.0],
])
obstacle_radii = jnp.array([1.0, 0.5, 0.5, 0.5, 0.5])
drone_radius = 0.1

# Add sphere obstacle visualization.
for i in range(len(obstacle_centers)):
    server.scene.add_icosphere(
        f"obstacle_{i}",
        radius=float(obstacle_radii[i]),
        color=(0.0, 0.0, 1.0),
        position=onp.array(obstacle_centers[i])
    )

num_timesteps = 100
dt = 2.0 / num_timesteps
x0 = jax.vmap(jnp.linspace, in_axes=(0, 0, None))(init_pos, final_pos, num_timesteps).T

def make_segments(x: jnp.ndarray) -> onp.ndarray:
    pts = jnp.concatenate([init_pos[None, :], x.reshape(-1, 3)], axis=0)
    return onp.array(jnp.stack([pts[:-1], pts[1:]], axis=1))

# Red = penalty method, green = augmented Lagrangian.
pm_trajectory_handle = server.scene.add_line_segments("pm_trajectory", make_segments(x0), colors=(1.0, 0.2, 0.2))
al_trajectory_handle = server.scene.add_line_segments("al_trajectory", make_segments(x0), colors=(0.2, 1.0, 0.4))
pm_points_handle = server.scene.add_point_cloud("pm_points", onp.array(x0), colors=(1.0, 0.2, 0.2), point_shape="circle", point_size=0.1)
al_points_handle = server.scene.add_point_cloud("al_points", onp.array(x0), colors=(0.2, 1.0, 0.4), point_shape="circle", point_size=0.1)

goal_handle  = server.scene.add_icosphere("goal",  radius=0.1, color=(1.0, 1.0, 0.0), position=onp.array(final_pos))
start_handle = server.scene.add_icosphere("start", radius=0.1, color=(0.0, 1.0, 1.0), position=onp.array(init_pos))

pm_violation_handle = server.gui.add_text("pm_violation", "PM violation: --", disabled=False)
al_violation_handle = server.gui.add_text("al_violation", "AL violation: --", disabled=False)

# --------------------------------------------------------------------------- #
#  Constraint / residual functions                                             #
# --------------------------------------------------------------------------- #

def final_point_constraint(x: jnp.ndarray) -> jnp.ndarray:
    """Hard constraint: only the last waypoint must equal final_pos."""
    return x.reshape(-1, 3)[-1] - final_pos

def smoothness_residual(x: jnp.ndarray) -> jnp.ndarray:
    pts = jnp.concatenate([init_pos[None, :], x.reshape(-1, 3)], axis=0)
    return jnp.diff(pts, axis=0).ravel() / dt

def single_obstacle_constraint(x: jnp.ndarray, center: jnp.ndarray, radius: float) -> jnp.ndarray:
    """Raw signed constraint: positive = inside obstacle (violated)."""
    sq_dist = jnp.sum((x.reshape(-1, 3) - center) ** 2, axis=1)
    return ((radius + drone_radius) ** 2 - sq_dist).ravel()

def obstacle_constraint(x: jnp.ndarray) -> jnp.ndarray:
    return jax.vmap(single_obstacle_constraint, in_axes=(None, 0, 0))(
        x, obstacle_centers, obstacle_radii
    ).ravel()

def obstacle_residual(x: jnp.ndarray) -> jnp.ndarray:
    """Pre-rectified obstacle violation for use in the penalty method."""
    return jnp.maximum(0.0, obstacle_constraint(x))

def all_constraints(x: jnp.ndarray) -> jnp.ndarray:
    """
    Unified constraint vector for the AL, treating everything as g(x) = 0.
    Only the final waypoint is constrained to equal final_pos â€” constraining
    all waypoints fights smoothness and causes divergence.
    Obstacles use max(0, c(x)) so the equality update lm <- lm + rho*g
    automatically respects non-negativity.
    """
    return jnp.concatenate([
        final_point_constraint(x),
        jnp.maximum(0.0, obstacle_constraint(x)),
    ])

def total_violation(x: jnp.ndarray) -> float:
    """Scalar constraint violation for display."""
    return float(jnp.sum(obstacle_residual(x)))

# --------------------------------------------------------------------------- #
#  Penalty method                                                              #
# --------------------------------------------------------------------------- #

def build_penalty_cost(w_smooth: float, w_goal: float, w_obstacle: jnp.ndarray) -> Callable:
    def cost_fn(x: jnp.ndarray) -> float:
        return (
            0.5 * w_smooth     * jnp.sum(smoothness_residual(x)    ** 2)
            + 0.5 * w_goal     * jnp.sum(final_point_constraint(x) ** 2)
            + 0.5 * w_obstacle * jnp.sum(obstacle_residual(x)      ** 2)
        )
    return cost_fn

def pm_inner_step(
    x: jnp.ndarray,
    w_obstacle: jnp.ndarray,
    w_goal: jnp.ndarray,
    step_size: float,
    w_smooth: float,
) -> jnp.ndarray:
    grad = jax.grad(build_penalty_cost(w_smooth, w_goal, w_obstacle))(x)
    return x - step_size * grad

def pm_outer_step(
    x: jnp.ndarray,
    w_obstacle: jnp.ndarray,
    w_goal: jnp.ndarray,
    penalty_scale: float,
    w_max: float,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    return (
        x,
        jnp.minimum(w_obstacle * penalty_scale, w_max),
        jnp.minimum(w_goal     * penalty_scale, w_max),
    )

# --------------------------------------------------------------------------- #
#  Augmented Lagrangian                                                        #
# --------------------------------------------------------------------------- #

def create_augmented_lagrangian(lms: jnp.ndarray, rho: jnp.ndarray, w_smooth: float) -> Callable:
    def augmented_lagrangian(x: jnp.ndarray) -> jnp.ndarray:
        smooth = smoothness_residual(x)
        g = all_constraints(x)
        return (
            0.5 * w_smooth * jnp.sum(smooth ** 2)
            + lms @ g
            + 0.5 * rho * jnp.sum(g ** 2)
        )
    return augmented_lagrangian

def al_inner_step(
    x: jnp.ndarray,
    lms: jnp.ndarray,
    rho: jnp.ndarray,
    step_size: float,
    w_smooth: float,
) -> jnp.ndarray:
    grad = jax.grad(create_augmented_lagrangian(lms, rho, w_smooth))(x)
    return x - step_size * grad

def al_outer_step(
    x: jnp.ndarray,
    lms: jnp.ndarray,
    rho: jnp.ndarray,
    penalty_scale: float,
    rho_max: float,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    lms = lms + rho * all_constraints(x)
    rho = jnp.minimum(rho * penalty_scale, rho_max)
    return x, lms, rho

# --------------------------------------------------------------------------- #
#  GUI                                                                         #
# --------------------------------------------------------------------------- #

w_smooth_slider      = server.gui.add_slider("w_smooth",        min=0.0,  max=1.0,    step=1e-3, initial_value=0.1)
penalty_scale_slider = server.gui.add_slider("penalty_scale",   min=1.0,  max=2.0,    step=0.01, initial_value=1.1)
w_max_slider         = server.gui.add_slider("w_max / rho_max", min=1.0,  max=1e5,    step=1.0,  initial_value=1e3)
num_iters_slider     = server.gui.add_slider("num_iters",       min=1,    max=100000, step=1,    initial_value=10000)
lr_slider            = server.gui.add_slider("learning_rate",   min=1e-7, max=0.1,    step=1e-5, initial_value=0.001)
outer_every_slider   = server.gui.add_slider("outer_every",     min=10,   max=1000,   step=10,   initial_value=100)

optimize_button = server.gui.add_button("Optimize Both")
play_pm_button  = server.gui.add_button("Play (Penalty, red)")
play_al_button  = server.gui.add_button("Play (AL, green)")

x_pm = x0.ravel()
x_al = x0.ravel()

_run_id = 0

def update_viz(x_pm, x_al):
    pm_trajectory_handle.points = make_segments(x_pm)
    al_trajectory_handle.points = make_segments(x_al)
    pm_points_handle.points = onp.array(x_pm).reshape(-1, 3)
    al_points_handle.points = onp.array(x_al).reshape(-1, 3)
    pm_violation_handle.value = f"PM violation: {total_violation(x_pm):.4f}"
    al_violation_handle.value = f"AL violation: {total_violation(x_al):.4f}"

# --------------------------------------------------------------------------- #
#  Optimization loop -- runs both methods in lockstep                         #
# --------------------------------------------------------------------------- #

@optimize_button.on_click
def run_optimization(_):
    global x_pm, x_al, _run_id

    _run_id += 1
    my_run_id = _run_id

    lr            = lr_slider.value
    w_smooth      = w_smooth_slider.value
    penalty_scale = penalty_scale_slider.value
    w_max         = w_max_slider.value
    w_goal        = 1.0
    outer_every   = int(outer_every_slider.value)

    # Reset both trajectories to the straight-line initialisation.
    x_pm       = x0.ravel()
    w_obstacle = jnp.array(1.0)
    w_goal_jax = jnp.array(w_goal)

    x_al = x0.ravel()
    lms  = jnp.zeros_like(all_constraints(x_al))
    rho  = jnp.array(1.0)

    # Both methods now share the same penalty_scale and w_max/rho_max,
    # baked in at JIT time from the slider values read above.
    pm_inner_jit = jax.jit(lambda x, w_obs, w_g: pm_inner_step(x, w_obs, w_g, lr, w_smooth))
    pm_outer_jit = jax.jit(lambda x, w_obs, w_g: pm_outer_step(x, w_obs, w_g, penalty_scale, w_max))

    al_inner_jit = jax.jit(lambda x, lms, rho: al_inner_step(x, lms, rho, lr, w_smooth))
    al_outer_jit = jax.jit(lambda x, lms, rho: al_outer_step(x, lms, rho, penalty_scale, w_max))

    update_viz(x_pm, x_al)

    for i in range(num_iters_slider.value):
        if _run_id != my_run_id:
            return

        x_pm = pm_inner_jit(x_pm, w_obstacle, w_goal_jax)
        x_al = al_inner_jit(x_al, lms, rho)

        if i % outer_every == 0:
            x_pm, w_obstacle, w_goal_jax = pm_outer_jit(x_pm, w_obstacle, w_goal_jax)
            x_al, lms, rho               = al_outer_jit(x_al, lms, rho)
            update_viz(x_pm, x_al)

    update_viz(x_pm, x_al)

# --------------------------------------------------------------------------- #
#  Playback                                                                    #
# --------------------------------------------------------------------------- #

@play_pm_button.on_click
def play_pm(_):
    for t in range(num_timesteps):
        drone.position = onp.array(x_pm).reshape(-1, 3)[t]
        time.sleep(dt)

@play_al_button.on_click
def play_al(_):
    for t in range(num_timesteps):
        drone.position = onp.array(x_al).reshape(-1, 3)[t]
        time.sleep(dt)

# --------------------------------------------------------------------------- #
#  Scene                                                                       #
# --------------------------------------------------------------------------- #

drone_mesh = trimesh.load("../assets/quadcopter_drone.glb").apply_transform(
    trimesh.transformations.rotation_matrix(onp.pi / 2, [1, 0, 0])
)
drone = server.scene.add_mesh_trimesh("drone", drone_mesh, scale=0.25, position=init_pos)

while True:
    time.sleep(0.1)

KeyboardInterrupt: 