### ILQR example
1. The example shows how to compute the ilqr solution to a 1 dof actuated pendulum
2. At this point we are only penalizing the control inputs and the deviation from the final state in the overall cost

#### Discrete dynamics
1. In the following the symbolic dynamics of the pendulum is defined.
2. For convenience, lambda functions are used for:
    - The discrete dynamics $x_{k+1} = f(u_k, x_k)$
    - The linearized system with $A = \frac{\partial f}{\partial x}$ and $B = \frac{\partial f}{\partial u}$

In [11]:
import jax
import jax.numpy as jnp
from jax import make_jaxpr, jit, jacfwd, jacrev
from functools import partial

M = 1 # kg
G = 9.8 # ms^-2
L = 1 # m
DT = 0.005 # s

class PendulumDynamics:
    def __init__(self, m, g, L, dt):
        # states
        self.m_ = m
        self.g_ = g
        self.L_ = L
        self.dt_ = dt

        # dynamics
        @jit
        def f(x, u):
            return jnp.array([
                [dt*x[1] + x[0]],                                   # [dt*theta_dot + theta]
                [x[1] + dt*(-L*g*m*jnp.sin(x[0]) + u)/(L**2*m)]     # [theta_dot + dt*(-L*g*m*sin(theta) + u)/(L**2*m)]
                ])
        self.f_ = f

    # Note: his is related to https://github.com/google/jax/issues/1251, the main issue is 
    # with the 'self' part that indicates an arbitrary class. By using static_argnums we're 
    # telling jit to compile only the computation that gets applied to the other arguments, 
    # and just to re-trace and re-compile every time the first argument changes its Python 
    # object id
    @partial(jit, static_argnums=(0,))
    def f(self, x, u):
        return self.f_(x, u)

    # @partial(jit, static_argnums=(0,))
    def f_x(self, x, u):
        return jacfwd(self.f_, 0)(x, u)

    # @partial(jit, static_argnums=(0,))
    def f_u(self, x, u):
        return jacfwd(self.f_, 1)(x, u)

    # @partial(jit, static_argnums=(0,))
    def f_xx(self, x, u):
        return jnp.squeeze(jacfwd(jacrev(self.f_))(x, u), axis=1)

    # @partial(jit, static_argnums=(0,))
    def f_uu(self, x, u):
        return jacfwd(jacrev(self.f_, 1), 1)(x, u)

    # @partial(jit, static_argnums=(0,))
    def f_xu(self, x, u):
        return jnp.squeeze(jacfwd(jacrev(self.f_, 1), 0)(x, u), axis=1)


dyn = PendulumDynamics(M, G, L, DT)

x = jnp.array([0.1, 0.1])
u = jnp.array(0.5)

print(x.shape)
print(u.shape)

# f__ = jit(dyn.f(x,u))

%time _ = dyn.f(x,u)
# %time _ = dyn.f(x,u)

# %time _ = dyn.f_x(x,u)
# %time _ = dyn.f_x(x,u)

# %time _ = dyn.f_u(x,u)
# %time _ = dyn.f_u(x,u)

# %time _ = dyn.f_xx(x,u)
# %time _ = dyn.f_xx(x,u)

# %time _ = dyn.f_uu(x,u)
# %time _ = dyn.f_uu(x,u)

# %time _ = dyn.f_xu(x,u)
# %time _ = dyn.f_xu(x,u)

print(dyn.f(x,u))

(2,)
()
CPU times: user 49.5 ms, sys: 214 µs, total: 49.7 ms
Wall time: 48.4 ms
[[0.1005    ]
 [0.09760816]]


#### Iterative LQR algorithm
1. 

In [47]:
class IterativeLQR:
    def __init__(self, init_state, target_state, initial_guess, dt, start_time, end_time, dynamics, Q_k, R_k, Q_T, n_iterations):
        """Initialization of ILQR

        Args:
            init_state (ndarray): Initial state
            target_state (ndarray): Target state
            initial_guess (ndarray): Initial guess for ilqr
            dt (double): Sampling time for discrete system
            start_time (double): Starting time, defaults to 0 for a single trajectory
            end_time (double): Ending time, defaults to final time for a single trajectory
            dynamics (PendulumDynamics): Dynamics x_{k_+1} = f(x_k, u_k) and its respective derivatives
            Q_k (ndarray): Weights for states in the running cost
            R_k (ndarray): Weights for inputs in the running cost
            Q_T (ndarray): Weights for states in the terminal cost
            n_iterations (double): Maximum interations for ilqr
        """
        
        # states
        self.init_state_ = init_state
        self.target_state_ = target_state
        self.inputs_ = initial_guess
        self.n_states_ = jnp.shape(init_state)[0] # The dimensions of the state vector
        self.n_inputs_ = jnp.shape(initial_guess)[1] # The dimension of the control vector

        # timing
        self.dt_ = dt
        self.start_time_ = start_time
        self.end_time_ = end_time
        self.time_span_ = jnp.arange(start_time, end_time, dt).flatten()
        self.n_timesteps_ = len(self.time_span_)

        # cost due to dynamics
        self.dynamics_ = dynamics

        # weighting for loss function, i.e. L = x_T^T Q_T x_T + sum of (x_k^T Q_k x_k + u_k^T R_k u_k)
        self.Q_k_ = Q_k # Weight for state vector
        self.R_k_ = R_k # Weight for control vector
        self.Q_T_ = Q_T # Weight for terminal state

        # max iterations to run
        self.n_iterations_ = n_iterations

        # costs
        self.expected_cost_reduction_ = 0
        self.expected_cost_reduction_grad_ = 0
        self.expected_cost_reduction_hess_ = 0

#     @partial(jit, static_argnums=(0,))
    def rollout(self):
        """Rollout of the simulated system given an initial state

        Returns:
            ndarray: States trajectory from the rollout
            ndarray: Inputs trajectory from the rollout
        """
        # we store states and inputs as:
        # state = [., x_1, x_2, ..., x_N]
        # input =    [u_0, u_1, ..., u_{N-1}]
        # the first value in state is understood as self.init_state_
        states = jnp.zeros((self.n_timesteps_+1, self.n_states_)) # including initial state, x_0 to x_N
        inputs = jnp.zeros((self.n_timesteps_, self.n_inputs_)) # u_0 to u_{N-1}

        # store x0
        states = jax.ops.index_update(states,  jax.ops.index[0,:], self.init_state_)

        # integration of the system dynamics
        current_state = self.init_state_
        for i in range(0, self.n_timesteps_): # 0 to N-1
            current_input = jnp.squeeze(self.inputs_[i,:]) # u_k
#             print("i = {}, current_input = {}, current_state = {}".format(i, current_input, current_state))
            next_state = jnp.squeeze(self.dynamics_.f(current_state, current_input), axis=1) # x_{k+1} = f(x_k, u_k)
#             print("i = {}, next_state = {}".format(i, next_state))
            # store both u_k and x_{k+1}
            states = jax.ops.index_update(states,  jax.ops.index[i+1,:], next_state)
            inputs = jax.ops.index_update(inputs,  jax.ops.index[i,:], current_input)
            # update current state
            current_state = jax.ops.index_update(current_state,  jax.ops.index[:], next_state)

        return states, inputs

start_time = 0
end_time = 2
time_span = jnp.arange(start_time, end_time, DT)

# set states (start and end states are at rest)
n_states = 2 # position and velocity
n_inputs = 1 # inputs to the system
init_state = jnp.array([0.1, 0.])
target_state = jnp.array([jnp.pi, 0.])

# initial guess for control inputs
initial_input_guess = 0.1*jnp.ones((time_span.shape[0], n_inputs))

# define weights
Q_k = jnp.zeros((n_states, n_states)) # just find a valid trajectory first
R_k = 0.001*jnp.eye(n_inputs)
Q_T = 100*jnp.eye(n_states)

# iterations
n_iterations = 50

# ilqr
dynamics = PendulumDynamics(M, G, L, DT)
ilqr = IterativeLQR(init_state, target_state, initial_input_guess, DT, start_time, end_time, dynamics, Q_k, R_k, Q_T, n_iterations)

%time [states, inputs] = ilqr.rollout()
# %time [states, inputs] = ilqr.rollout()

# print(inputs)
# # print(inputs)
# print(initial_input_guess)


AttributeError: 'tuple' object has no attribute 'block_until_ready'

In [44]:
print(states)

[[ 0.1         0.        ]
 [ 0.1        -0.00439184]
 [ 0.09997804 -0.00878367]
 [ 0.09993412 -0.01317444]
 [ 0.09986825 -0.01756307]
 [ 0.09978044 -0.02194848]
 [ 0.0996707  -0.02632961]
 [ 0.09953905 -0.0307054 ]
 [ 0.09938552 -0.03507476]
 [ 0.09921015 -0.03943664]
 [ 0.09901297 -0.04378996]
 [ 0.09879402 -0.04813368]
 [ 0.09855335 -0.05246671]
 [ 0.09829102 -0.05678801]
 [ 0.09800708 -0.06109652]
 [ 0.09770159 -0.06539118]
 [ 0.09737464 -0.06967095]
 [ 0.09702629 -0.07393476]
 [ 0.09665661 -0.07818159]
 [ 0.0962657  -0.0824104 ]
 [ 0.09585365 -0.08662013]
 [ 0.09542055 -0.09080977]
 [ 0.0949665  -0.09497829]
 [ 0.09449161 -0.09912466]
 [ 0.09399599 -0.10324786]
 [ 0.09347975 -0.10734688]
 [ 0.09294302 -0.11142073]
 [ 0.09238592 -0.11546838]
 [ 0.09180857 -0.11948886]
 [ 0.09121113 -0.12348116]
 [ 0.09059372 -0.12744431]
 [ 0.0899565  -0.13137734]
 [ 0.08929961 -0.13527927]
 [ 0.08862322 -0.13914913]
 [ 0.08792748 -0.14298598]
 [ 0.08721255 -0.14678888]
 [ 0.08647861 -0.15055688]
 

In [13]:
def f(x, u):
    return jnp.array([
        [DT*x[1] + x[0]],                                   # [dt*theta_dot + theta]
        [x[1] + DT*(-L*G*M*jnp.sin(x[0]) + u)/(L**2*M)]     # [theta_dot + dt*(-L*g*m*sin(theta) + u)/(L**2*m)]
        ])

states = jnp.zeros((1000+1, 2)) # including initial state, x_0 to x_N
inputs = jnp.zeros((1000, 1)) # u_0 to u_{N-1}
current_state = jnp.array([0.1, 0.1])
inputs_ = jnp.array(0.5)
for i in range(0, 100): # 0 to N-1
    current_input = inputs_ # u_k
    next_state = f(current_state, current_input) # x_{k+1} = f(x_k, u_k) 
    # store both u_k and x_{k+1}
    # states[i+1,:] = np.array(next_state)
    # inputs[i,:] = np.squeeze(np.array(current_input))
    next_state = jnp.squeeze(next_state, axis=1)
    # print(next_state.shape)
    # print(next_state.shape)
    # print(current_state[0])
    # print(current_state[1])
    # print(next_state[0])
    # print(next_state[1])
    # current_state[0] = current_state.at[0].set(next_state[0])
    # current_state[1] = current_state.at[1].set(next_state[1])
    # update current state
    current_state = jax.ops.index_update(current_state, jax.ops.index[:], next_state)

https://en.wikipedia.org/wiki/Backtracking_line_search
