In [1]:
%load_ext autoreload
%autoreload 2

import jax
from jax import grad, jit, jvp, random, vjp
from jax.scipy.sparse.linalg import cg, gmres
import jax.numpy as jnp
import nopt
import matplotlib.pyplot as plt
from IPython.display import HTML

from nopt.problems import InvertedPendulum, CartPole
from nopt import solve

In [35]:
#ip = InvertedPendulum()
#bcs = dict(x0=jnp.array([jnp.pi, 0.]),
#           xN=jnp.array([0., 0.]))

ip = CartPole()
bcs = dict(
    x0=jnp.array([0., 0., 0., 0.]),
    xN=jnp.array([0., 0., jnp.pi, 0.])
)

N = 20

p = nopt.NlpProblem(ip, boundary_conditions=bcs, N=N)

statedim = ip.statedim

def outer_callback(z, lam):
    x, u = p._splitz(z)
    c_norm = jnp.linalg.norm(p.c(z))
    print(f"--- ||c(x)|| = {c_norm}")
    print(f"--- x(0) = {x[:statedim]}")
    print(f"--- x(T) = {x[-statedim:]}")
    

xstar = solve(p, outer_callback=outer_callback, max_iters=10, solver=gmres)


--iter: 1
--- ||c(x)|| = 1.8626306056976318
--- x(0) = [ 0.03019904  0.0262285  -0.09700286 -0.01722335]
--- x(T) = [-2.9963978e-02  3.5936688e-03  2.0836949e+00 -1.6812413e-03]

--iter: 2
--- ||c(x)|| = 1.639798641204834
--- x(0) = [-0.00178374  0.05163572 -0.04460081 -0.04516439]
--- x(T) = [ 9.9694356e-04 -4.7078557e-02  2.3385777e+00 -2.9578120e-01]

--iter: 3
--- ||c(x)|| = 1.3963439464569092
--- x(0) = [-0.0037143   0.03999659 -0.03462758 -0.00695843]
--- x(T) = [ 0.0030481  -0.04071717  2.465222    0.05327076]

--iter: 4
--- ||c(x)|| = 1.3261936902999878
--- x(0) = [-0.00403085  0.03283401 -0.03701124  0.01174335]
--- x(T) = [ 0.00382979 -0.03648121  2.5092976   0.14941612]

--iter: 5
--- ||c(x)|| = 1.3059134483337402
--- x(0) = [-0.00284904  0.03010602 -0.0587497   0.02573767]
--- x(T) = [ 0.00277922 -0.03295768  2.5370972   0.17275862]

--iter: 6
--- ||c(x)|| = 1.2814241647720337
--- x(0) = [-0.00258343  0.0295151  -0.05867387  0.03103646]
--- x(T) = [ 0.00256534 -0.03239405 

In [36]:
anim = p.plot(xstar);
HTML(anim.to_html5_video())

In [18]:
x, u = p._splitz(xstar)

In [22]:
xN = x.reshape(-1, 4)[-1]

In [24]:
x, _, theta, _ = xN

In [27]:
jnp.sin(theta)

DeviceArray(0.5409801, dtype=float32)

In [28]:
jnp.cos(theta)

DeviceArray(-0.84103537, dtype=float32)

In [34]:
list(p.plot(xstar).frame_seq)

KeyboardInterrupt: 