# Differential Dynamic Programming for Quadcopter

Let $\mathbf{x} = [x, v_x, y, v_y, z, v_z, x_q, y_q, z_q, \omega_x, \omega_y, \omega_z]^\top \in \mathbb{R}^{12}$ be a state vector, where $x, y, z$ are x-axis, y-axis, z-axis positions, respectively.  
$x_q, y_q, z_q$ are x-axis, y-axis, z-axis component quaternions, respectively.  
$v_x, v_y, v_z$ represent velocities on each axis.  
$\omega_x, \omega_y, \omega_z$ are angular velocities around each axis.  

Let $\mathbf{u} = [p_0, p_1, p_2, p_3] \in \mathbb{R}^4$ be a control vector, where $p$ represents the command RPM for a rotor. 

Import necessary modules

In [8]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import time
from functools import partial

from safe_control_gym.envs.gym_pybullet_drones.quadrotor import Quadrotor
from safe_control_gym.envs.gym_pybullet_drones.quadrotor_utils import QuadType, cmd2pwm, pwm2rpm
from safe_control_gym.math_and_models.symbolic_systems import SymbolicModel

Set run parameters

In [9]:
DEFAULT_NUM_DRONES = 1
DEFAULT_GUI = True
DEFAULT_RECORD_VISION = False
DEFAULT_PLOT = True
DEFAULT_USER_DEBUG_GUI = True
DEFAULT_OBSTACLES = True
DEFAULT_SIMULATION_FREQ_HZ = 240 # pybullet frequency
DEFAULT_CONTROL_FREQ_HZ = 60
DEFAULT_DURATION_SEC = 10
DEFAULT_OUTPUT_FOLDER = 'results'
DEFAULT_COLAB = True

#### Action vector ######## Thrust           X Torque             Y Torque             Z Torque
# ACT_LOWER_BOUND = np.array([0.,              -self.MAX_XY_TORQUE, -self.MAX_XY_TORQUE, -self.MAX_Z_TORQUE])
# ACT_UPPER_BOUND = np.array([self.MAX_THRUST, self.MAX_XY_TORQUE,  self.MAX_XY_TORQUE,  self.MAX_Z_TORQUE])

Helper functions

In [10]:
def rotZ(psi):
    '''Rotation matrix about Z axis following SDFormat http://sdformat.org/tutorials?tut=specify_pose&cat=specification&.

    Args:
    psi: Scalar rotation

    Returns:
    R: Rotation matrix
    '''
    R = jnp.array([[jnp.cos(psi), -jnp.sin(psi), 0.],
                    [jnp.sin(psi),  jnp.cos(psi), 0.],
                    [          0.,            0., 1.]], dtype=float)
    return R


def rotY(theta):
    '''Rotation matrix about Y axis following SDFormat http://sdformat.org/tutorials?tut=specify_pose&cat=specification&.

    Args:
    theta: Scalar rotation

    Returns:
    R: Rotation matrix
    '''
    R = jnp.array([[ jnp.cos(theta), 0., jnp.sin(theta)],
                    [             0., 1.,             0.],
                    [-jnp.sin(theta), 0., jnp.cos(theta)]], dtype=float)
    return R


def rotX(phi):
    '''Rotation matrix about X axis following SDFormat http://sdformat.org/tutorials?tut=specify_pose&cat=specification&.

    Args:
    phi: Scalar rotation

    Returns:
    R: Rotation matrix
    '''
    R = jnp.array([[ 1.,           0.,            0.],
                    [ 0., jnp.cos(phi), -jnp.sin(phi)],
                    [ 0., jnp.sin(phi),  jnp.cos(phi)]], dtype=float)
    return R


def rotXYZ(phi, theta, psi):
    '''Rotation matrix from euller angles  following SDFormat http://sdformat.org/tutorials?tut=specify_pose&cat=specification&.
    This represents the extrinsic X-Y-Z (or quivalently the intrinsic Z-Y-X (3-2-1)) euler angle rotation.

    Args:
    phi: roll (or rotation about X).
    theta: pitch (or rotation about Y).
    psi: yaw (or rotation about Z).

    Returns:
    R: Rotation matrix
    '''
    R = rotZ(psi) @ rotY(theta) @ rotX(phi)

    return R

def skew(vec):
    return jnp.array([[0., -vec[2], vec[1]],
                    [vec[2], 0., -vec[0]],
                    [-vec[1], vec[0], 0.]])


Define a class to describe a cartpole system


In [11]:
class Quadcopter3D:

    def __init__(self,
                 n_drones: int,
                 use_gui: bool,
                 record_video: bool,
                 debug_gui: bool,
                 simulation_freq: int,
                 control_freq: int,
                 duration: int,
                 output_folder: str,
                 use_colab: bool):
        # ddp iteration
        self.iteration = 10 
        # the number of drones for simulation
        self.n_drones = n_drones
        self.control_freq = control_freq
        # the number of dimensions for a single agent
        u_dim = 4
        x_dim = 12
        # time horizon
        self.T = duration
        # control u to be optimized
        self.u = jnp.ones((self.T, u_dim), dtype=float) # dim: T x u_dim
        # state x to be optimized
        self.x = jnp.zeros((self.T, x_dim), dtype=float) # dim: T x x_dim 
        #self.x = self.x.at[:, 2].set(jnp.pi * 2)
        self.updated_x = jnp.zeros((self.T, x_dim), dtype=float) # dim: T x x_dim 
        #self.desired_terminal_state = jnp.array([0, 0, jnp.pi, 0], dtype=float)
        # state x for test
        self.test_x = jnp.zeros((self.T, x_dim), dtype=float) # dim: T x x_dim 
        self.test_updated_x = jnp.zeros((self.T, x_dim), dtype=float) # dim: T x x_dim 

        # gains to be optimized
        self.k = jnp.zeros((self.T, u_dim), dtype=float) # dim: T x u_dim
        self.K = jnp.zeros((self.T, u_dim, x_dim), dtype=float) # dim: T x u_dim x x_dim
        # coefficients for the quadratic cost
        self.P = jnp.eye(x_dim, x_dim, dtype=float) # dim: x_dim x x_dim
        self.P = self.P.at[0, 0].set(10)
        self.P = self.P.at[2, 2].set(10)
        self.P = self.P.at[4, 4].set(10)
        self.R = jnp.eye(u_dim, u_dim, dtype=float)  # dim: u_dim x u_dim


        # gym-pybullet-drones environments
        # Create an environment
        # for initial rollout
        env_config = {'info_in_reset': True, 
                      'ctrl_freq': 60, 
                      'pyb_freq': 240, 
                      'physics': 'pyb', 
                      'gui': False, 
                      'quad_type': 3, # 3-D 
                      'normalized_rl_action_space': False, 
                      'episode_len_sec': duration, 
                      'init_state': None, 
                      'randomized_init': False, 
                      'init_state_randomization_info': None, 
                      'inertial_prop': None, 
                      'randomized_inertial_prop': False, 
                      'inertial_prop_randomization_info': None, 
                      'task': 'stabilization', 
                      'task_info': None, 
                      'cost': 'rl_reward', 
                      'disturbances': None, 
                      'adversary_disturbance': None, 
                      'adversary_disturbance_offset': 0.0, 
                      'adversary_disturbance_scale': 0.01, 
                      'constraints': None, 
                      'done_on_violation': False, 
                      'use_constraint_penalty': False, 
                      'constraint_penalty': -1, 
                      'verbose': False, 
                      'norm_act_scale': 0.1, 
                      'obs_goal_horizon': 0, 
                      'rew_state_weight': 1.0, 
                      'rew_act_weight': 0.0001, 
                      'rew_exponential': True, 
                      'done_on_out_of_bound': True,
                      'task_info': {'stabilization_goal': [0, 0, 1],
                                    'stabilization_goal_tolerance': 0.0,
                                    'proj_point': [0, 0, 0.5], 
                                    'proj_normal': [0, 1, 1]}}
        self.random_env = Quadrotor(**env_config)
        self.random_env.seed(0)
        init_state, _ = self.random_env.reset()
        env_config_static = env_config.copy()
        env_config_static['randomized_init'] = False
        env_config_static['init_state'] = init_state
        env_config_static['gui'] = use_gui
        env_config_static_test = env_config_static.copy()
        env_config_static_test['gui'] = use_gui
        self.static_env = Quadrotor(**env_config_static)
        #self.test_env = Quadrotor(**env_config_static_test)

        # system parameters
        self.M = self.static_env.MASS
        self.Iyy = self.static_env.J[1, 1]
        self.g = self.static_env.GRAVITY_ACC
        self.length = self.static_env.L
        self.dt = self.static_env.CTRL_TIMESTEP
        self.Ixx = self.static_env.J[0, 0]
        self.Izz = self.static_env.J[2, 2]
        self.J = jnp.array([[self.Ixx, 0., 0,],
                            [0., self.Iyy, 0.],
                            [0., 0., self.Izz]], dtype=float)
        self.Jinv = jnp.array([[1. / self.Ixx, 0., 0,],
                               [0., 1. / self.Iyy, 0.],
                               [0., 0., 1. / self.Izz]], dtype=float)
        self.gamma = self.static_env.KM / self.static_env.KF

        
        # goals are in the environment

    def rollout_initial(self):
        break_counter = 0
        self.static_env.seed(0)
        observation, _ = self.static_env.reset()
        for i in range(self.T * self.control_freq):
            self.x = self.x.at[i].set(jnp.array(observation))
            observation, reward, terminated, info = self.static_env.step(np.array(self.u[i], dtype=float))
            # mbrl cartpole env does not allow further steps after termination
            if terminated:
                break
            break_counter += 1
        print('\trollout_initial break at :', break_counter)
        #self.random_env.close()

    def update_trajectory(self, mu=1):
        observation, _ = self.static_env.reset()
        break_counter = 0
        for i in range(self.T * self.control_freq):
            # state
            self.updated_x = self.updated_x.at[i].set(jnp.array(observation, dtype=float))
            # new_control = current_control + delta_control,
            # where delta_control = k + K @ delta_x
            new_control = self.u[i] + mu * self.k[i].ravel() + \
                self.K[i] @ (self.updated_x[i] - self.x[i]).T
            # update the current_control with the new_control
            self.u = self.u.at[i].set(new_control)
            observation, reward, terminated, info = self.static_env.step(np.array(self.u[i], dtype=float))
            if terminated:
                break
            break_counter += 1
        print('\tupdate_trajectory break at :', break_counter)
        self.x = self.updated_x.copy()

    def get_trajectory_cost_value(self, x, u, P, R):
        '''
            x: dim of T x x_dim
            u: dim of T x u_dim
            P: dim of x_dim x x_dim
            R: dim of u_dim x u_dim
        '''
        trajectory_cost = 0
        # accumulate over the planning horizon
        # except the terminal state
        for i in range(self.T - 1):
            # quadratic cost
            one_step_cost = self.get_trajectory_cost_one_step(x[i], u[i], P, R)
            #print('one_step_cost:', one_step_cost)
            trajectory_cost += one_step_cost
        return trajectory_cost

    def get_terminal_cost(self, x_terminal, eta=1):
        x_goal = self.static_env.X_GOAL
        error = x_terminal - x_goal
        penalty = eta * jnp.eye(len(x_terminal), dtype=float)
        return error @ penalty @ error.T
    
    def get_trajectory_cost_one_step(self, x, u, P, R):
        '''
            x: dim of 1 x x_dim
            u: dim of 1 x u_dim
            P: dim of x_dim x x_dim
            R: dim of u_dim x u_dim
        '''
        x_goal = self.static_env.X_GOAL
        u_goal = self.static_env.U_GOAL
        return (x - x_goal) @ P @ (x - x_goal).T + (u - u_goal) @ R @ (u - u_goal).T

    # system dynamics
    def F(self, 
          x: jnp.ndarray, 
          u: jnp.ndarray):
        '''
            x: 
                0: x-position
                1: x-velocity
                2: y-position
                3: y-velocity
                4: z-position
                5: z-velocity
                6: x-quaternion
                7: y-quaternion
                8: z-quaternion
                9: x-angular velocity
                10: y-angular velocity
                11: z-angular velocity
            u: 
                0: thrust
                1: x-torque
                2: y-torque
                3: z-torque
        '''
        # PyBullet Euler angles use the SDFormat for rotation matrices.
        Rob = rotXYZ(x[6], x[7], x[8])  # rotation matrix transforming a vector in the body frame to the world frame.
        # From Ch. 2 of Luis, Carlos, and Jérôme Le Ny. 'Design of a trajectory tracking controller for a
        # nanoquadcopter.' arXiv preprint arXiv:1608.05786 (2016).

        # Defining the dynamics function.
        # We are using the velocity of the base wrt to the world frame expressed in the world frame.
        # Note that the reference expresses this in the body frame.
        oVdot_cg_o = Rob @ jnp.array([0., 0., jnp.sum(u)], dtype=float) / self.M - jnp.array([0., 0., self.g], dtype=float)
        pos_ddot = oVdot_cg_o
        Mb = jnp.array([self.length / jnp.sqrt(2.) * (u[0] + u[1] - u[2] - u[3]),
                        self.length / jnp.sqrt(2.) * (-u[0] + u[1] + u[2] - u[3]),
                        self.gamma * (-u[0] + u[1] - u[2] + u[3])], dtype=float).T
        rate_dot = self.Jinv @ (Mb - (skew(x[9:].T) @ self.J @ x[9:].T))
        ang_dot = jnp.array([[1., jnp.sin(x[6]) * jnp.tan(x[7]), jnp.cos(x[6]) * jnp.tan(x[7])],
                             [0., jnp.cos(x[6]), -jnp.sin(x[6])],
                             [0., jnp.sin(x[6]) / jnp.cos(x[7]), jnp.cos(x[6]) / jnp.cos(x[7])]]) @ x[9:].T
        next_quaternion = x[6:9] + ang_dot * self.dt
        next_angular_velocity = x[9:] + rate_dot * self.dt
        next_state_pos_vel = jnp.array([x[0] + x[1] * self.dt, 
                                x[1] + pos_ddot[0] * self.dt, 
                                x[2] + x[3] * self.dt,
                                x[3] + pos_ddot[1] * self.dt, 
                                x[4] + x[5] * self.dt,
                                x[5] + pos_ddot[2] * self.dt,
                                ], dtype=float)
        next_state = jnp.concatenate((next_state_pos_vel, next_quaternion, next_angular_velocity))
        return next_state

    def get_cost_derivatives(self, x, u, P, R):
        L_x = jax.grad(self.get_trajectory_cost_one_step, argnums=0) \
            (x, u, P, R) # dim: x_dim
        L_u = jax.grad(self.get_trajectory_cost_one_step, argnums=1) \
            (x, u, P, R) # dim: u_dim
        L_xu = jax.jacfwd(jax.grad(self.get_trajectory_cost_one_step, argnums=0), argnums=1) \
            (x, u, P, R) # dim: x_dim x u_dim
        L_ux = jax.jacfwd(jax.grad(self.get_trajectory_cost_one_step, argnums=1), argnums=0) \
            (x, u, P, R) # dim: u_dim x x_dim
        L_xx = jax.jacfwd(jax.grad(self.get_trajectory_cost_one_step, argnums=0), argnums=0) \
            (x, u, P, R) # dim: x_dim x x_dim
        L_uu = jax.jacfwd(jax.grad(self.get_trajectory_cost_one_step, argnums=1), argnums=1) \
            (x, u, P, R) # dim: u_dim x u_dim
        return L_x, L_u, L_xu, L_ux, L_xx, L_uu

    def test_render(self, mu=1):
        print('testing...')
        self.test_env.seed(0)
        observation, _ = self.test_env.reset()
        for i in range(self.T * self.control_freq):
            print('test step:', i)
            observation = jnp.array(observation, dtype=float)
            # new_control = current_control + delta_control,
            # where delta_control = k + K @ delta_x
            new_control = self.u[i] + mu * self.k[i].ravel() + \
                self.K[i] @ (observation - self.test_x[i]).T
            self.test_x = self.test_x.at[i].set(observation)
            # update the current_control with the new_control
            self.u = self.u.at[i].set(new_control)
            observation, reward, terminated, info = self.test_env.step(np.array(self.u[i], dtype=float))
            if terminated:
                break
            time.sleep(1/240)
        self.test_env.close()

    def ddp(self):
        self.rollout_initial()
        for i in range(self.iteration):
            t1 = time.time()
            cost = self.get_trajectory_cost_value(x=self.x, 
                                                  u=self.u, 
                                                  P=self.P, 
                                                  R=self.R)
            cost += self.get_terminal_cost(x_terminal=self.x[-1])
            #print('cost.shape:', cost.shape)
            print(f'ddp iteration: {i}, initial cost: {cost}')

            L_x, L_u, L_xu, L_ux, L_xx, L_uu = self.get_cost_derivatives(self.x[-1], 
                                                                         self.u[-1], 
                                                                         self.P, 
                                                                         self.R)

            # get terminal gains
            k = -jnp.linalg.inv(L_uu) @ L_u
            K = -jnp.linalg.inv(L_uu) @ L_ux

            # set gains at the final time step
            self.k = self.k.at[-1].set(k) 
            self.K = self.K.at[-1].set(K)

            V_x = jax.grad(self.get_terminal_cost, argnums=0) \
                (self.x[-1]) # dim: x_dim
            V_xx = jax.jacfwd(jax.grad(self.get_terminal_cost, argnums=0), argnums=0) \
                (self.x[-1]) # dim: x_dim x x_dim

            print('backward pass...')
            # propagating backwards to obtain gains over time steps
            for j in reversed(range(self.T - 1)):
                #print(f'\tbackward step: {j}')
                # get partial derivatives of dynamics
                F_x = jax.jacfwd(self.F, argnums=0)(self.x[j], 
                                                    self.u[j]) # dim: x_dim x x_dim
                F_u = jax.jacfwd(self.F, argnums=1)(self.x[j], 
                                                    self.u[j]) # dim: x_dim x u_dim
                F_xx = jax.jacfwd(jax.jacfwd(self.F, argnums=0), argnums=0) \
                    (self.x[j], 
                     self.u[j]) # dim: x_dim x x_dim x x_dim
                F_ux = jax.jacfwd(jax.jacfwd(self.F, argnums=1), argnums=0) \
                    (self.x[j], 
                     self.u[j]) # dim: x_dim x u_dim x x_dim
                F_uu = jax.jacfwd(jax.jacfwd(self.F, argnums=1), argnums=1) \
                    (self.x[j], 
                     self.u[j]) # dim: x_dim x u_dim x u_dim

                # get partial derivatives of instantaneous cost 
                L_x, L_u, L_xu, L_ux, L_xx, L_uu = self.get_cost_derivatives(self.x[j], 
                                                                             self.u[j], 
                                                                             self.P, 
                                                                             self.R)

                # get Q function
                Q_x = L_x + V_x @ F_x # dim: x_dim
                Q_u = L_u + V_x @ F_u # dim: u_dim
                Q_xx = L_xx + F_x.T @ V_xx @ F_x #+ V_x @ F_xx # dim: x_dim x x_dim
                Q_ux = L_ux + F_u.T @ V_xx @ F_x #+ jnp.einsum('i, ijk -> jk', V_x, F_ux) # dim: u_dim x x_dim
                Q_uu = L_uu + F_u.T @ V_xx @ F_u #+ jnp.einsum('i, ijk -> jk', V_x, F_uu) # dim: u_dim x u_dim

                # set gains each time step
                k = -jnp.linalg.inv(Q_uu) @ Q_u
                K = -jnp.linalg.inv(Q_uu) @ Q_ux
                self.k = self.k.at[j].set(k) 
                self.K = self.K.at[j].set(K)

                # compute the partial derivatives of the value function 
                V_x = Q_x - Q_ux.T @ jnp.linalg.inv(Q_uu) @ Q_u # dim: x_dim
                V_xx = Q_xx - Q_ux.T @ jnp.linalg.inv(Q_uu) @ Q_ux # dim: x_dim x x_dim
                
            self.update_trajectory()
            t2 = time.time()
            print('elapsed iteration time: ', t2 - t1)
            cost = self.get_trajectory_cost_value(x=self.x, 
                                                  u=self.u, 
                                                  P=self.P, 
                                                  R=self.R)
            cost += self.get_terminal_cost(x_terminal=self.x[-1])
            print(f'ddp iteration: {i}, optimized cost: {cost}')
        self.static_env.close()


In [12]:
quadcopter_3d = Quadcopter3D(n_drones=DEFAULT_NUM_DRONES,
                             use_gui=DEFAULT_GUI,
                             record_video=DEFAULT_RECORD_VISION,
                             debug_gui=DEFAULT_USER_DEBUG_GUI,
                             simulation_freq=DEFAULT_SIMULATION_FREQ_HZ,
                             control_freq=DEFAULT_CONTROL_FREQ_HZ,
                             duration=DEFAULT_DURATION_SEC,
                             output_folder=DEFAULT_OUTPUT_FOLDER,
                             use_colab=DEFAULT_COLAB
                             )

  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [13]:
quadcopter_3d.ddp()

	rollout_initial break at : 34
ddp iteration: 0, initial cost: 126.8449935913086
backward pass...
	update_trajectory break at : 74
elapsed iteration time:  5.8454649448394775
ddp iteration: 0, optimized cost: 116.84886169433594
ddp iteration: 1, initial cost: 116.84886169433594
backward pass...
	update_trajectory break at : 113
elapsed iteration time:  5.9968366622924805
ddp iteration: 1, optimized cost: 111.21208953857422
ddp iteration: 2, initial cost: 111.21208953857422
backward pass...
	update_trajectory break at : 209
elapsed iteration time:  6.861333608627319
ddp iteration: 2, optimized cost: 107.96401977539062
ddp iteration: 3, initial cost: 107.96401977539062
backward pass...
	update_trajectory break at : 414
elapsed iteration time:  7.831307888031006
ddp iteration: 3, optimized cost: 106.5362777709961
ddp iteration: 4, initial cost: 106.5362777709961
backward pass...
	update_trajectory break at : 414
elapsed iteration time:  7.813297271728516
ddp iteration: 4, optimized cost: 