In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pydrake.all import (AbstractValue, Cylinder, DiagramBuilder,
                         DirectCollocation,
                         FiniteHorizonLinearQuadraticRegulatorOptions,
                         FramePoseVector,
                         LogVectorOutput,
                         MakeFiniteHorizonLinearQuadraticRegulator,
                         MeshcatVisualizerCpp, MultibodyPlant, Parser,
                         PiecewisePolynomial, Rgba, RigidTransform,
                         RotationMatrix, SceneGraph, Simulator, Solve,
                         StartMeshcat, TrajectorySource,
                         namedview, JointIndex,
                         AddMultibodyPlantSceneGraph,
                         Wing, ConstantVectorSource,
                         SpatialVelocity, DirectTranscription)

from lib.systems import Barometer, SpatialForceConcatinator

In [None]:
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat = StartMeshcat()

## Starship Crossection
![imgage](https://everydayastronaut.com/wp-content/uploads/Articles/Belly_Flop/Belly-Flop-MAIN-Reshoot.00_08_53_06.Still003-800x450.jpg)


# This is Starship!
The flaps are simply modeled via the Flat Plate Wing model
The aerodynamic model of the body is still work in progress

In [None]:
class Starship:
    def __init__(self, time_step = 0.001):
        builder = DiagramBuilder()
        plant, scene_graph =AddMultibodyPlantSceneGraph(builder, time_step=time_step)
    
        #Parser(plant).AddModelFromFile("ground.urdf", "ground")
        Parser(plant).AddModelFromFile("starship/urdf/starship.urdf", "starship")
        plant.Finalize()

        #No wind 
        no_wind = builder.AddSystem(ConstantVectorSource([0,0,0]))

        barometer = builder.AddSystem(Barometer(body_index = plant.GetFrameByName("body").body().index()))

        belly_1 = builder.AddSystem(Wing(body_index = plant.GetFrameByName("body").body().index(),
                                          X_BodyWing = RigidTransform(),
                                          surface_area = 395.82, fluid_density = 1.204))

        belly_2 = builder.AddSystem(Wing(body_index = plant.GetFrameByName("body").body().index(),
                                          X_BodyWing = RigidTransform(RotationMatrix.MakeYRotation(np.pi/2)),
                                          surface_area = 61.164, fluid_density = 1.204))

        Leg_R = builder.AddSystem(Wing(body_index = plant.GetFrameByName("Link_Leg_R").body().index(),
                                        X_BodyWing = RigidTransform(p = [0,-1.88,0]),
                                          surface_area = 47.859, fluid_density = 1.204))

        Leg_L = builder.AddSystem(Wing(body_index = plant.GetFrameByName("Link_Leg_L").body().index(),
                                      X_BodyWing = RigidTransform(p = [0,1.88,0]),
                                      surface_area = 47.859, fluid_density = 1.204))

        Arm_R = builder.AddSystem(Wing(body_index = plant.GetFrameByName("Link_Arm_R").body().index(),
                                      X_BodyWing = RigidTransform(p = [0.22,-1.24,0]),
                                      surface_area = 16.295, fluid_density = 1.204))

        Arm_L = builder.AddSystem(Wing(body_index = plant.GetFrameByName("Link_Arm_L").body().index(),
                                      X_BodyWing = RigidTransform(p = [0.22,1.24,0]),
                                       surface_area = 16.295, fluid_density = 1.204))

        wings = [belly_1, belly_2, Leg_R, Leg_L, Arm_R, Arm_L]

        concat = builder.AddSystem(SpatialForceConcatinator(len(wings)))

        builder.Connect(plant.get_body_poses_output_port(), barometer.GetInputPort("body_poses_input_port"))

        for i, control_surface in enumerate(wings):
            #inputs 
            builder.Connect(plant.get_body_poses_output_port(), control_surface.get_body_poses_input_port())
            builder.Connect(plant.get_body_spatial_velocities_output_port(), control_surface.get_body_spatial_velocities_input_port())

            builder.Connect(no_wind.get_output_port(0), control_surface.get_wind_velocity_input_port())
            builder.Connect(barometer.GetOutputPort("Density"), control_surface.get_fluid_density_input_port())
    
            #output
            builder.Connect(control_surface.get_spatial_force_output_port(), concat.get_input_port(i))
    
        #output
        builder.Connect(concat.get_output_port(0), plant.get_applied_spatial_force_input_port())
    
        builder.ExportOutput(scene_graph.get_query_output_port(), "geometry_query")
        builder.ExportInput(plant.get_actuation_input_port())

        diagram = builder.Build()
        
        self.diagram_ = diagram
        self.plant_ = plant
    
    def GetDiagram(self):
        return self.diagram_
    
    def GetPlant(self):
        return self.plant_
    
    def CreateDefaultContext(self):
        return self.diagram_.CreateDefaultContext()
    
    def MakeNamedView(self):
        names_pos = [None]*self.plant_.num_positions()
        for ind in range(self.plant_.num_joints()):
            joint = self.plant_.get_joint(JointIndex(ind))
            # TODO: Handle planar joints, etc.
            assert(joint.num_positions() == 1)
            names_pos[joint.position_start()] = joint.name() + "_p"
        for ind in self.plant_.GetFloatingBaseBodies():
            body = self.plant_.get_body(ind)
            start = body.floating_positions_start()
            body_name = body.name()
            names_pos[start] = body_name+'_qw'
            names_pos[start+1] = body_name+'_qx'
            names_pos[start+2] = body_name+'_qy'
            names_pos[start+3] = body_name+'_qz'
            names_pos[start+4] = body_name+'_x'
            names_pos[start+5] = body_name+'_y'
            names_pos[start+6] = body_name+'_z'
        
        names_vel = [None]*self.plant_.num_velocities()
        for ind in range(self.plant_.num_joints()):
            joint = self.plant_.get_joint(JointIndex(ind))
            # TODO: Handle planar joints, etc.
            assert(joint.num_velocities() == 1)
            names_vel[joint.velocity_start()] = joint.name() + "_v"
        for ind in self.plant_.GetFloatingBaseBodies():
            body = self.plant_.get_body(ind)
            start = body.floating_velocities_start() - self.plant_.num_positions()
            body_name = body.name()
            names_vel[start] = body_name+'_wx'
            names_vel[start+1] = body_name+'_wy'
            names_vel[start+2] = body_name+'_wz'
            names_vel[start+3] = body_name+'_vx'
            names_vel[start+4] = body_name+'_vy'
            names_vel[start+5] = body_name+'_vz'
        
        return namedview("state", names_pos + names_vel)

# Lets Drop Starship

In [None]:
builder = DiagramBuilder()

starship = Starship()

starship_plant = starship.GetPlant()

starship_system = builder.AddSystem(starship.GetDiagram())

visualizer = MeshcatVisualizerCpp.AddToBuilder(builder, starship_system.GetOutputPort("geometry_query"), meshcat)

meshcat.Delete()

diagram = builder.Build()
simulator = Simulator(diagram)

sim_context = simulator.get_mutable_context()
plant_context = starship_plant.GetMyContextFromRoot(sim_context)

#initial condition
V_WB = SpatialVelocity(w=[0, 0, 0], v=[0,0,-20])
X_WB = RigidTransform(RotationMatrix.MakeYRotation(-np.pi/8),[0, 0, 50])

starship_plant.SetFreeBodyPose(plant_context,
                               starship_plant.GetBodyByName("body"), X_WB)
starship_plant.SetFreeBodySpatialVelocity(context=plant_context,
                                          body=starship_plant.GetBodyByName("body"),
                                          V_WB=V_WB)


starship_plant.get_actuation_input_port().FixValue(plant_context, np.zeros(4))

simulator.set_target_realtime_rate(1.0)
visualizer.StartRecording()
simulator.AdvanceTo(5)
visualizer.PublishRecording()

# Trajectory Optimization

In [None]:
def dirtran_starship():
    builder = DiagramBuilder()

    starship = Starship(0.001)
    starship_system = builder.AddSystem(starship.GetDiagram())

    visualizer = MeshcatVisualizerCpp.AddToBuilder(builder, starship_system.GetOutputPort("geometry_query"), meshcat)

    meshcat.Delete()
    
    diagram = builder.Build()
    context = diagram.CreateDefaultContext()
    diagram.Publish(context)



    x_traj = None
    u_traj = None
    # SNOPT is more reliable if we solve it twice.
    N = 25
    dirtran = DirectTranscription(starship_system, starship.CreateDefaultContext(), N)      
    prog = dirtran.prog()
    
    #input limits
    u = dirtran.input()


In [None]:
dirtran_starship()

In [None]:
def dircol_starship():
    builder = DiagramBuilder()

    starship = Starship(0)
    starship_system = builder.AddSystem(starship.GetDiagram())

    starship_plant = starship.GetPlant()

    visualizer = MeshcatVisualizerCpp.AddToBuilder(builder, starship_system.GetOutputPort("geometry_query"), meshcat)

    meshcat.Delete()
    
    diagram = builder.Build()
    context = diagram.CreateDefaultContext()
    diagram.Publish(context)
    
    plant_context = starship_plant.GetMyContextFromRoot(context)

    #initial condition
    V_WB = SpatialVelocity(w=[0, 0, 0], v=[0,0,-20])
    X_WB = RigidTransform(RotationMatrix.MakeYRotation(-np.pi/8),[0, 0, 50])

    starship_plant.SetFreeBodyPose(plant_context,
                                starship_plant.GetBodyByName("body"), X_WB)
    starship_plant.SetFreeBodySpatialVelocity(context=plant_context,
                                            body=starship_plant.GetBodyByName("body"),
                                            V_WB=V_WB)

    q0 = np.hstack([starship_plant.GetPositions(plant_context),  starship_plant.GetVelocities(plant_context)])

    x_traj = None
    u_traj = None
    # SNOPT is more reliable if we solve it twice.
    for N in [25, 41]:
        dircol = DirectCollocation(starship_system, starship.CreateDefaultContext(), N, 0.5 / N, 2.0 / N)
        
        prog = dircol.prog()
        dircol.AddEqualTimeIntervalsConstraints()
        
        #input limits
        u = dircol.input()
        for i in range(len(u)):
            elevator_velocity_limit = 13  # max servo velocity (rad/sec)
            dircol.AddConstraintToAllKnotPoints(-elevator_velocity_limit <= u[i])
            dircol.AddConstraintToAllKnotPoints(u[i] <= elevator_velocity_limit)

        #inital conditions
        prog.AddBoundingBoxConstraint(q0[:], q0[:], dircol.initial_state())
        context.SetContinuousState(q0[:])
        diagram.Publish(context)

        # Joint positions and velocities
        StateView = starship.MakeNamedView()
        sf = StateView(dircol.final_state())
        prog.AddBoundingBoxConstraint(0, 0, sf.body_wx) #Roll constraint


        #State Constraints
        s = StateView(dircol.state())
        dircol.AddConstraintToAllKnotPoints(s.Joint_Leg_R_p <= 1.57)
        dircol.AddConstraintToAllKnotPoints(s.Joint_Leg_R_p >= -1.57)
        dircol.AddConstraintToAllKnotPoints(s.Joint_Leg_L_p <= 1.57)
        dircol.AddConstraintToAllKnotPoints(s.Joint_Leg_L_p >= -1.57)

        dircol.AddConstraintToAllKnotPoints(s.Joint_Arm_R_p <= 1.57)
        dircol.AddConstraintToAllKnotPoints(s.Joint_Arm_R_p >= -1.57)
        dircol.AddConstraintToAllKnotPoints(s.Joint_Arm_L_p <= 1.57)
        dircol.AddConstraintToAllKnotPoints(s.Joint_Arm_L_p >= -1.57)

        # Cost
        dircol.AddRunningCost(u.dot(u))

        sf_d = StateView(np.zeros(21))
        sf_d.body_wx = 0
        prog.AddQuadraticErrorCost(np.diag([0,0,0,0,0,0,0,
                                            0,0,0,0,
                                            10, 0, 0, 0, 0, 0,
                                            0,0,0,0]), sf_d[:],
                                dircol.final_state())

        result = Solve(prog)
        assert result.is_success()

        x_traj = dircol.ReconstructStateTrajectory(result)
        u_traj = dircol.ReconstructInputTrajectory(result)

In [None]:
dircol_starship()

# Trajectory Stabilization

In [None]:
def finite_horizon_lqr(x_traj, u_traj):
    builder = DiagramBuilder()

    starship = Starship(0.001)
    starship_system = builder.AddSystem(starship.GetDiagram())

    visualizer = MeshcatVisualizerCpp.AddToBuilder(builder, starship_system.GetOutputPort("geometry_query"), meshcat)
    
    meshcat.Delete()
    
    Q = np.diag([#TODO])
    R = [#TODO]
    options = FiniteHorizonLinearQuadraticRegulatorOptions()
    options.Qf = = Q

    options.x0 = x_traj
    options.u0 = u_traj

    controller = builder.AddSystem(
        MakeFiniteHorizonLinearQuadraticRegulator(
            system = starship_system,
            context = starship.CreateDefaultContext(),
            t0 = x_traj.start_time(),
            tf = x_traj.end_time(),
            Q = Q,
            R = R,
            options = options))
    
    builder.Connect(controller.get_output_port(), starship_system.get_input_port())
    builder.Connect(starship_system.GetOutputPort('state'), controller.get_input_port())

    diagram = builder.Build()

    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()
    starship_context = starship_system.GetMyContextFromRoot(context)

    #TODO plot the desired state vs the actual state

    simulator.set_target_realtime_rate(1.0)
    visualizer.StartRecording()
    simulator.AdvanceTo(5)
    visualizer.PublishRecording()

    ani = visualizer.get_recording_as_animation()
    display(HTMP(ani.to_jshtml()))
        

In [None]:
finite_horizon_lqr() #todo pass in inputs from trajectory optimization