### ILQR example
1. The example shows how to compute the ilqr solution to a 1 dof actuated pendulum
2. At this point we are only penalizing the control inputs and the deviation from the final state in the overall cost

#### Discrete dynamics
1. In the following the symbolic dynamics of the pendulum is defined.
2. For convenience, lambda functions are used for:
    - The discrete dynamics $x_{k+1} = f(u_k, x_k)$
    - The linearized system with $A = \frac{\partial f}{\partial x}$ and $B = \frac{\partial f}{\partial u}$

In [9]:
import jax
import jax.numpy as jnp
from jax import make_jaxpr, jit, jacfwd, jacrev, ops, lax
from functools import partial

# dynamic parameters
m = 1 # kg
g = 9.8 # ms^-2
L = 1 # m
    
# initialize timing
dt = 0.005 # s
start_time = 0
end_time = 2
time_span = jnp.arange(start_time, end_time, dt)
n_timesteps = jnp.shape(time_span)[0]

n_states = 2 # position and velocity
n_inputs = 1 # inputs to the system
init_state = jnp.array([0.1, 0.])
target_state = jnp.array([jnp.pi, 0.])
init_inputs_guess = 0.1*jnp.ones((n_timesteps, n_inputs))
states = jnp.zeros((n_timesteps+1, n_states)) # including initial state, x_0 to x_N
inputs = jnp.zeros((n_timesteps, n_inputs))

# define weights
Q_k = jnp.zeros((n_states, n_states)) # just find a valid trajectory first
R_k = 0.001*jnp.eye(n_inputs)
Q_T = 100*jnp.eye(n_states)

# iterations
n_iterations = 50

################################ Dynamics ##############################
@jit
def f(x, u):
    return jnp.array([
        [dt*x[1] + x[0]],                                   # [dt*theta_dot + theta]
        [x[1] + dt*(-L*g*m*jnp.sin(x[0]) + u)/(L**2*m)]     # [theta_dot + dt*(-L*g*m*sin(theta) + u)/(L**2*m)]
        ])

@jit
def f_x(x, u):
    return jnp.squeeze(jacfwd(f, 0)(x, u), axis=1)

@jit
def f_u(x, u):
    return jacfwd(f, 1)(x, u)

@jit
def f_xx(x, u):
    return jnp.squeeze(jacfwd(jacrev(f))(x, u), axis=1)

@jit
def f_uu(x, u):
    return jacfwd(jacrev(f, 1), 1)(x, u)

@jit
def f_xu(x, u):
    return jnp.squeeze(jacfwd(jacrev(f, 1), 0)(x, u), axis=1)

################################ ILQR ##############################
# a = jnp.squeeze(init_input_guess[0,:])
print(init_inputs_guess.shape)

%time _ = f(init_state, jnp.squeeze(init_inputs_guess[0,:]))
%time _ = f(init_state, jnp.squeeze(init_inputs_guess[0,:]))

# %time _ = f_x(x,u)
# %time _ = f_x(x,u)

# %time _ = f_u(x,u)
# %time _ = f_u(x,u)

@jit
def rollout(x0, inputs, states):
    states = ops.index_update(states, ops.index[0,:], x0)
    current_state = x0    
    for i in range(0, inputs.shape[0]): # 0 to N-1
        current_input = jnp.squeeze(inputs[i,:]) # u_k
        next_state = jnp.squeeze(f(current_state, current_input), axis=1) # x_{k+1} = f(x_k, u_k)
        states = ops.index_update(states, ops.index[i+1,:], next_state)
        current_state = next_state

    return states
    
%time states = rollout(init_state, init_inputs_guess, states)
%time states = rollout(init_state, init_inputs_guess, states)

print(states)
    
# print(rollout(init_state, initial_input_guess))

(400, 1)
CPU times: user 40.5 ms, sys: 1 µs, total: 40.5 ms
Wall time: 40.1 ms
CPU times: user 1.37 ms, sys: 64 µs, total: 1.44 ms
Wall time: 1.4 ms
CPU times: user 706 ms, sys: 0 ns, total: 706 ms
Wall time: 706 ms
CPU times: user 732 ms, sys: 0 ns, total: 732 ms
Wall time: 731 ms
[[ 0.1         0.        ]
 [ 0.1        -0.00439184]
 [ 0.09997804 -0.00878367]
 [ 0.09993412 -0.01317444]
 [ 0.09986825 -0.01756307]
 [ 0.09978044 -0.02194848]
 [ 0.0996707  -0.02632961]
 [ 0.09953905 -0.0307054 ]
 [ 0.09938552 -0.03507476]
 [ 0.09921015 -0.03943664]
 [ 0.09901297 -0.04378996]
 [ 0.09879402 -0.04813368]
 [ 0.09855335 -0.05246671]
 [ 0.09829102 -0.05678801]
 [ 0.09800708 -0.06109652]
 [ 0.09770159 -0.06539118]
 [ 0.09737464 -0.06967095]
 [ 0.09702629 -0.07393476]
 [ 0.09665661 -0.07818159]
 [ 0.0962657  -0.0824104 ]
 [ 0.09585365 -0.08662013]
 [ 0.09542055 -0.09080977]
 [ 0.0949665  -0.09497829]
 [ 0.09449161 -0.09912466]
 [ 0.09399599 -0.10324786]
 [ 0.09347975 -0.10734688]
 [ 0.09294302 -

In [2]:
import jax
import jax.numpy as jnp
from jax import make_jaxpr, jit, jacfwd, jacrev, ops
from functools import partial

M = 1 # kg
G = 9.8 # ms^-2
L = 1 # m
DT = 0.005 # s

class PendulumDynamics:
    def __init__(self, m, g, L, dt):
        # states
        self.m_ = m
        self.g_ = g
        self.L_ = L
        self.dt_ = dt

        # dynamics
        @jit
        def fdyn(x, u):
            return jnp.array([
                [dt*x[1] + x[0]],                                   # [dt*theta_dot + theta]
                [x[1] + dt*(-L*g*m*jnp.sin(x[0]) + u)/(L**2*m)]     # [theta_dot + dt*(-L*g*m*sin(theta) + u)/(L**2*m)]
                ])
        self.f_ = fdyn

    # Note: his is related to https://github.com/google/jax/issues/1251, the main issue is 
    # with the 'self' part that indicates an arbitrary class. By using static_argnums we're 
    # telling jit to compile only the computation that gets applied to the other arguments, 
    # and just to re-trace and re-compile every time the first argument changes its Python 
    # object id
    @partial(jit, static_argnums=(0,))
    def f(self, x, u):
        return self.f_(x, u)

    @partial(jit, static_argnums=(0,))
    def f_x(self, x, u):
        return jnp.squeeze(jacfwd(self.f_, 0)(x, u), axis=1)

    @partial(jit, static_argnums=(0,))
    def f_u(self, x, u):
        return jacfwd(self.f_, 1)(x, u)

    @partial(jit, static_argnums=(0,))
    def f_xx(self, x, u):
        return jnp.squeeze(jacfwd(jacrev(self.f_))(x, u), axis=1)

    @partial(jit, static_argnums=(0,))
    def f_uu(self, x, u):
        return jacfwd(jacrev(self.f_, 1), 1)(x, u)

    @partial(jit, static_argnums=(0,))
    def f_xu(self, x, u):
        return jnp.squeeze(jacfwd(jacrev(self.f_, 1), 0)(x, u), axis=1)


dyn = PendulumDynamics(M, G, L, DT)

x = jnp.array([0.1, 0.1])
u = jnp.array(0.5)

print(x.shape)
print(u.shape)

# f__ = jit(dyn.f(x,u))

%time f_ = dyn.f(x,u)
%time _ = dyn.f(x,u)

%time f_x_ = dyn.f_x(x,u)
%time _ = dyn.f_x(x,u)

%time f_u_ = dyn.f_u(x,u)
%time _ = dyn.f_u(x,u)

print("f = \n{}".format(f_))
print("f_x = \n{}".format(f_x_))
print("f_u = \n{}".format(f_u_))


# %time _ = dyn.f_xx(x,u)
# %time _ = dyn.f_xx(x,u)

# %time _ = dyn.f_uu(x,u)
# %time _ = dyn.f_uu(x,u)

# %time _ = dyn.f_xu(x,u)
# %time _ = dyn.f_xu(x,u)

(2,)
()
CPU times: user 49.3 ms, sys: 55 µs, total: 49.4 ms
Wall time: 47.9 ms
CPU times: user 36 µs, sys: 2 µs, total: 38 µs
Wall time: 44.3 µs
CPU times: user 68.2 ms, sys: 0 ns, total: 68.2 ms
Wall time: 67.3 ms
CPU times: user 35 µs, sys: 3 µs, total: 38 µs
Wall time: 42.4 µs
CPU times: user 56.1 ms, sys: 0 ns, total: 56.1 ms
Wall time: 55.3 ms
CPU times: user 33 µs, sys: 3 µs, total: 36 µs
Wall time: 41 µs
f = 
[[0.1005    ]
 [0.09760816]]
f_x = 
[[ 1.         0.005    ]
 [-0.0487552  1.       ]]
f_u = 
[[0.   ]
 [0.005]]


#### Iterative LQR algorithm
1. 

In [3]:
class IterativeLQR:
    def __init__(self, init_state, target_state, init_inputs_guess, dt, start_time, end_time, dynamics, Q_k, R_k, Q_T, n_iterations):
        """Initialization of ILQR

        Args:
            init_state (ndarray): Initial state
            target_state (ndarray): Target state
            initial_guess (ndarray): Initial guess for ilqr
            dt (double): Sampling time for discrete system
            start_time (double): Starting time, defaults to 0 for a single trajectory
            end_time (double): Ending time, defaults to final time for a single trajectory
            dynamics (PendulumDynamics): Dynamics x_{k_+1} = f(x_k, u_k) and its respective derivatives
            Q_k (ndarray): Weights for states in the running cost
            R_k (ndarray): Weights for inputs in the running cost
            Q_T (ndarray): Weights for states in the terminal cost
            n_iterations (double): Maximum interations for ilqr
        """
        
        # timing
        self.dt_ = dt
        self.start_time_ = start_time
        self.end_time_ = end_time
        self.time_span_ = jnp.arange(start_time, end_time, dt).flatten()
        self.n_timesteps_ = len(self.time_span_)
        
        # states
        self.init_state_ = init_state
        self.target_state_ = target_state
        self.n_states_ = jnp.shape(init_state)[0] # The dimensions of the state vector
        self.n_inputs_ = jnp.shape(init_inputs_guess)[1] # The dimension of the control vector
        
        self.states_ = jnp.zeros((self.n_timesteps_+1, self.n_states_)) # including initial state, x_0 to x_N
        self.inputs_ = jnp.zeros((self.n_timesteps_, self.n_inputs_))
        self.inputs_ = ops.index_update(self.inputs_, ops.index[:], init_inputs_guess)

        # cost due to dynamics
        self.dynamics_ = dynamics

        # weighting for loss function, i.e. L = x_T^T Q_T x_T + sum of (x_k^T Q_k x_k + u_k^T R_k u_k)
        self.Q_k_ = Q_k # Weight for state vector
        self.R_k_ = R_k # Weight for control vector
        self.Q_T_ = Q_T # Weight for terminal state

        # max iterations to run
        self.n_iterations_ = n_iterations

        # costs
        self.expected_cost_reduction_ = 0
        self.expected_cost_reduction_grad_ = 0
        self.expected_cost_reduction_hess_ = 0
        

#     @partial(jit, static_argnums=(0,))
    def compute_cost(self, states, inputs):
        """Computes the cost from all the terms, i.e. dynamics and cost as well as their derivatives:
        f_x, f_u, f_xx, f_ux, f_uu, 
        l_x, l_u, l_xx, l_ux, l_uu

        Args:
            states (ndarray): State trajectory
            inputs (ndarray): Input trajectory

        Returns:
            double: Total cost, i.e. terminal cost + running cost
        """
        # dynamics first derivatives

        # dynamics second derivatives

        # cost first derivatives

        # cost second derivatives

        # accumulate cost to go
        total_cost = 0
        for i in range(0, self.n_timesteps_):
#             current_x = states[i,:]
#             current_u = inputs[i,:].flatten()
#             current_cost = current_u.T @ self.R_k_ @ current_u
#             total_cost = total_cost + current_cost
            total_cost += inputs[i,:].T @ self.R_k_ @ inputs[i,:]
        # add terminal cost
        terminal_diff = (states[-1,:] - self.target_state_).flatten()
        terminal_cost = terminal_diff.T @ self.Q_T_ @ terminal_diff
        total_cost = total_cost + terminal_cost

        return total_cost
    
    def backward_pass(self):
        """Backward pass of iLQR

        Returns:
            ndarray: feedforward gain, k
            ndarray: feedback gain, K
            double: expected cost reduction
        """
        # starting from the last state
        V_xx = self.Q_T_ # since V_N = x_T^T Q_T x_T, V_xx(N) = Q_T_
#         end_difference = 
        V_x = self.Q_T_ @ (states[-1,:] - self.target_state_).T # V_x(N)

        # initialize control modifications to be stored
        k_trj = jnp.zeros((self.n_timesteps_, self.n_inputs_)) # (8b)
        K_trj = jnp.zeros((self.n_timesteps_, self.n_inputs_, self.n_states_)) # (8b)

        # initialize cost reduction
        expected_cost_reduction = 0
        expected_cost_reduction_grad = 0
        expected_cost_reduction_hess = 0

        # looping backwards from N-1 to 1 using initial value of V_{N}
        for i in reversed(range(0, self.n_timesteps_)):
            # current variables
            current_x = states[i,:]
            current_u = jnp.squeeze(inputs[i,:])

            # updates to partial derivatives of cost function
            l_xx = self.Q_k_
            l_uu = self.R_k_

            l_ux = jnp.zeros((self.n_inputs_, self.n_states_))
            l_x = self.Q_k_ @ jnp.zeros(self.n_states_).flatten()
            l_u = self.R_k_ @ (current_u).flatten()

            # get jacobian of discrete dynamics
            f_x = self.dynamics_.f_x(current_x, current_u) # V'_x
            f_u = self.dynamics_.f_u(current_x, current_u) # V'_u
    
            # all the Q vector/matrices
            Q_x = l_x + f_x.T @ V_x # (5a)
            Q_u = l_u + f_u.T @ V_x # (5b)
            Q_ux = l_ux + f_u.T @ V_xx @ f_x # (5c)
            Q_uu = l_uu + f_u.T @ V_xx @ f_u # (5d)
            Q_xx = l_xx + f_x.T @ V_xx @ f_x # (5e)

            # compute and store gains
            kSingValThreshold = 1e-4
            (_,s,_) = jnp.linalg.svd(Q_uu)
            if (jnp.min(s) < kSingValThreshold):
                print("Q_uu is non-singular")
            Q_uu_inv = jnp.linalg.inv(Q_uu) # TODO: this can be singular, try using (9)
            k = -Q_uu_inv @ Q_u # (6)
            K = -Q_uu_inv @ Q_ux # (6)

            k_trj = ops.index_update(k_trj, ops.index[i,:], k) 
            K_trj = ops.index_update(K_trj, ops.index[i,:,:], K) 

            # update the expected reduction (11a), delta V
            # similar to equation of delta J(\alpha)
            current_cost_reduction_grad = -Q_u.T @ k
            current_cost_reduction_hess = (0.5 * k.T @ (Q_uu) @ (k))
            current_cost_reduction = current_cost_reduction_grad + current_cost_reduction_hess

            expected_cost_reduction_grad += current_cost_reduction_grad
            expected_cost_reduction_hess += current_cost_reduction_hess
            expected_cost_reduction += current_cost_reduction

            # update hessian and gradient of value function for the next iteration
            V_x = Q_x + K.T @ Q_uu @ k + K.T @ Q_u + Q_ux.T @ k # (11b)
            V_xx = Q_xx + K.T @ Q_uu @ K + K.T @ Q_ux + Q_ux.T @ K # (11c)

        # store values
        self.expected_cost_reduction_grad_ = expected_cost_reduction_grad
        self.expected_cost_reduction_hess_ = expected_cost_reduction_hess
        self.expected_cost_reduction_ = expected_cost_reduction

        # store gains
        self.k_feedforward_ = k_trj
        self.K_feedback_ = K_trj

        return (k_trj, K_trj, expected_cost_reduction)
    
#     def rollout(self):
        
#         @partial(jit, static_argnums=(0,))
#         def _rollout(f, init_state, timesteps, states, inputs):
#             # store x0
#             states = ops.index_update(states,  ops.index[0,:], init_state)

#             print(type(timesteps))
            
#             # integration of the system dynamics
#             current_state = init_state
#             for i in range(0, len(inputs)): # 0 to N-1
#                 current_input = jnp.squeeze(inputs[i,:]) # u_k
#                 val = f(current_state, current_input)
#                 next_state = jnp.squeeze(val, axis=1) # x_{k+1} = f(x_k, u_k)
#                 states = ops.index_update(states, ops.index[i+1,:], next_state)
#                 inputs = ops.index_update(inputs, ops.index[i,:], current_input)
#                 # update current state
#                 current_state = ops.index_update(current_state,  ops.index[:], next_state)

#             return states, inputs
    
#         print(self.n_timesteps_)
#         (s, i) = _rollout(self.dynamics_.f, self.init_state_, self.n_timesteps_, self.states_, self.inputs_)
#         return s, i
    
    # @partial(jit, static_argnums=(0,))
    def rollout(self):
        """Rollout of the simulated system given an initial state

        Returns:
            ndarray: States trajectory from the rollout
            ndarray: Inputs trajectory from the rollout
        """
        # we store states and inputs as:
        # state = [., x_1, x_2, ..., x_N]
        # input =    [u_0, u_1, ..., u_{N-1}]
        # the first value in state is understood as self.init_state_
        states = jnp.zeros((self.n_timesteps_+1, self.n_states_)) # including initial state, x_0 to x_N
        inputs = jnp.zeros((self.n_timesteps_, self.n_inputs_)) # u_0 to u_{N-1}

        # store x0
        states = ops.index_update(states,  ops.index[0,:], self.init_state_)

        print(self.n_timesteps_)

        # integration of the system dynamics
        current_state = self.init_state_
        for i in range(0, self.n_timesteps_): # 0 to N-1
            current_input = jnp.squeeze(self.inputs_[i,:]) # u_k
            print("i = {}, current_input = {}, current_state = {}".format(i, current_input, current_state))
            next_state = jnp.squeeze(self.dynamics_.f(current_state, current_input), axis=1) # x_{k+1} = f(x_k, u_k)
#             print("i = {}, next_state = {}".format(i, next_state))
            # store both u_k and x_{k+1}
            states = ops.index_update(states, ops.index[i+1,:], next_state)
            inputs = ops.index_update(inputs, ops.index[i,:], current_input)
            # update current state
            current_state = ops.index_update(current_state,  ops.index[:], next_state)

        # store trajectories
#         self.states_ = ops.index_update(self.states_, ops.index[:], states)
#         self.inputs_ = ops.index_update(self.inputs_, ops.index[:], inputs)
            
        return states, inputs
    
#     @partial(jit, static_argnums=(0,))
#     def rollout(self):
#         """Rollout of the simulated system given an initial state

#         Returns:
#             ndarray: States trajectory from the rollout
#             ndarray: Inputs trajectory from the rollout
#         """
#         # we store states and inputs as:
#         # state = [., x_1, x_2, ..., x_N]
#         # input =    [u_0, u_1, ..., u_{N-1}]
#         # the first value in state is understood as self.init_state_
#         states = jnp.zeros((self.n_timesteps_+1, self.n_states_)) # including initial state, x_0 to x_N
#         inputs = jnp.zeros((self.n_timesteps_, self.n_inputs_)) # u_0 to u_{N-1}

#         # store x0
#         states = ops.index_update(states,  ops.index[0,:], self.init_state_)

#         # integration of the system dynamics
#         current_state = self.init_state_
#         for i in range(0, self.n_timesteps_): # 0 to N-1
#             current_input = jnp.squeeze(self.inputs_[i,:]) # u_k
# #             print("i = {}, current_input = {}, current_state = {}".format(i, current_input, current_state))
#             next_state = jnp.squeeze(self.dynamics_.f(current_state, current_input), axis=1) # x_{k+1} = f(x_k, u_k)
# #             print("i = {}, next_state = {}".format(i, next_state))
#             # store both u_k and x_{k+1}
#             states = ops.index_update(states, ops.index[i+1,:], next_state)
#             inputs = ops.index_update(inputs, ops.index[i,:], current_input)
#             # update current state
#             current_state = ops.index_update(current_state,  ops.index[:], next_state)

#         # store trajectories
#         self.states_ = ops.index_update(self.states_, ops.index[:], states)
#         self.inputs_ = ops.index_update(self.inputs_, ops.index[:], inputs)
# #         self.states_ = np.array(states)
# #         self.inputs_ = np.array(inputs)
            
#         return states, inputs
    
#     @partial(jit, static_argnums=(0,))
#     def rollout(self):
#         states = jnp.zeros((self.n_timesteps_+1, self.n_states_)) # including initial state, x_0 to x_N
#         inputs = jnp.zeros((self.n_timesteps_, self.n_inputs_)) # u_0 to u_{N-1}
        
# #         @partial(jit, static_argnums=(0,))
#         def _rollout(f, states, inputs):
#             current_state = jnp.array([0., 0.])
#             for i in range(0, inputs.shape[0]):
#                 current_input = jnp.squeeze(inputs[i,:])
#                 next_state = jnp.squeeze(f(current_state, current_input), axis=1)
#                 states = ops.index_update(states, ops.index[i+1,:], next_state)
#                 inputs = ops.index_update(inputs, ops.index[i,:], current_input)
#             return states, inputs
        
#         return _rollout(self.dynamics_.f, states, inputs)
    
    
start_time = 0
end_time = 2
time_span = jnp.arange(start_time, end_time, DT)

# set states (start and end states are at rest)
n_states = 2 # position and velocity
n_inputs = 1 # inputs to the system
init_state = jnp.array([0.1, 0.])
target_state = jnp.array([jnp.pi, 0.])

# initial guess for control inputs
initial_input_guess = 0.1*jnp.ones((time_span.shape[0], n_inputs))

# define weights
Q_k = jnp.zeros((n_states, n_states)) # just find a valid trajectory first
R_k = 0.001*jnp.eye(n_inputs)
Q_T = 100*jnp.eye(n_states)

# iterations
n_iterations = 50

# ilqr
dynamics = PendulumDynamics(M, G, L, DT)

ilqr = IterativeLQR(init_state, target_state, initial_input_guess, DT, start_time, end_time, dynamics, Q_k, R_k, Q_T, n_iterations)

# precompilation
%time (states, inputs) = ilqr.rollout()
# %time (states, inputs) = ilqr.rollout()

# %time _ = ilqr.rollout()
# %time _ = ilqr.rollout()

# %time (states, inputs) = ilqr.rollout()
# %time _ = ilqr.compute_cost(states, inputs)
# %time _ = ilqr.compute_cost(states, inputs)

# %time (k_trj, K_trj, expected_cost_reduction) = ilqr.backward_pass(states, inputs)

print(states)
# # print(inputs)
# print(initial_input_guess)


400
i = 0, current_input = 0.10000000149011612, current_state = [0.1 0. ]
i = 1, current_input = 0.10000000149011612, current_state = [ 0.1        -0.00439184]
i = 2, current_input = 0.10000000149011612, current_state = [ 0.09997804 -0.00878367]
i = 3, current_input = 0.10000000149011612, current_state = [ 0.09993412 -0.01317444]
i = 4, current_input = 0.10000000149011612, current_state = [ 0.09986825 -0.01756307]
i = 5, current_input = 0.10000000149011612, current_state = [ 0.09978044 -0.02194848]
i = 6, current_input = 0.10000000149011612, current_state = [ 0.0996707  -0.02632961]
i = 7, current_input = 0.10000000149011612, current_state = [ 0.09953905 -0.0307054 ]
i = 8, current_input = 0.10000000149011612, current_state = [ 0.09938552 -0.03507476]
i = 9, current_input = 0.10000000149011612, current_state = [ 0.09921015 -0.03943664]
i = 10, current_input = 0.10000000149011612, current_state = [ 0.09901297 -0.04378996]
i = 11, current_input = 0.10000000149011612, current_state = [ 0.

In [4]:
print(states)

[[ 0.1         0.        ]
 [ 0.1        -0.00439184]
 [ 0.09997804 -0.00878367]
 [ 0.09993412 -0.01317444]
 [ 0.09986825 -0.01756307]
 [ 0.09978044 -0.02194848]
 [ 0.0996707  -0.02632961]
 [ 0.09953905 -0.0307054 ]
 [ 0.09938552 -0.03507476]
 [ 0.09921015 -0.03943664]
 [ 0.09901297 -0.04378996]
 [ 0.09879402 -0.04813368]
 [ 0.09855335 -0.05246671]
 [ 0.09829102 -0.05678801]
 [ 0.09800708 -0.06109652]
 [ 0.09770159 -0.06539118]
 [ 0.09737464 -0.06967095]
 [ 0.09702629 -0.07393476]
 [ 0.09665661 -0.07818159]
 [ 0.0962657  -0.0824104 ]
 [ 0.09585365 -0.08662013]
 [ 0.09542055 -0.09080977]
 [ 0.0949665  -0.09497829]
 [ 0.09449161 -0.09912466]
 [ 0.09399599 -0.10324786]
 [ 0.09347975 -0.10734688]
 [ 0.09294302 -0.11142073]
 [ 0.09238592 -0.11546838]
 [ 0.09180857 -0.11948886]
 [ 0.09121113 -0.12348116]
 [ 0.09059372 -0.12744431]
 [ 0.0899565  -0.13137734]
 [ 0.08929961 -0.13527927]
 [ 0.08862322 -0.13914913]
 [ 0.08792748 -0.14298598]
 [ 0.08721255 -0.14678888]
 [ 0.08647861 -0.15055688]
 

In [5]:
def f(x, u):
    return jnp.array([
        [DT*x[1] + x[0]],                                   # [dt*theta_dot + theta]
        [x[1] + DT*(-L*G*M*jnp.sin(x[0]) + u)/(L**2*M)]     # [theta_dot + dt*(-L*g*m*sin(theta) + u)/(L**2*m)]
        ])

states = jnp.zeros((1000+1, 2)) # including initial state, x_0 to x_N
inputs = jnp.zeros((1000, 1)) # u_0 to u_{N-1}
current_state = jnp.array([0.1, 0.1])
inputs_ = jnp.array(0.5)
for i in range(0, 100): # 0 to N-1
    current_input = inputs_ # u_k
    next_state = f(current_state, current_input) # x_{k+1} = f(x_k, u_k) 
    # store both u_k and x_{k+1}
    # states[i+1,:] = np.array(next_state)
    # inputs[i,:] = np.squeeze(np.array(current_input))
    next_state = jnp.squeeze(next_state, axis=1)
    # print(next_state.shape)
    # print(next_state.shape)
    # print(current_state[0])
    # print(current_state[1])
    # print(next_state[0])
    # print(next_state[1])
    # current_state[0] = current_state.at[0].set(next_state[0])
    # current_state[1] = current_state.at[1].set(next_state[1])
    # update current state
    current_state = jax.ops.index_update(current_state, jax.ops.index[:], next_state)

In [6]:
a = jnp.array([[1], [2]])
b = jnp.array([1, 2])
print(a.shape)
print(b.shape)

print(b @ a)

(2, 1)
(2,)
[5]


https://en.wikipedia.org/wiki/Backtracking_line_search
