In [8]:
import jax.numpy as jnp
import jax
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import interact

from aa598.hw2_helper import simulate_dynamics
import cvxpy as cp
from cbfax.dynamics import *


Define system and problem parameters

In [9]:
robot = DynamicallyExtendedSimpleCar()
human = DynamicallyExtendedSimpleCar()
robot_state = jnp.array([12., 3., 0., 10.])
human_state = jnp.array([2., 5., 0., 10.])
# x_initial_leader = [12, 3, 0, 10]^T; [x_pos, y_pos, steering, v]^T
# x_initial_follower = [2, 5, 0, 10]^T
# x_ref_leader = x_ref_follower = [free, 5, 0, 10]^T

x_ref = jnp.array([50, 5., 0., 10.])

@jax.jit
def obstacle_constraint(state, obstacle, radius):
    return jnp.linalg.norm(state[:2] - obstacle[:2]) - radius

# The paper originally uses a kinematic single track model, whereas here, we use the bicycle model.
num_time_steps = 30
dt = 6 # timestep
Q = np.diag(np.array([0,1,0,100]))
R_u = np.diag(np.array([1,1]))
R_udot = np.diag(np.array([10000,1000]))
planning_horizon = 25
num_sqp_iterations = 15
t = 0. 

v_min = 0
v_max = 30
steering_max = 0.5 # converted from deg to radians
steering_min = -0.5
a_min = -8
a_max = 3
j_min = -10 # j is jerk
j_max = 6

alpha = 0.5 # courtesy constraint

radius = 1.

human_control_prediction_noise_limit = 0.25
human_control_prediction_variance = 0.25

In [10]:
xs_robot = cp.Variable([planning_horizon+1, robot.state_dim])  # cvx variable for states
xs_human = cp.Variable([planning_horizon+1, human.state_dim])
us_robot = cp.Variable([planning_horizon, robot.control_dim])  # cvx variable for controls
us_human = cp.Variable([planning_horizon, human.control_dim])
slack = cp.Variable(1) # slack variable to make sure the problem is feasible

As = [cp.Parameter([robot.state_dim, robot.state_dim]) for _ in range(planning_horizon)]  # parameters for linearized dynamics
Bs = [cp.Parameter([robot.state_dim, robot.control_dim]) for _ in range(planning_horizon)] # parameters for linearized dynamics
Cs = [cp.Parameter([robot.state_dim]) for _ in range(planning_horizon)] # parameters for linearized dynamics

Gs = [cp.Parameter([robot.state_dim]) for _ in range(planning_horizon+1)] # parameters for linearized constraints
hs = [cp.Parameter(1) for _ in range(planning_horizon+1)] # parameters for linearized constraints

xs_previous = cp.Parameter([planning_horizon+1, robot.state_dim]) # parameter for previous solution
us_previous = cp.Parameter([planning_horizon, robot.control_dim]) # parameter for previous solution
initial_state = cp.Parameter([robot.state_dim]) # parameter for current robot state

In [11]:
@jax.jit
def objective_cost_base(states, controls, reference_state, previous_controls, state_weight, control_weight, disturb_weight):
    cost = 0
    cost += cp.sum(cp.quad_form(states-reference_state, state_weight))
    cost += cp.sum(cp.quad_form(controls, control_weight))
    cost += cp.sum(cp.quad_form(controls-previous_controls, disturb_weight))
    cost += cp.quad_form(controls[0]-previous_controls, disturb_weight)

    return cost

In [12]:
robot_trajectory = [robot_state] # list to collect robot's state as it replans
human_trajectory = [human_state] # list to collect humans's state
robot_control_list = []  # list to collect robot's constrols as it replans
robot_trajectory_list = [] # list to collect robot's planned trajectories

# initial robot planned state and controls
previous_controls = jnp.zeros([planning_horizon, robot.control_dim]) # initial guess for robot controls
previous_states =  simulate_dynamics(robot, robot_state, previous_controls, dt) # initial guess for robot states
xs_previous.value = np.array(previous_states) # set xs_previous parameter value
us_previous.value = np.array(previous_controls) # set us_previous parameter value 

# jit the linearize dynamics and constraint functions to make it run faster
linearize_dynamics = jax.jit(lambda states, controls, ti: jax.vmap(linearize, [None, 0, 0, None])(lambda s, c, t: robot.discrete_step(s, c, t, dt), states, controls, ti))
linearize_obstacle = jax.jit(lambda states, controls, radius: jax.vmap(jax.grad(obstacle_constraint), [0, 0, None])(states, controls, radius))

In [13]:
# omega = 0.5
# alpha = 0.5

# objective = objective_cost_base(xs_robot-)

In [14]:

planning_horizon = 25 # planning horizon to compute cost over
n_human_samples = 64 # number of human future trajectories to sample
n_robot_samples = 32 # number of robot trajectories to sample for MPPI
human_control_prediction_noise_limit = 0.25
human_control_prediction_variance = 0.25
robot_control_noise_limit = 0.25
robot_control_noise_variance = 0.25

num_iterations = 20 # number of MPPI iteraciotns
num_time_steps = 50 # number of timesteps to simulate

# nominal controls
robot_nominal_controls = jnp.zeros([planning_horizon, robot.control_dim])
# assume human wants to follow a constant velocity mode (i.e., zero control input)
human_nominal_controls = jnp.zeros([planning_horizon, human.control_dim])

# making lists of things for plotting later
robot_trajectory = [robot_state]
human_trajectory = [human_state]
robot_controls_list = []
human_controls_list = []
human_samples = []
robot_nominal_controls_list = [robot_nominal_controls]

for ti in range(num_time_steps):
    # very simple human prediction model -- just gaussian noise about a constant velocity model.
    dus = jnp.clip(jnp.array(np.random.randn(n_human_samples, planning_horizon, human.control_dim) * human_control_prediction_variance), -human_control_prediction_noise_limit, human_control_prediction_noise_limit)
    human_controls_samples = jnp.clip(human_nominal_controls + dus, min=jnp.array([a_min, steering_min]), max=jnp.array([a_max, steering_max]))
    human_states_samples = jax.vmap(simulate_dynamics, [None, None, 0, None])(human, human_state, human_controls_samples, dt)
    human_samples.append(human_states_samples)
    
    for t in range(num_iterations):
        temperature = 1 - (t / num_iterations)
        # sampling robot control trajectories
        dus = jnp.clip(jnp.array(np.random.randn(n_robot_samples, planning_horizon, robot.control_dim) * robot_control_noise_variance), -robot_control_noise_limit, robot_control_noise_limit)
        robot_controls_samples = jnp.clip(robot_nominal_controls + dus, min=jnp.array([a_min, steering_min]), max=jnp.array([a_max, steering_max]))
        # simulate robot trajectory for each control trajectory sample
        robot_states_samples = jax.vmap(simulate_dynamics, [None, None, 0, None])(robot, robot_state, robot_controls_samples, dt)
        # evaluate cost of each robot trajectory sample
        trajectory_costs = jax.vmap(objective_cost_base)(robot_states_samples, robot_controls_samples, x_ref, us_previous, Q, R_u, R_udot)
        # weight for each trajectory sample
        weights = jax.nn.softmax(-trajectory_costs / temperature).reshape([-1, 1, 1])
        # compute new nominal control using weighted sum
        
        robot_nominal_controls = jnp.clip(robot_nominal_controls + (dus * weights).sum(0), min=jnp.array([a_min, steering_min]), max=jnp.array([a_max, steering_max]))
        
    # use final nominal control to step forward in time by one step
    robot_nominal_controls_list.append(robot_nominal_controls)
    robot_state = simulate_dynamics(robot, robot_state, robot_nominal_controls[:1], dt)[-1]
    human_state = simulate_dynamics(human, human_state, human_controls_samples[0][:1], dt)[-1]
    # collect the new state and controls for plotting purposes
    robot_trajectory.append(robot_state)
    human_trajectory.append(human_state)
    robot_controls_list.append(robot_nominal_controls[:1])
    human_controls_list.append(human_controls_samples[0][:1])

# turn things into jnp.array
robot_trajectory = jnp.stack(robot_trajectory)
human_trajectory = jnp.stack(human_trajectory)
human_samples = jnp.stack(human_samples)
robot_controls_list = jnp.concatenate(robot_controls_list, 0)
human_controls_list = jnp.concatenate(human_controls_list, 0)


ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (2 of them) had size 32, e.g. axis 0 of argument states of type float32[32,26,4];
  * some axes (2 of them) had size 4, e.g. axis 0 of argument reference_state of type float32[4];
  * some axes (2 of them) had size 2, e.g. axis 0 of argument control_weight of type int32[2,2];
  * one axis had size 25: axis 0 of argument previous_controls of type unknown

In [None]:
@interact(i=(0,num_time_steps-1))
def plot(i):
    fig, axs = plt.subplots(1,2, figsize=(18,8))
    ax = axs[0]
    robot_position = robot_trajectory[i, :2]
    human_position = human_trajectory[i, :2]
    circle1 = plt.Circle(robot_position, radius / 2, color='C0', alpha=0.4)
    circle2 = plt.Circle(human_position, radius / 2, color='C1', alpha=0.4)
    ax.add_patch(circle1)
    ax.add_patch(circle2)
    ax.plot(human_samples[i,:,:,0].T, human_samples[i,:,:,1].T, "o-", alpha=0.1, markersize=2, color='C1')
    ax.plot(robot_trajectory[:,0], robot_trajectory[:,1], "o-", markersize=3, color='C0')
    ax.plot(human_trajectory[:,0], human_trajectory[:,1], "o-", markersize=3, color='C1')
    ax.scatter(robot_trajectory[i:i+1,0], robot_trajectory[i:i+1,1], s=30,  color='C0', label="Robot")
    ax.scatter(human_trajectory[i:i+1,0], human_trajectory[i:i+1,1], s=30,  color='C1', label="Human")
    ax.grid()
    ax.legend()
    ax.axis("equal")
    ax.set_xlim([-4,4])
    ax.set_ylim([-3, 6])

    ax.set_title("heading=%.2f velocity=%.2f"%(robot_trajectory[i,2], robot_trajectory[i,3]))
    
    ax = axs[1]
    plt.plot(robot_controls_list)
    plt.scatter([i], robot_controls_list[i:i+1, 0], label="Acceleration")
    plt.scatter([i], robot_controls_list[i:i+1, 1], label="Steering")
    ax.plot(robot_trajectory[:,-1], "o-", markersize=3, color='C0', label="Velocity")

    ax.legend()
    ax.grid()
    