In [90]:
import jax.numpy as jnp
import jax
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import interact

from aa598.hw2_helper import simulate_dynamics
import cvxpy as cp
from cbfax.dynamics import *


In [91]:
robot = DynamicallyExtendedSimpleCar() # robot dynamics
human = DynamicallyExtendedSimpleCar() # human dynamics

@jax.jit
def obstacle_constraint(state, obstacle, radius):
    return jnp.linalg.norm(state[:2] - obstacle[:2]) - radius


In [92]:
planning_horizon = 25
num_time_steps = 30
num_sqp_iterations = 15
dt = 0.1
t = 0. # this doesn't affect anything, but a value is needed 
radius = 1. # minimum collision distance

v_max = 1.5
v_min = 0.
acceleration_max = 1.0
acceleration_min = -1.0
steering_max = 0.3
steering_min = -0.3

human_control_prediction_noise_limit = 0.25
human_control_prediction_variance = 0.25


In [93]:
xs = cp.Variable([planning_horizon+1, robot.state_dim])  # cvx variable for states
us = cp.Variable([planning_horizon, robot.control_dim])  # cvx variable for controls
slack = cp.Variable(1) # slack variable to make sure the problem is feasible
As = [cp.Parameter([robot.state_dim, robot.state_dim]) for _ in range(planning_horizon)]  # parameters for linearized dynamics
Bs = [cp.Parameter([robot.state_dim, robot.control_dim]) for _ in range(planning_horizon)] # parameters for linearized dynamics
Cs = [cp.Parameter([robot.state_dim]) for _ in range(planning_horizon)] # parameters for linearized dynamics

Gs = [cp.Parameter([robot.state_dim]) for _ in range(planning_horizon+1)] # parameters for linearized constraints
hs = [cp.Parameter(1) for _ in range(planning_horizon+1)] # parameters for linearized constraints

xs_previous = cp.Parameter([planning_horizon+1, robot.state_dim]) # parameter for previous solution
us_previous = cp.Parameter([planning_horizon, robot.control_dim]) # parameter for previous solution
initial_state = cp.Parameter([robot.state_dim]) # parameter for current robot state


In [94]:
beta1 = 0.2 # coefficient for control effort
beta2 = 2. # coefficient for progress
beta3 = 10. # coefficient for trust region
slack_penalty = 10000. # coefficient for slack variable
markup = 1.0

objective = beta2 * (xs[-1,2]**2 + xs[-1,1]**2 - xs[-1,0]) + beta3 * (cp.sum_squares(xs - xs_previous) + cp.sum_squares(us - us_previous)) + slack_penalty * slack**2
constraints = [xs[0] == initial_state, slack >= 0] # initial state and slack constraint 
for t in range(planning_horizon):
    objective += beta1 * cp.sum_squares(us[t]) * markup**t
    constraints += [xs[t+1] == As[t] @ xs[t] + Bs[t] @ us[t] + Cs[t]] # dynamics constraint
    constraints += [xs[t,-1] <= v_max, xs[t,-1] >= v_min, us[t,0] <= acceleration_max, us[t,0] >= acceleration_min, us[t,1] <= steering_max, us[t,1] >= steering_min] # control limit constraints
    constraints += [Gs[t] @ xs[t] + hs[t] >= -slack] # linearized collision avoidance constraint
constraints += [xs[planning_horizon,-1] <= v_max, xs[planning_horizon,-1] >= v_min, Gs[planning_horizon] @ xs[planning_horizon] + hs[planning_horizon] >= 0] # constraints for last planning horizon step
prob = cp.Problem(cp.Minimize(objective), constraints) # construct problem
    


In [95]:
# initial states
robot_state = jnp.array([-3.0, -0., 0., 1.])  # robot starting state
human_state = jnp.array([-1., -2., jnp.pi/2, 1.]) # human starting state

robot_trajectory = [robot_state] # list to collect robot's state as it replans
human_trajectory = [human_state] # list to collect humans's state
robot_control_list = []  # list to collect robot's constrols as it replans
robot_trajectory_list = [] # list to collect robot's planned trajectories

# initial robot planned state and controls
previous_controls = jnp.zeros([planning_horizon, robot.control_dim]) # initial guess for robot controls
previous_states =  simulate_dynamics(robot, robot_state, previous_controls, dt) # initial guess for robot states
xs_previous.value = np.array(previous_states) # set xs_previous parameter value
us_previous.value = np.array(previous_controls) # set us_previous parameter value 

# jit the linearize dynamics and constraint functions to make it run faster
linearize_dynamics = jax.jit(lambda states, controls, ti: jax.vmap(linearize, [None, 0, 0, None])(lambda s, c, t: robot.discrete_step(s, c, t, dt), states, controls, ti))
linearize_obstacle = jax.jit(lambda states, controls, radius: jax.vmap(jax.grad(obstacle_constraint), [0, 0, None])(states, controls, radius))

In [96]:
solver = cp.CLARABEL

for t in range(num_time_steps):
    print("timestep: %i"% t)
    initial_state.value = np.array(robot_state)
    # simulate human future trajectory, assuming some noisy behavior
    noisy_human_control = jnp.clip(jnp.array(np.random.randn(planning_horizon, human.control_dim) * human_control_prediction_variance), -human_control_prediction_noise_limit, human_control_prediction_noise_limit)
    human_future = simulate_dynamics(human, human_state, noisy_human_control, dt)
    
    for i in range(num_sqp_iterations):
        # As_value, Bs_value, Cs_value = jax.vmap(linearize, [None, 0, 0, None])(lambda s, c, t: robot.discrete_step(s, c, t, dt), previous_states[:-1], previous_controls, t)
        As_value, Bs_value, Cs_value = linearize_dynamics( previous_states[:-1], previous_controls, t)
        # Gs_value = jax.vmap(jax.grad(obstacle_constraint), [0, 0, None])(previous_states, human_future, radius)
        Gs_value = linearize_obstacle(previous_states, human_future, radius)
        hs_value = jax.vmap(obstacle_constraint, [0, 0, None])(previous_states, human_future, radius) - jax.vmap(jnp.dot, [0, 0])(Gs_value, previous_states)

        for i in range(planning_horizon):
            As[i].value = np.array(As_value[i])
            Bs[i].value = np.array(Bs_value[i])
            Cs[i].value = np.array(Cs_value[i])
            Gs[i].value = np.array(Gs_value[i])
            hs[i].value = np.array(hs_value[i:i+1])
        Gs[planning_horizon].value = np.array(Gs_value[planning_horizon])
        hs[planning_horizon].value = np.array(hs_value[planning_horizon:planning_horizon+1])
        
        result = prob.solve(solver=solver)

        # previous_states = xs.value
        previous_controls = us.value
        previous_states =  simulate_dynamics(robot, robot_state, previous_controls, dt)
        xs_previous.value = np.array(previous_states)
        us_previous.value = np.array(previous_controls)
        
    robot_control = previous_controls[0]
    robot_control_list.append(robot_control)
    # robot takes a step
    robot_state = robot.discrete_step(robot_state, robot_control, 0., dt)
    robot_trajectory.append(robot_state)
    robot_trajectory_list.append(previous_states)
    
    
    human_random_control = jnp.clip(jnp.array(np.random.randn( human.control_dim) * human_control_prediction_variance), -human_control_prediction_noise_limit, human_control_prediction_noise_limit)
    # human states a step
    human_state = human.discrete_step(human_state, human_random_control, 0., dt)
    human_trajectory.append(human_state)
    
robot_trajectory = jnp.stack(robot_trajectory)
human_trajectory = jnp.stack(human_trajectory)
robot_controls = jnp.stack(robot_control_list)

    

timestep: 0
timestep: 1
timestep: 2
timestep: 3
timestep: 4
timestep: 5
timestep: 6
timestep: 7
timestep: 8
timestep: 9
timestep: 10
timestep: 11
timestep: 12
timestep: 13
timestep: 14
timestep: 15
timestep: 16
timestep: 17
timestep: 18
timestep: 19
timestep: 20
timestep: 21
timestep: 22
timestep: 23
timestep: 24
timestep: 25
timestep: 26
timestep: 27
timestep: 28
timestep: 29


In [97]:
# plotting
@interact(i=(0,num_time_steps-1))
def plot(i):
    fig, axs = plt.subplots(1,2, figsize=(18,8))
    ax = axs[0]
    robot_position = robot_trajectory[i, :2]
    human_position = human_trajectory[i, :2]
    circle1 = plt.Circle(robot_position, radius / 2, color='C0', alpha=0.4)
    circle2 = plt.Circle(human_position, radius / 2, color='C1', alpha=0.4)
    ax.add_patch(circle1)
    ax.add_patch(circle2)
    # ax.plot(human_samples[i,:,:,0].T, human_samples[i,:,:,1].T, "o-", alpha=0.1, markersize=2, color='C1')
    ax.plot(robot_trajectory[:,0], robot_trajectory[:,1], "o-", markersize=3, color='C0')
    ax.plot(robot_trajectory_list[i][:,0], robot_trajectory_list[i][:,1], "o-", markersize=3, color='C2', label="planned")

    ax.plot(human_trajectory[:,0], human_trajectory[:,1], "o-", markersize=3, color='C1')
    ax.scatter(robot_trajectory[i:i+1,0], robot_trajectory[i:i+1,1], s=30,  color='C0', label="Robot")
    ax.scatter(human_trajectory[i:i+1,0], human_trajectory[i:i+1,1], s=30,  color='C1', label="Human")
    ax.grid()
    ax.legend()

    ax.set_xlim([-4,4])
    ax.set_ylim([-3, 2])
    ax.axis("equal")

    ax.set_title("heading=%.2f velocity=%.2f"%(robot_trajectory[i,2], robot_trajectory[i,3]))
    
    ax = axs[1]
    plt.plot(robot_controls)
    plt.scatter([i], robot_controls[i:i+1, 0], label="Acceleration")
    plt.scatter([i], robot_controls[i:i+1, 1], label="Steering")
    ax.plot(robot_trajectory[:,-1], "o-", markersize=3, color='C0', label="Velocity")

    ax.legend()
    ax.grid()
    

interactive(children=(IntSlider(value=14, description='i', max=29), Output()), _dom_classes=('widget-interact'…