In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pydrake.all import (AbstractValue, Cylinder, DiagramBuilder,
                         DirectCollocation,
                         FiniteHorizonLinearQuadraticRegulatorOptions,
                         FramePoseVector,
                         LogVectorOutput,
                         MakeFiniteHorizonLinearQuadraticRegulator,
                         MeshcatVisualizerCpp, MultibodyPlant, Parser,
                         PiecewisePolynomial, Rgba, RigidTransform,
                         RotationMatrix, SceneGraph, Simulator, Solve,
                         StartMeshcat, TrajectorySource,
                         namedview, JointIndex,
                         AddMultibodyPlantSceneGraph,
                         Wing, ConstantVectorSource,
                         SpatialVelocity, DirectTranscription,
                         PlanarJoint, FixedOffsetFrame)

from lib.systems import Barometer, SpatialForceConcatinator
from lib.helpers import MakeNamedView

In [None]:
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat = StartMeshcat()

## Starship Crossection
![imgage](https://everydayastronaut.com/wp-content/uploads/Articles/Belly_Flop/Belly-Flop-MAIN-Reshoot.00_08_53_06.Still003-800x450.jpg)


# This is Starship!
The flaps are simply modeled via the Flat Plate Wing model
The aerodynamic model of the body is still work in progress

In [None]:
class Starship:
    def __init__(self, time_step = 0.001, planar = False):
        builder = DiagramBuilder()
        plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=time_step)
    
        #Parser(plant).AddModelFromFile("ground.urdf", "ground")
        Parser(plant).AddModelFromFile("starship/urdf/starship.urdf", "starship")
        
        if planar:
            planar_joint_frame = plant.AddFrame(FixedOffsetFrame("planar_joint_frame", plant.world_frame(), RigidTransform(RotationMatrix.MakeXRotation(np.pi/2))))                                                  
            planar_body_frame = plant.AddFrame(FixedOffsetFrame("planar_body_frame", plant.GetFrameByName("body"), RigidTransform(RotationMatrix.MakeXRotation(np.pi/2))))                                                  
            joint = PlanarJoint("Joint_body", planar_joint_frame,  planar_body_frame)
            plant.AddJoint(joint)
        plant.Finalize()

        #No wind 
        no_wind = builder.AddSystem(ConstantVectorSource([0,0,0]))

        barometer = builder.AddSystem(Barometer(body_index = plant.GetFrameByName("body").body().index()))

        belly_1 = builder.AddSystem(Wing(body_index = plant.GetFrameByName("body").body().index(),
                                          X_BodyWing = RigidTransform(),
                                          surface_area = 395.82, fluid_density = 1.204))

        belly_2 = builder.AddSystem(Wing(body_index = plant.GetFrameByName("body").body().index(),
                                          X_BodyWing = RigidTransform(RotationMatrix.MakeYRotation(np.pi/2)),
                                          surface_area = 61.164, fluid_density = 1.204))

        Leg_R = builder.AddSystem(Wing(body_index = plant.GetFrameByName("Link_Leg_R").body().index(),
                                        X_BodyWing = RigidTransform(p = [0,-1.88,0]),
                                          surface_area = 47.859, fluid_density = 1.204))

        Leg_L = builder.AddSystem(Wing(body_index = plant.GetFrameByName("Link_Leg_L").body().index(),
                                      X_BodyWing = RigidTransform(p = [0,1.88,0]),
                                      surface_area = 47.859, fluid_density = 1.204))

        Arm_R = builder.AddSystem(Wing(body_index = plant.GetFrameByName("Link_Arm_R").body().index(),
                                      X_BodyWing = RigidTransform(p = [0,-1.26,0]),
                                      surface_area = 16.295, fluid_density = 1.204))

        Arm_L = builder.AddSystem(Wing(body_index = plant.GetFrameByName("Link_Arm_L").body().index(),
                                      X_BodyWing = RigidTransform(p = [0,1.26,0]),
                                       surface_area = 16.295, fluid_density = 1.204))

        wings = [belly_1, belly_2, Leg_R, Leg_L, Arm_R, Arm_L]

        concat = builder.AddSystem(SpatialForceConcatinator(len(wings)))

        builder.Connect(plant.get_body_poses_output_port(), barometer.GetInputPort("body_poses_input_port"))

        for i, control_surface in enumerate(wings):
            #inputs 
            builder.Connect(plant.get_body_poses_output_port(), control_surface.get_body_poses_input_port())
            builder.Connect(plant.get_body_spatial_velocities_output_port(), control_surface.get_body_spatial_velocities_input_port())

            builder.Connect(no_wind.get_output_port(0), control_surface.get_wind_velocity_input_port())
            builder.Connect(barometer.GetOutputPort("Density"), control_surface.get_fluid_density_input_port())
    
            #output
            builder.Connect(control_surface.get_spatial_force_output_port(), concat.get_input_port(i))
    
        #output
        builder.Connect(concat.get_output_port(0), plant.get_applied_spatial_force_input_port())
    
        builder.ExportOutput(scene_graph.get_query_output_port(), "geometry_query")
        builder.ExportInput(plant.get_actuation_input_port())

        diagram = builder.Build()
        
        self.diagram_ = diagram
        self.plant_ = plant

        self.named_view_ = MakeNamedView(self.plant_)

        self.planar = planar
    
    def GetDiagram(self):
        return self.diagram_
    
    def GetPlant(self):
        return self.plant_
    
    def CreateDefaultContext(self):
        return self.diagram_.CreateDefaultContext()
    
    def StateView(self, q):
        return self.named_view_(q)
    
    def init(self, context, **kwargs):
        if self.planar:
            joint_pos = kwargs["joint_pos"]
            joint_vel = kwargs["joint_vel"]

            self.plant_.GetJointByName("Joint_body").set_translation(context, joint_pos[:2])
            self.plant_.GetJointByName("Joint_body").set_translational_velocity(context, joint_vel[:2])

            self.plant_.GetJointByName("Joint_body").set_rotation(context, joint_pos[2])
            self.plant_.GetJointByName("Joint_body").set_angular_velocity(context, joint_vel[2])
        else:
            X_WB = kwargs["X_WB"]
            V_WB = kwargs["V_WB"]
            self.plant_.SetFreeBodyPose(context,
                               self.plant_.GetBodyByName("body"), X_WB)
            self.plant_.SetFreeBodySpatialVelocity(context=context,
                                          body=self.plant_.GetBodyByName("body"),
                                          V_WB=V_WB)


   

# Lets Drop Starship

In [None]:
is2d = False

builder = DiagramBuilder()

starship = Starship(1e-3, is2d)

starship_plant = starship.GetPlant()

starship_system = builder.AddSystem(starship.GetDiagram())

visualizer = MeshcatVisualizerCpp.AddToBuilder(builder, starship_system.GetOutputPort("geometry_query"), meshcat)

meshcat.Delete()

diagram = builder.Build()
simulator = Simulator(diagram)

sim_context = simulator.get_mutable_context()
plant_context = starship_plant.GetMyContextFromRoot(sim_context)

starship_plant.get_actuation_input_port().FixValue(plant_context, np.zeros(4))

#initial condition
if is2d:
    starship.init(context = plant_context, joint_pos = [0,50, np.pi/8], joint_vel = [0,-20,0])
else:
    starship.init(context = plant_context, V_WB = SpatialVelocity(w=[0, 0, 0], v=[0,0,-20]),
                                 X_WB = RigidTransform(RotationMatrix.MakeYRotation(-np.pi/8),[0, 0, 50]))


simulator.set_target_realtime_rate(1.0)
visualizer.StartRecording()
simulator.AdvanceTo(5)
visualizer.PublishRecording()

# Trajectory Optimization

In [None]:
def dirtran_starship():
    builder = DiagramBuilder()

    starship = Starship(0.001)
    starship_system = builder.AddSystem(starship.GetDiagram())
    
    starship_plant = starship.GetPlant()

    visualizer = MeshcatVisualizerCpp.AddToBuilder(builder, starship_system.GetOutputPort("geometry_query"), meshcat)

    meshcat.Delete()
    
    diagram = builder.Build()
    context = diagram.CreateDefaultContext()
    diagram.Publish(context)

    plant_context = starship_plant.GetMyContextFromRoot(context)

    #initial condition
    if is2d:
        starship.init(context = plant_context, joint_pos = [0,50, np.pi/8], joint_vel = [0,-20,0])
    else:
        starship.init(context = plant_context, V_WB = SpatialVelocity(w=[0, 0, 0], v=[0,0,-20]),
                                     X_WB = RigidTransform(RotationMatrix.MakeYRotation(-np.pi/8),[0, 0, 50]))

    q0 = np.hstack([starship_plant.GetPositions(plant_context),  starship_plant.GetVelocities(plant_context)])

    q0 = starship.StateView(q0)

    x_traj = None
    u_traj = None

    N = 25
    #N = int(5/0.001)
    dirtran = DirectTranscription(starship_system, starship.CreateDefaultContext(), N)      
    prog = dirtran.prog()

    result = Solve(prog)
    print(result.is_success())
    print(result.GetSolution())
    for c in result.GetInfeasibleConstraints(prog):
        print(c)
    return

    u = dirtran.input()
    for i in range(len(u)):
        velocity_limit = 13  # max servo velocity (rad/sec)
        dirtran.AddConstraintToAllKnotPoints(-velocity_limit <= u[i])
        dirtran.AddConstraintToAllKnotPoints(u[i] <= velocity_limit)

    #initial conditions
    prog.AddBoundingBoxConstraint(q0[:], q0[:], dirtran.initial_state())
    
    # Joint positions and velocities
    sf = StateView(dirtran.final_state())
    #prog.AddBoundingBoxConstraint(-1, 1, sf.body_wx) #Roll constraint
    #prog.AddBoundingBoxConstraint(5, 20, sf.body_x) #position constraint

    #State Constraints
    s = StateView(dirtran.state())
    dirtran.AddConstraintToAllKnotPoints(s.Joint_Leg_R_p <= 1.57)
    dirtran.AddConstraintToAllKnotPoints(s.Joint_Leg_R_p >= -1.57)
    dirtran.AddConstraintToAllKnotPoints(s.Joint_Leg_L_p <= 1.57)
    dirtran.AddConstraintToAllKnotPoints(s.Joint_Leg_L_p >= -1.57)

    dirtran.AddConstraintToAllKnotPoints(s.Joint_Arm_R_p <= 1.57)
    dirtran.AddConstraintToAllKnotPoints(s.Joint_Arm_R_p >= -1.57)
    dirtran.AddConstraintToAllKnotPoints(s.Joint_Arm_L_p <= 1.57)
    dirtran.AddConstraintToAllKnotPoints(s.Joint_Arm_L_p >= -1.57)

    dirtran.AddConstraintToAllKnotPoints(s.body_z <= q0.body_z)
    dirtran.AddConstraintToAllKnotPoints(s.body_x <= 5)
    dirtran.AddConstraintToAllKnotPoints(s.body_x >= -5)
    dirtran.AddConstraintToAllKnotPoints(s.body_y <= 5)
    dirtran.AddConstraintToAllKnotPoints(s.body_y >= -5)

    # Cost
    #dirtran.AddRunningCost(u.dot(u))

    def plot_traj(times, states):
        s = StateView(states)
        vertices = np.vstack([s.body_x, s.body_y, s.body_z])
        print(sum(s.body_z))
        print(sum(s.body_x))
        print(sum(s.body_y))
        meshcat.SetLine("dircol", vertices, rgba = Rgba(0, 0, 0.5))
    
    dirtran.AddStateTrajectoryCallback(plot_traj)


    result = Solve(prog)

    x_sol = dirtran.ReconstructStateTrajectory(result)
    u_sol = dirtran.ReconstructInputTrajectory(result)

    assert(result.is_success()), "Optimization failed"


In [None]:
dirtran_starship()

In [None]:
def dircol_starship(is2d = True):
    builder = DiagramBuilder()

    starship = Starship(0, is2d)
    starship_system = builder.AddSystem(starship.GetDiagram())

    starship_plant = starship.GetPlant()

    visualizer = MeshcatVisualizerCpp.AddToBuilder(builder, starship_system.GetOutputPort("geometry_query"), meshcat)

    meshcat.Delete()
    
    diagram = builder.Build()
    context = diagram.CreateDefaultContext()
    diagram.Publish(context)
    
    plant_context = starship_plant.GetMyContextFromRoot(context)


    #initial condition
    if is2d:
        starship.init(context = plant_context, joint_pos = [0,50, np.pi/8], joint_vel = [0,-20,0])
    else:
        starship.init(context = plant_context, V_WB = SpatialVelocity(w=[0, 0, 0], v=[0,0,-20]),
                                     X_WB = RigidTransform(RotationMatrix.MakeYRotation(-np.pi/8),[0, 0, 50]))


    q0 = np.hstack([starship_plant.GetPositions(plant_context),  starship_plant.GetVelocities(plant_context)])
    q0 = starship.StateView(q0)

    x_traj = None
    u_traj = None
    # SNOPT is more reliable if we solve it twice.
    
    for N in [25, 41]:
        max_time_step = 0.01
        min_time_step = 0.001
        time_span = 3.0
        #N = int(time_span/min_time_step)
        dircol = DirectCollocation(starship_system, starship.CreateDefaultContext(), N, 0.5 / N, 2.0 / N)
        #dircol = DirectCollocation(starship_system, starship.CreateDefaultContext(), N, min_time_step, max_time_step)
        
        prog = dircol.prog()
        dircol.AddEqualTimeIntervalsConstraints()

        result = Solve(prog)
        print(result.is_success())
        print(result.GetSolution())
        for c in result.GetInfeasibleConstraints(prog):
            print(c)
        break

        
        #input limits
        u = dircol.input()
        for i in range(len(u)):
            velocity_limit = 13  # max servo velocity (rad/sec)
            dircol.AddConstraintToAllKnotPoints(-velocity_limit <= u[i])
            dircol.AddConstraintToAllKnotPoints(u[i] <= velocity_limit)

        # Inital conditions
        prog.AddBoundingBoxConstraint(q0[:], q0[:], dircol.initial_state())
        context.SetContinuousState(q0[:])
        diagram.Publish(context)

        # Final conditions 
        sf = starship.StateView(dircol.final_state())
        #prog.AddBoundingBoxConstraint(0, 0, sf.Joint_body_theta) #Pitch constraint

        #State Constraints
        s = starship.StateView(dircol.state())
        #dircol.AddConstraintToAllKnotPoints(s.Joint_Leg_R_p <= 1.57)
        #dircol.AddConstraintToAllKnotPoints(s.Joint_Leg_R_p >= -1.57)
        #dircol.AddConstraintToAllKnotPoints(s.Joint_Leg_L_p <= 1.57)
        #dircol.AddConstraintToAllKnotPoints(s.Joint_Leg_L_p >= -1.57)

        #dircol.AddConstraintToAllKnotPoints(s.Joint_Arm_R_p <= 1.57)
        #dircol.AddConstraintToAllKnotPoints(s.Joint_Arm_R_p >= -1.57)
        #dircol.AddConstraintToAllKnotPoints(s.Joint_Arm_L_p <= 1.57)
        #dircol.AddConstraintToAllKnotPoints(s.Joint_Arm_L_p >= -1.57)

        #
        #dircol.AddConstraintToAllKnotPoints(s.body_z <= q0.body_z)
        #dircol.AddConstraintToAllKnotPoints(s.body_x <= 5)
        #dircol.AddConstraintToAllKnotPoints(s.body_x >= -5)
        #dircol.AddConstraintToAllKnotPoints(s.body_y <= 5)
        #dircol.AddConstraintToAllKnotPoints(s.body_y >= -5)

        # Cost
        dircol.AddRunningCost(u.dot(u))

        #sf_d = StateView(np.zeros(21))
        #sf_d.body_wx = 0
        #prog.AddQuadraticErrorCost(np.diag([0,0,0,0,0,0,0, #body
        #                                    0,0,0,0, #joint positinos
        #                                    10, 0, 0, 0, 0, 0, #body veelocity
        #                                    0,0,0,0]), sf_d[:], #joint vel
        #                        dircol.final_state())

        sf_d = starship.StateView(np.zeros(14))
        sf_d.Joint_body_theta = 0
        prog.AddQuadraticErrorCost(np.diag([0,0,0, #Joint body positions
                                            0,0,0,0, #joint positinos
                                            0, 0, 10, #joint body velocities
                                            0,0,0,0]), sf_d[:], #joint vel
                                dircol.final_state())



        def plot_traj(times, states):
            s = starship.StateView(states)
            vertices = np.vstack([s.Joint_body_x, 0*s.Joint_body_x, s.Joint_body_y])
            print(s.Joint_body_theta)
            meshcat.SetLine("dircol", vertices, rgba = Rgba(0, 0, 0.5))
        
        dircol.AddStateTrajectoryCallback(plot_traj)

        result = Solve(prog)
        assert result.is_success()

        x_traj = dircol.ReconstructStateTrajectory(result)
        u_traj = dircol.ReconstructInputTrajectory(result)
        print("yah")
        break

In [None]:
dircol_starship()

# Trajectory Stabilization

In [None]:
def finite_horizon_lqe(x_traj, u_traj):
    builder = DiagramBuilder()

    starship = Starship(0.001)
    starship_system = builder.AddSystem(starship.GetDiagram())

    visualizer = MeshcatVisualizerCpp.AddToBuilder(builder, starship_system.GetOutputPort("geometry_query"), meshcat)
    
    meshcat.Delete()
    
    Q = np.diag([#TODO])
    R = [#TODO]
    options = FiniteHorizonLinearQuadraticRegulatorOptions()
    options.Qf = = Q

    options.x0 = x_traj
    options.u0 = u_traj

    controller = builder.AddSystem(
        MakeFiniteHorizonLinearQuadraticRegulator(
            system = starship_system,
            context = starship.CreateDefaultContext(),
            t0 = x_traj.start_time(),
            tf = x_traj.end_time(),
            Q = Q,
            R = R,
            options = options))
    
    builder.Connect(controller.get_output_port(), starship_system.get_input_port())
    builder.Connect(starship_system.GetOutputPort('state'), controller.get_input_port())

    diagram = builder.Build()

    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()
    starship_context = starship_system.GetMyContextFromRoot(context)

    #TODO plot the desired state vs the actual state

    simulator.set_target_realtime_rate(1.0)
    visualizer.StartRecording()
    simulator.AdvanceTo(5)
    visualizer.PublishRecording()

    ani = visualizer.get_recording_as_animation()
    display(HTMP(ani.to_jshtml()))
        


In [None]:
finite_horizon_lqr() #todo pass in inputs from trajectory optimization

In [None]:
body =  {'ixx':1293500,
        'ixy':5.3286E-09,
        'ixz':3212.3,
        'iyy':14695000,
        'iyz':5.3029E-12,
        'izz':14695000}
leg_R = {'ixx':5374.2,
       'ixy':-4459.8,
       'ixz':-1.9212E-11,
       'iyy':45100,
       'iyz':-8.3506E-13,
       'izz':50379}

leg_L ={'ixx':5374.2,
        'ixy':4459.8,
        'ixz':-8.7792E-13,
        'iyy':45100,
        'iyz':-1.0595E-12,
        'izz':50379}
arm_R ={'ixx':919.98,
        'ixy':-634.06,
        'ixz':-0.59721,
        'iyy':4386.3,
        'iyz':-0.83543,
        'izz':5273.3}

arm_L ={'ixx':919.98,
        'ixy':634.06,
        'ixz':-0.5972,
        'iyy':4386.3,
        'iyz':0.83542,
        'izz':5273.3}

for i in [body, leg_R, leg_L, arm_R, arm_L]:

       I = np.array([[i['ixx'], i['ixy'], i['ixz']],
                     [i['ixy'], i['iyy'], i['iyz']],
                     [i['ixz'], i['iyz'], i['izz']]])

       np.linalg.inv(I)
       print(np.all(np.linalg.eigvals(I) > 0)) 