In [1]:
import jax.numpy as jnp
import jax
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import interact

from aa598.hw2_helper import simulate_dynamics
import cvxpy as cp
from cbfax.dynamics import *


Define system and problem parameters

In [None]:
robot = DynamicallyExtendedSimpleCar()
human = DynamicallyExtendedSimpleCar()

# The paper originally uses a kinematic single track model, whereas here, we use the bicycle model.
N = 30
timestep = 6
Q = np.diag(np.array([0,1,0,100]))
R_u = 0
R_udot = 0

v_min = 0
v_max = 30
steering_max = 0.5 # convert to radians later
steering_min = -0.5
a_min = -8
a_max = 3
j_min = -10 # j is jerk
j_max = 6

radius = 1.

# x_initial_leader = [12, 3, 0, 10]^T; [x_pos, y_pos, steering, v]^T
# x_initial_follower = [2, 5, 0, 10]^T
# x_ref_leader = x_ref_follower = [free, 5, 0, 10]^T


In [None]:
@jax.jit # TO BE CHANGED
def objective_function_base(robot_states, robot_controls, human_states_samples, coeff=[0.1, 0.3, 0.5, 0.4, 3., 3.]):
    # lower is better
    steering = robot_controls[:,1]
    acceleration = robot_controls[:,0]
    # steering effort
    turning_effort = (steering**2).mean() 
    # acceleration effort
    acceleration_effort = (acceleration**2).mean() 
    # speed limit
    speed = jax.nn.relu(robot_states[:,-1].max() - v_max) + jax.nn.relu(-robot_states[:,-1].min()) 
    # progress to goal
    progress = robot_states[-1,2]**2 + (robot_states[:,1]**2).mean()
    # collision 
    collision = jax.nn.relu(-(jnp.linalg.norm((robot_states - human_states_samples)[:,:,:2], axis=-1).min(-1) - radius).mean())
    # control limits
    control_limits = jax.nn.relu(steering.max() - steering_max) + jax.nn.relu(steering_min - steering.min()) + jax.nn.relu(acceleration.max() - acceleration_max) + jax.nn.relu(acceleration_min - acceleration.min()) 

    return jnp.dot(jnp.array(coeff), jnp.array([turning_effort, acceleration_effort, speed, progress, collision, control_limits]))
    
