In [1]:
import time
import numpy as np
import scipy.special
import jax.numpy as jnp
import matplotlib.pyplot as plt

from jax import jit, jacfwd, jacrev, random, vmap
from jax.config import config

config.update("jax_enable_x64", True)

#### Bernstein Coefficient Polynomials

In [2]:
def bernstein_coeff_order10_new(n, tmin, tmax, t_actual):
    l = tmax - tmin
    t = (t_actual - tmin) / l

    P0 = scipy.special.binom(n, 0) * ((1 - t) ** (n - 0)) * t ** 0
    P1 = scipy.special.binom(n, 1) * ((1 - t) ** (n - 1)) * t ** 1
    P2 = scipy.special.binom(n, 2) * ((1 - t) ** (n - 2)) * t ** 2
    P3 = scipy.special.binom(n, 3) * ((1 - t) ** (n - 3)) * t ** 3
    P4 = scipy.special.binom(n, 4) * ((1 - t) ** (n - 4)) * t ** 4
    P5 = scipy.special.binom(n, 5) * ((1 - t) ** (n - 5)) * t ** 5
    P6 = scipy.special.binom(n, 6) * ((1 - t) ** (n - 6)) * t ** 6
    P7 = scipy.special.binom(n, 7) * ((1 - t) ** (n - 7)) * t ** 7
    P8 = scipy.special.binom(n, 8) * ((1 - t) ** (n - 8)) * t ** 8
    P9 = scipy.special.binom(n, 9) * ((1 - t) ** (n - 9)) * t ** 9
    P10 = scipy.special.binom(n, 10) * ((1 - t) ** (n - 10)) * t ** 10

    P0dot = -10.0 * (-t + 1) ** 9
    P1dot = -90.0 * t * (-t + 1) ** 8 + 10.0 * (-t + 1) ** 9
    P2dot = -360.0 * t ** 2 * (-t + 1) ** 7 + 90.0 * t * (-t + 1) ** 8
    P3dot = -840.0 * t ** 3 * (-t + 1) ** 6 + 360.0 * t ** 2 * (-t + 1) ** 7
    P4dot = -1260.0 * t ** 4 * (-t + 1) ** 5 + 840.0 * t ** 3 * (-t + 1) ** 6
    P5dot = -1260.0 * t ** 5 * (-t + 1) ** 4 + 1260.0 * t ** 4 * (-t + 1) ** 5
    P6dot = -840.0 * t ** 6 * (-t + 1) ** 3 + 1260.0 * t ** 5 * (-t + 1) ** 4
    P7dot = -360.0 * t ** 7 * (-t + 1) ** 2 + 840.0 * t ** 6 * (-t + 1) ** 3
    P8dot = 45.0 * t ** 8 * (2 * t - 2) + 360.0 * t ** 7 * (-t + 1) ** 2
    P9dot = -10.0 * t ** 9 + 9 * t ** 8 * (-10.0 * t + 10.0)
    P10dot = 10.0 * t ** 9

    P0ddot = 90.0 * (-t + 1) ** 8
    P1ddot = 720.0 * t * (-t + 1) ** 7 - 180.0 * (-t + 1) ** 8
    P2ddot = 2520.0 * t ** 2 * (-t + 1) ** 6 - 1440.0 * t * (-t + 1) ** 7 + 90.0 * (-t + 1) ** 8
    P3ddot = 5040.0 * t ** 3 * (-t + 1) ** 5 - 5040.0 * t ** 2 * (-t + 1) ** 6 + 720.0 * t * (-t + 1) ** 7
    P4ddot = 6300.0 * t ** 4 * (-t + 1) ** 4 - 10080.0 * t ** 3 * (-t + 1) ** 5 + 2520.0 * t ** 2 * (-t + 1) ** 6
    P5ddot = 5040.0 * t ** 5 * (-t + 1) ** 3 - 12600.0 * t ** 4 * (-t + 1) ** 4 + 5040.0 * t ** 3 * (-t + 1) ** 5
    P6ddot = 2520.0 * t ** 6 * (-t + 1) ** 2 - 10080.0 * t ** 5 * (-t + 1) ** 3 + 6300.0 * t ** 4 * (-t + 1) ** 4
    P7ddot = -360.0 * t ** 7 * (2 * t - 2) - 5040.0 * t ** 6 * (-t + 1) ** 2 + 5040.0 * t ** 5 * (-t + 1) ** 3
    P8ddot = 90.0 * t ** 8 + 720.0 * t ** 7 * (2 * t - 2) + 2520.0 * t ** 6 * (-t + 1) ** 2
    P9ddot = -180.0 * t ** 8 + 72 * t ** 7 * (-10.0 * t + 10.0)
    P10ddot = 90.0 * t ** 8

    P = np.hstack((P0, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10))
    Pdot = np.hstack((P0dot, P1dot, P2dot, P3dot, P4dot, P5dot, P6dot, P7dot, P8dot, P9dot, P10dot)) / l
    Pddot = np.hstack((P0ddot, P1ddot, P2ddot, P3ddot, P4ddot, P5ddot, P6ddot, P7ddot, P8ddot, P9ddot, P10ddot)) / (l ** 2)
    return P, Pdot, Pddot

#### Initializations

In [3]:
x_min = -6.0
x_max = 6.0

y_min = -6.0
y_max = 6.0

t_fin = 2.0
num = 25

In [4]:
tot_time = np.linspace(0.0, t_fin, num)
tot_time_copy = tot_time.reshape(num, 1)
P, Pdot, Pddot = bernstein_coeff_order10_new(10, tot_time_copy[0], tot_time_copy[-1], tot_time_copy)
nvar = np.shape(P)[1]
num = np.shape(P)[0]

In [5]:
x_obs_temp = np.hstack((-2.0, -0.79, 3.0, 4.0))
y_obs_temp = np.hstack((-2.0, 1.0, -0.80, 2.0))
num_obs = np.shape(x_obs_temp)[0]

a_obs = 1.0
b_obs = 1.0

x_obs = np.ones((num_obs, num)) * x_obs_temp[:, np.newaxis]
y_obs = np.ones((num_obs, num)) * y_obs_temp[:, np.newaxis]

In [6]:
x_init = -2.87
y_init = 2.96
vx_init = 0.0
ax_init = 0.0
vy_init = 0.0
ay_init = 0.0

In [7]:
x_fin = 1.4
y_fin = 0.2
vx_fin = 0.0
ax_fin = 0.0
vy_fin = 0.0
ay_fin = 0.0

In [8]:
rho_obs = 2.0
rho_eq = 10.0
weight_smoothness = 100

In [9]:
A_eq = np.vstack((P[0], Pdot[0], Pddot[0], P[-1], Pdot[-1], Pddot[-1]))
A_obs = np.tile(P, (num_obs, 1))
Q_smoothness = np.dot(Pddot.T, Pddot)

In [10]:
P_jax = jnp.asarray(P)
A_eq_jax = jnp.asarray(A_eq)
A_obs_jax = jnp.asarray(A_obs)
x_obs_jax = jnp.asarray(x_obs)
y_obs_jax = jnp.asarray(y_obs)
Q_smoothness_jax = jnp.asarray(Q_smoothness)



In [11]:
# bx_eq = np.array([-1.6721, -0.0158,  0.2543, -0.5678,  0.0000,  0.0000])
# by_eq = np.array([2.1997, -1.7899, -0.6161, -0.7362,  0.0000,  0.0000])

bx_eq = np.array([1.2147, -0.8816,  0.1860,  0.0862,  1.1351,  1.0330])
by_eq = np.array([0.0876,  0.9048, 0.0106, -0.3246,  0.2031,  1.6398])
bx_eq, by_eq

(array([ 1.2147, -0.8816,  0.186 ,  0.0862,  1.1351,  1.033 ]),
 array([ 0.0876,  0.9048,  0.0106, -0.3246,  0.2031,  1.6398]))

#### Compute Solution

In [12]:
def compute_sol(rho_obs, rho_eq, weight_smoothness, num_obs, bx_eq, by_eq, P, Pdot, Pddot, x_obs, y_obs, a_obs, b_obs):
    maxiter = 300
    nvar = np.shape(P)[1]
    num = np.shape(P)[0]

    cost_smoothness = weight_smoothness * np.dot(Pddot.T, Pddot)

    alpha_obs = np.zeros((num_obs, num))
    d_obs = np.ones((num_obs, num))

    lamda_x = np.zeros(nvar)
    lamda_y = np.zeros(nvar)

    res_obs = np.ones(maxiter)
    res_eq = np.ones(maxiter)
    d_min = np.ones(maxiter)

    cost = cost_smoothness + rho_obs * np.dot(A_obs.T, A_obs)
    cost_mat = np.vstack((np.hstack((cost, A_eq.T)), np.hstack((A_eq, np.zeros((np.shape(A_eq)[0], np.shape(A_eq)[0]))))))
    cost_mat_inv = np.linalg.inv(cost_mat)

    for i in range(0, maxiter):
        temp_x_obs = d_obs * np.cos(alpha_obs) * a_obs
        b_obs_x = x_obs.reshape(num * num_obs) + temp_x_obs.reshape(num * num_obs)

        temp_y_obs = d_obs * np.sin(alpha_obs) * b_obs
        b_obs_y = y_obs.reshape(num * num_obs) + temp_y_obs.reshape(num * num_obs)

        lincost_x = - lamda_x - rho_obs * np.dot(A_obs.T, b_obs_x)
        lincost_y = - lamda_y - rho_obs * np.dot(A_obs.T, b_obs_y)

        sol_x = np.dot(cost_mat_inv, np.hstack((-lincost_x, bx_eq)))
        sol_y = np.dot(cost_mat_inv, np.hstack((-lincost_y, by_eq)))

        primal_x = sol_x[0:nvar]
        dual_x = sol_x[nvar:nvar + 6]

        primal_y = sol_y[0:nvar]
        dual_y = sol_y[nvar:nvar + 6]
    
        x = np.dot(P, primal_x)
        y = np.dot(P, primal_y)

        wc_alpha = (x - x_obs)
        ws_alpha = (y - y_obs)
        alpha_obs = np.arctan2(ws_alpha * a_obs, wc_alpha * b_obs)
        
        c1_d = 1.0 * rho_obs * (a_obs ** 2 * np.cos(alpha_obs) ** 2 + b_obs ** 2 * np.sin(alpha_obs) ** 2)
        c2_d = 1.0 * rho_obs * (a_obs * wc_alpha * np.cos(alpha_obs) + b_obs * ws_alpha * np.sin(alpha_obs))

        d_temp = c2_d / c1_d
        d_obs = np.maximum(np.ones((num_obs, num)), d_temp)
        d_min[i] = np.amin(d_temp)

        res_x_obs_vec = wc_alpha - a_obs * d_obs * np.cos(alpha_obs)
        res_y_obs_vec = ws_alpha - b_obs * d_obs * np.sin(alpha_obs)

        lamda_x = lamda_x - rho_obs * np.dot(A_obs.T, res_x_obs_vec.reshape(num_obs * num))
        lamda_y = lamda_y - rho_obs * np.dot(A_obs.T, res_y_obs_vec.reshape(num_obs * num))
        
        res_obs[i] = np.linalg.norm(np.hstack((res_x_obs_vec, res_y_obs_vec)))

    slack_obs = np.sqrt((d_obs - 1))
    return x, y, primal_x, primal_y, dual_x, dual_y, alpha_obs.reshape(num_obs * num), d_obs.reshape(num_obs * num), lamda_x, lamda_y, slack_obs.reshape(num_obs * num)

In [13]:
x, y, primal_x, primal_y, dual_x, dual_y, alpha_obs, d_obs, lamda_x, lamda_y, slack_obs = compute_sol(rho_obs, rho_eq, weight_smoothness, num_obs, bx_eq, by_eq, P, Pdot, Pddot, x_obs, y_obs, a_obs, b_obs)

In [14]:
aug_sol = np.hstack((primal_x, primal_y, alpha_obs, d_obs))

In [16]:
aug_sol[:10]

array([ 1.2147    ,  1.03838   ,  0.87032667,  0.50331953,  0.30967066,
        0.30225991, -0.41747714,  0.00931756, -0.32192889, -0.14082   ])

In [17]:
aug_sol

array([ 1.2147    ,  1.03838   ,  0.87032667,  0.50331953,  0.30967066,
        0.30225991, -0.41747714,  0.00931756, -0.32192889, -0.14082   ,
        0.0862    ,  0.0876    ,  0.26856   ,  0.44999111,  0.05873069,
        0.63001091, -0.2179742 ,  0.04874829, -0.19779072, -0.33296   ,
       -0.36522   , -0.3246    ,  0.57594758,  0.602327  ,  0.62640225,
        0.64797424,  0.66755885,  0.68557653,  0.70216972,  0.71728074,
        0.73079064,  0.74262   ,  0.75275852,  0.76123235,  0.76803918,
        0.77308751,  0.7761721 ,  0.77700414,  0.77529287,  0.7708494 ,
        0.76366517,  0.75391886,  0.74189426,  0.72784852,  0.71195006,
        0.69449016,  0.6766204 , -0.42711221, -0.41076631, -0.40114934,
       -0.40106775, -0.4104395 , -0.42851249, -0.45485039, -0.4895181 ,
       -0.53280727, -0.58475464, -0.64465543, -0.71074508, -0.78018508,
       -0.84940287, -0.91469378, -0.97286206, -1.02165507, -1.05984848,
       -1.08698652, -1.10289044, -1.10711265, -1.09857435, -1.07

In [18]:
lamda_x_jax = jnp.asarray(lamda_x)
lamda_y_jax = jnp.asarray(lamda_y)

#### Cost Function

In [None]:
def cost_fun(aug_sol_jax, param_sol):
    x_init, vx_init, ax_init, x_fin, vx_fin, ax_fin, y_init, vy_init, ay_init, y_fin, vy_fin, ay_fin = param_sol

    bx_eq_jax = jnp.array(bx_eq)
    by_eq_jax = jnp.array(by_eq)

    c_x = aug_sol_jax[0:nvar]
    c_y = aug_sol_jax[nvar: 2 * nvar]

    num_tot = num_obs * num
    alpha_obs = aug_sol_jax[2 * nvar:2*nvar + num_tot]
    d_obs = aug_sol_jax[2 * nvar + num_tot:2 * nvar + 2 * num_tot]

    cost_smoothness_x = 0.5 * weight_smoothness * jnp.dot(c_x.T, jnp.dot(Q_smoothness_jax, c_x))
    cost_smoothness_y = 0.5 * weight_smoothness * jnp.dot(c_y.T, jnp.dot(Q_smoothness_jax, c_y))

    temp_x_obs = d_obs * jnp.cos(alpha_obs) * a_obs
    b_obs_x = x_obs_jax.reshape(num * num_obs) + temp_x_obs

    temp_y_obs = d_obs * jnp.sin(alpha_obs) * b_obs
    b_obs_y = y_obs_jax.reshape(num * num_obs) + temp_y_obs

    cost_obs_x = 0.5 * rho_obs * (jnp.sum((jnp.dot(A_obs_jax, c_x) - b_obs_x) ** 2))
    cost_obs_y = 0.5 * rho_obs * (jnp.sum((jnp.dot(A_obs_jax, c_y) - b_obs_y) ** 2))
    cost_slack = 0.5 * rho_obs * jnp.sum(jnp.maximum(jnp.zeros(num_tot), -d_obs + 1))

    cost_eq_x = 0.5 * rho_eq * (jnp.sum((jnp.dot(A_eq_jax, c_x) - bx_eq_jax) ** 2))
    cost_eq_y = 0.5 * rho_eq * (jnp.sum((jnp.dot(A_eq_jax, c_y) - by_eq_jax) ** 2))
    
    cost_x = cost_smoothness_x + cost_obs_x - jnp.dot(lamda_x_jax.T, c_x)
    cost_y = cost_smoothness_y + cost_obs_y - jnp.dot(lamda_y_jax.T, c_y)
    
    eps = 10 ** (-8.0)
    cost = cost_x + cost_y + eps * jnp.sum(c_x ** 2) + eps * jnp.sum(c_y ** 2) + eps * jnp.sum(d_obs ** 2) + eps * jnp.sum(alpha_obs ** 2) + cost_slack
    return cost 

In [20]:
aug_sol_jax = jnp.asarray(aug_sol)
params = jnp.hstack((x_init, vx_init, ax_init, x_fin, vx_fin, ax_fin, y_init, vy_init, ay_init, y_fin, vy_fin, ay_fin))
# cost_fun(aug_sol_jax, params)

In [21]:
# x_init, vx_init, ax_init, x_fin, vx_fin, ax_fin, y_init, vy_init, ay_init, y_fin, vy_fin, ay_fin = params

bx_eq_jax = jnp.array(bx_eq)
by_eq_jax = jnp.array(by_eq)

c_x = aug_sol_jax[0:nvar]
c_y = aug_sol_jax[nvar: 2 * nvar]

num_tot = num_obs * num
alpha_obs = aug_sol_jax[2 * nvar:2*nvar + num_tot]
d_obs = aug_sol_jax[2 * nvar + num_tot:2 * nvar + 2 * num_tot]

cost_smoothness_x = 0.5 * weight_smoothness * jnp.dot(c_x.T, jnp.dot(Q_smoothness_jax, c_x))
cost_smoothness_y = 0.5 * weight_smoothness * jnp.dot(c_y.T, jnp.dot(Q_smoothness_jax, c_y))

temp_x_obs = d_obs * jnp.cos(alpha_obs) * a_obs
b_obs_x = x_obs_jax.reshape(num * num_obs) + temp_x_obs

temp_y_obs = d_obs * jnp.sin(alpha_obs) * b_obs
b_obs_y = y_obs_jax.reshape(num * num_obs) + temp_y_obs

cost_obs_x = 0.5 * rho_obs * (jnp.sum((jnp.dot(A_obs_jax, c_x) - b_obs_x) ** 2))
cost_obs_y = 0.5 * rho_obs * (jnp.sum((jnp.dot(A_obs_jax, c_y) - b_obs_y) ** 2))
cost_slack = 0.5 * rho_obs * jnp.sum(jnp.maximum(jnp.zeros(num_tot), -d_obs + 1))

cost_eq_x = 0.5 * rho_eq * (jnp.sum((jnp.dot(A_eq_jax, c_x) - bx_eq_jax) ** 2))
cost_eq_y = 0.5 * rho_eq * (jnp.sum((jnp.dot(A_eq_jax, c_y) - by_eq_jax) ** 2))

cost_x = cost_smoothness_x + cost_obs_x - jnp.dot(lamda_x_jax.T, c_x)
cost_y = cost_smoothness_y + cost_obs_y - jnp.dot(lamda_y_jax.T, c_y)

eps = 10 ** (-8.0)
cost = cost_x + cost_y + eps * jnp.sum(c_x ** 2) + eps * jnp.sum(c_y ** 2) + eps * jnp.sum(d_obs ** 2) + eps * jnp.sum(alpha_obs ** 2) + cost_slack

In [23]:
c_x

DeviceArray([ 1.2147    ,  1.03838   ,  0.87032667,  0.50331953,
              0.30967066,  0.30225991, -0.41747714,  0.00931756,
             -0.32192889, -0.14082   ,  0.0862    ], dtype=float64)

In [24]:
cost_smoothness_x

DeviceArray(3038.4687723, dtype=float64)

In [None]:
aug_sol[:10]

#### Compute argmin derivative

In [None]:
hess_inp = jit(jacfwd(jacrev(cost_fun)))
hess_param = jit(jacfwd(jacrev(cost_fun), argnums=1))

In [None]:
aug_sol = np.hstack((primal_x, primal_y, alpha_obs, d_obs))
aug_sol_jax = jnp.asarray(aug_sol)

params = jnp.hstack((x_init, vx_init, ax_init, x_fin, vx_fin, ax_fin, y_init, vy_init, ay_init, y_fin, vy_fin, ay_fin))

F_yy = hess_inp(aug_sol, params)
F_xy = hess_param(aug_sol, params)
F_yy_inv = jnp.linalg.inv(F_yy)

dgx = jnp.dot(-F_yy_inv, F_xy)

In [None]:
aug_sol.shape, params.shape

In [None]:
cost_fun(aug_sol_jax, params)

#### Testing

In [None]:
maxiter = 300
nvar = np.shape(P)[1]
num = np.shape(P)[0]

cost_smoothness = weight_smoothness * np.dot(Pddot.T, Pddot)

alpha_obs = np.zeros((num_obs, num))
d_obs = np.ones((num_obs, num))

lamda_x = np.zeros(nvar)
lamda_y = np.zeros(nvar)

res_obs = np.ones(maxiter)
res_eq = np.ones(maxiter)
d_min = np.ones(maxiter)

cost = cost_smoothness + rho_obs * np.dot(A_obs.T, A_obs)
cost_mat = np.vstack((np.hstack((cost, A_eq.T)), np.hstack((A_eq, np.zeros((np.shape(A_eq)[0], np.shape(A_eq)[0]))))))
cost_mat_inv = np.linalg.inv(cost_mat)

In [None]:
cost.shape, A_eq.shape

In [None]:
np.hstack((cost, A_eq.T)).shape

In [None]:
temp_x_obs = d_obs*np.cos(alpha_obs)*a_obs
b_obs_x = x_obs.reshape(num*num_obs)+temp_x_obs.reshape(num*num_obs)

temp_y_obs = d_obs*np.sin(alpha_obs)*b_obs
b_obs_y = y_obs.reshape(num*num_obs)+temp_y_obs.reshape(num*num_obs)

lincost_x = -lamda_x-rho_obs*np.dot(A_obs.T, b_obs_x)
lincost_y = -lamda_y-rho_obs*np.dot(A_obs.T, b_obs_y)

sol_x = np.dot(cost_mat_inv, np.hstack(( -lincost_x, bx_eq )))
sol_y = np.dot(cost_mat_inv, np.hstack(( -lincost_y, by_eq )))

In [None]:
sol_x.shape, primal_x.shape

In [None]:
lincost_x.shape

In [None]:
bx_eq.shape

In [None]:
np.hstack(( -lincost_x, bx_eq )).shape

In [None]:
sol