## Lecture 8: Constrained Optimization

This notebook is a motivating example for constrained optimization. We'll be trying to fly a drone through an 
environment with obstacles, and will look at how hard it is to balance the objectives of reaching the goal,
maintaining a smooth trajectory (dynamic feasibility), and avoiding obstacles (satisfying safety constraints).

Tuning all of these to guarantee collision avoidance *and* keep things numerically well-conditioned is hard,
which motivates the constrained methods we'll study in future lectures.

In [133]:
import time
from typing import Callable

import numpy as onp
import jax
import jaxlie
import trimesh
import viser
from jax import numpy as jnp

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Run the below cell, and click [here](http://localhost:8080) to open the visualization.

In [134]:
if 'server' not in dir() or server is None:
    server = viser.ViserServer()
else:
    server.scene.reset()
    server.gui.reset()

server.scene.add_grid("grid")

init_pos = jnp.array([-2.0, -2.0, 1.0])
final_pos = jnp.array([2.0, 2.0, 1.0])

obstacle_centers = jnp.array([
    [0.0, 0.0, 1.0],
    [-1.5, -1.5, 1.25],
    [-1.5, -0.75, 0.5],
    [1.5, 1.0, 0.75],
    [1.0, 1.75, 1.0],
])
obstacle_radii = jnp.array([1.0, 0.5, 0.5, 0.5, 0.5])
drone_radius = 0.1

# Add sphere obstacle visualization.
for i in range(len(obstacle_centers)):
    server.scene.add_icosphere(
        f"obstacle_{i}",
        radius=float(obstacle_radii[i]),
        color=(0.0, 0.0, 1.0),
        position=onp.array(obstacle_centers[i])
    )

num_timesteps = 100
dt = 2.0 / num_timesteps
x0 = jax.vmap(jnp.linspace, in_axes=(0, 0, None))(init_pos, final_pos, num_timesteps).T

trajectory_handle = server.scene.add_line_segments("trajectory", onp.array(jnp.stack([jnp.concatenate([init_pos[None, :], x0], axis=0)[:-1], x0], axis=1)), colors=(1.0, 0.0, 0.0))
knot_point_handle = server.scene.add_point_cloud("knot_points", onp.array(jnp.concatenate([init_pos[None, :], x0], axis=0)), colors=(1.0, 0.0, 0.0), point_size=0.025, point_shape="circle")
goal_handle = server.scene.add_icosphere("goal", radius=0.1, color=(1.0, 1.0, 0.0), position=onp.array(final_pos))
start_handle = server.scene.add_icosphere("start", radius=0.1, color=(0.0, 1.0, 0.0), position=onp.array(init_pos))


def goal_residual(x: jnp.ndarray) -> jnp.ndarray:
    x = x.reshape(-1, 3)
    return (x - final_pos).ravel()

def smoothness_residual(x: jnp.ndarray) -> jnp.ndarray:
    # Concatenate the init and final positions to force the trajectory to start and end at the correct places, then compute finite differences.
    x = jnp.concatenate([init_pos[None, :], x.reshape(-1, 3)], axis=0)
    return jnp.diff(x, axis=0).ravel() / dt

def single_obstacle_residual(x: jnp.ndarray, obstacle_center: jnp.ndarray, obstacle_radius: float) -> jnp.ndarray:
    x = x.reshape(-1, 3)
    squared_distances = jnp.sum((x - obstacle_center)**2, axis=1) 
    return jnp.maximum(0, (obstacle_radius + drone_radius) ** 2 - squared_distances).ravel()

def obstacle_residual(x: jnp.ndarray) -> jnp.ndarray:
    return jax.vmap(single_obstacle_residual, in_axes=(None, 0, 0))(x, obstacle_centers, obstacle_radii).ravel()

def build_objective(
    w_goal: float = 1.0,
    w_smooth: float = 0.5,
    w_obstacle: float = 0.01,
) -> tuple[Callable, Callable]:
    """Assemble cost and residual functions from components."""

    # Pair each residual function with its weight, filter out zero-weighted ones.
    all_pairs = [
        (goal_residual, w_goal),
        (smoothness_residual, w_smooth),
        (obstacle_residual, w_obstacle),
    ]


    residual_fns = []
    weights = []
    for fn, w in all_pairs:
        if w > 0:
            residual_fns.append(fn)
            weights.append(w)

    sqrt_weights = [jnp.sqrt(w) for w in weights]

    def _eval_blocks(x: jnp.ndarray) -> list[jnp.ndarray]:
        """Evaluate each residual function separately, returning a list of blocks."""
        return [fn(x) for fn in residual_fns]

    def residual_fn(x: jnp.ndarray) -> jnp.ndarray:
        """Evaluate the full residual vector r(x), applying weights to each block."""
        blocks = _eval_blocks(x)
        return jnp.concatenate([sw * b for sw, b in zip(sqrt_weights, blocks)])

    def cost_fn(x: jnp.ndarray) -> float:
        """Evaluate the scalar cost f(x) = 0.5 * ||r(x)||^2, using the separate blocks."""
        blocks = _eval_blocks(x)
        return 0.5 * sum(w * jnp.sum(b**2) for w, b in zip(weights, blocks))  # type: ignore

    return cost_fn, residual_fn


def grad_descent_step(
    x: jnp.ndarray,
    cost_fn: Callable[[jnp.ndarray], float],
    step_size: float = 0.01,
) -> jnp.ndarray:
    """Perform a single gradient descent step on the cost function."""
    grad_fn = jax.grad(cost_fn)
    grad = grad_fn(x)
    return x - step_size * grad

# Add GUI sliders for weights.
w_goal_slider = server.gui.add_slider("w_goal", min=0.0, max=10.0, step=0.1, initial_value=1.0)
w_smooth_slider = server.gui.add_slider("w_smooth", min=0.0, max=10.0, step=0.1, initial_value=1.0)
w_obstacle_slider = server.gui.add_slider("w_obstacle", min=0.0, max=10000.0, step=0.1, initial_value=100.0)

num_iters_slider = server.gui.add_slider("num_iters", min=1, max=100000, step=1, initial_value=10000)
lr_slider = server.gui.add_slider("learning_rate", min=0.0000001, max=0.1, step=0.00001, initial_value=0.0001)

optimize_button = server.gui.add_button("Optimize")
play_button = server.gui.add_button("Play")
x = x0.ravel()

@optimize_button.on_click
def run_optimization(_):
    global x
    cost_fn, residual_fn = build_objective(
        w_goal=w_goal_slider.value,
        w_smooth=w_smooth_slider.value,
        w_obstacle=w_obstacle_slider.value,
    )

    step_fn = jax.jit(lambda x: grad_descent_step(x, cost_fn, step_size=lr_slider.value))
    x = x0.ravel()
    for _ in range(num_iters_slider.value):
        x = step_fn(x)
        x_reshaped = jnp.concat([init_pos[None, :], x.reshape(-1, 3)], axis=0)
        trajectory_handle.points = onp.array(jnp.stack([x_reshaped[:-1], x_reshaped[1:]], axis=1))
        knot_point_handle.points = onp.array(x_reshaped)

@play_button.on_click
def play_trajectory(_):
    for t in range(num_timesteps):
        drone.position = onp.array(x).reshape(-1, 3)[t]
        time.sleep(dt)
        

drone_mesh = trimesh.load("../assets/quadcopter_drone.glb").apply_transform(trimesh.transformations.rotation_matrix(onp.pi/2, [1, 0, 0]))
drone = server.scene.add_mesh_trimesh("drone", drone_mesh, scale=0.25, position=init_pos)

while True:
    time.sleep(0.1)

KeyboardInterrupt: 