In [None]:
!pip install drake

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pydot
from copy import deepcopy
from IPython.display import SVG, display, HTML
from pydrake.all import (
    Value,
    Context,
    Diagram,
    DiagramBuilder,
    DirectCollocation,
    FiniteHorizonLinearQuadraticRegulatorOptions,
    InputPortIndex,
    LeafSystem,
    Linearize,
    LinearQuadraticRegulator,
    MakeFiniteHorizonLinearQuadraticRegulator,
    PiecewisePolynomial,
    PlanarSceneGraphVisualizer,
    PortSwitch,
    Simulator,
    Solve,
    VectorLogSink,
)

import requests
r = requests.get("https://raw.githubusercontent.com/wei-chen-li/wei-chen-li.github.io/main/content/post/triple-pendulum-swingup/notebooks/utils.py")
exec(r.text)

#### Helper functions

In [None]:
def JointState2WorldState(arg1):
    if isinstance(arg1, (int, float)):
        num_states = arg1
    else:
        state = np.array(arg1)
        num_states = len(state)

    assert(num_states % 2 == 0)
    num_pendulums = num_states // 2 - 1
    T = np.block([[np.array([1]), np.zeros((1,num_pendulums))],
                  [np.zeros((num_pendulums,1)), np.tri(num_pendulums)]])
    T = np.block([[T, np.zeros(T.shape)],
                  [np.zeros(T.shape), T]])

    return T if isinstance(arg1, (int, float)) else T.dot(state)

def WorldState2JointState(arg1):
    if isinstance(arg1, (int, float)):
        return np.linalg.inv(JointState2WorldState(arg1))
    else:
        state = np.array(arg1)
        T = JointState2WorldState(len(state))
        return np.linalg.inv(T).dot(state)

### Trajectory optimization

In [None]:
cart_multi_pendulum = CartMultiPendulumSystem(m_cart=1, m1=1,l1=1, m2=1,l2=1)
num_pendulums = cart_multi_pendulum.num_pendulums

dircol = DirectCollocation(
    cart_multi_pendulum,
    cart_multi_pendulum.CreateDefaultContext(),
    num_time_samples=30,
    minimum_time_step=0.01,
    maximum_time_step=0.1
)
prog = dircol.prog()

x_start = WorldState2JointState([0] * (2 * num_pendulums + 2))
prog.AddBoundingBoxConstraint(x_start, x_start, dircol.initial_state())

x_target = WorldState2JointState([0] + [np.pi] * num_pendulums + [0] * (num_pendulums + 1))
prog.AddBoundingBoxConstraint(x_target, x_target, dircol.final_state())

dircol.AddRunningCost(10 * dircol.input()[0] ** 2)

u_trj_init = PiecewisePolynomial()
x_trj_init = PiecewisePolynomial.FirstOrderHold([0.0, 4.0], np.vstack([x_start, x_target]).T)
dircol.SetInitialTrajectory(u_trj_init, x_trj_init)

result = Solve(prog)
print("Found trajectory? ", result.is_success())

In [None]:
x_trj = dircol.ReconstructStateTrajectory(result)
u_trj = dircol.ReconstructInputTrajectory(result)

t = np.linspace(x_trj.start_time(), x_trj.end_time(), 1000)

fig = plt.figure()
ax1, ax2 = fig.subplots(2, 1, sharex=True)

ax1.plot(t, (x_trj.vector_values(t).T)[:,1:num_pendulums+1], '--')
ax1.legend([rf'$\theta_{i}$' for i in range(1,num_pendulums+1)])
ax1.set_ylabel(r'$\theta$ (rad)')
ax1.set_yticks([-np.pi, 0, np.pi])
ax1.set_yticklabels([r'$-\pi$', '$0$', '$\pi$'])

ax2.plot(t, u_trj.vector_values(t).T, 'k--')
ax2.set_ylabel('$f$ (N)')
ax2.set_xlim(u_trj.start_time(), u_trj.end_time())
ax2.set_xlabel('$t$')

plt.show()

### Trajectory tracking

In [None]:
Q = np.diag([10] * (num_pendulums+1) + [1] * (num_pendulums+1))
R = np.diag([1])

options = FiniteHorizonLinearQuadraticRegulatorOptions()
options.Qf = Q
options.x0 = x_trj
options.u0 = u_trj

swingup_controller = MakeFiniteHorizonLinearQuadraticRegulator(
    cart_multi_pendulum,
    cart_multi_pendulum.CreateDefaultContext(),
    t0=x_trj.start_time(),
    tf=x_trj.end_time(),
    Q=Q,
    R=R,
    options=options
)
swingup_controller.set_name("swingup_controller")

### Stabilizing

In [None]:
class StabilizingController(LeafSystem):
    def __init__(self, cart_multi_pendulum: Diagram, state, Q, R):
        LeafSystem.__init__(self)

        x0 = np.array(state)
        u0 = np.zeros(1)

        context = cart_multi_pendulum.CreateDefaultContext()
        cart_multi_pendulum.GetMutableSubsystemState(cart_multi_pendulum.GetSubsystemByName('plant'), context).get_mutable_continuous_state().SetFromVector(x0)
        cart_multi_pendulum.get_input_port(0).FixValue(context, u0)

        linear_plant = Linearize(cart_multi_pendulum, context)
        A = linear_plant.A()
        B = linear_plant.B()

        K, _ = LinearQuadraticRegulator(A, B, Q, R)

        self.K, self.x0, self.u0 = K, x0, u0
        self.DeclareVectorInputPort("x", len(x0))
        self.DeclareVectorOutputPort("f_cart", 1, self.DoCalcOutput0)
        self.DeclareVectorOutputPort("in_ROA", 1, self.DoCalcOutput1)

    def DoCalcOutput0(self, context, output):
        num_states = len(self.x0)
        x = self.get_input_port().Eval(context)
        delta_x = x - self.x0
        delta_x[1:num_states//2] = self.wrap_to_pi(delta_x[1:num_states//2])
        u = self.u0 - self.K @ delta_x
        output.SetFromVector(u)

    def DoCalcOutput1(self, context, output):
        num_states = len(self.x0)
        x = self.get_input_port().Eval(context)
        delta_x = x - self.x0
        in_roa = np.all(np.abs(self.wrap_to_pi(delta_x[1:num_states//2])) < np.deg2rad(5))
        output.SetFromVector([float(in_roa)])

    @staticmethod
    def wrap_to_pi(angle):
        return (angle + np.pi) % (2 * np.pi) - np.pi


class Latch(LeafSystem):
    def __init__(self):
        super().__init__()

        self.DeclareVectorInputPort("sig", 1)

        state_index = self.DeclareAbstractState(Value(InputPortIndex(1)))
        self.DeclarePeriodicUnrestrictedUpdateEvent(
            period_sec=1e-3,  # 1000 Hz
            offset_sec=0.0,
            update=self.Update)

        self.DeclareStateOutputPort("sel", state_index)

    def Update(self, context:Context, state):
        input = self.get_input_port().Eval(context)
        port = context.get_abstract_state(0).get_value()
        port_new = port if int(port) == 2 else (InputPortIndex(2) if input > 0 else InputPortIndex(1))
        state.get_mutable_abstract_state(0).set_value(port_new)

In [None]:
builder = DiagramBuilder()

cart_multi_pendulum = builder.AddSystem(deepcopy(cart_multi_pendulum))

swingup_controller = builder.AddSystem(deepcopy(swingup_controller))
builder.Connect(cart_multi_pendulum.get_output_port(0), swingup_controller.get_input_port())

Q = np.diag([10] * (num_pendulums+1) + [1] * (num_pendulums+1))
R = np.diag([1])
stabilize_controller = StabilizingController(cart_multi_pendulum, x_target, Q, R)
builder.AddSystem(stabilize_controller)
builder.Connect(cart_multi_pendulum.get_output_port(0), stabilize_controller.get_input_port())

switch = builder.AddSystem(PortSwitch(1))
builder.Connect(switch.get_output_port(), cart_multi_pendulum.get_input_port())
builder.Connect(swingup_controller.get_output_port(), switch.DeclareInputPort('port0'))
builder.Connect(stabilize_controller.get_output_port(0), switch.DeclareInputPort('port1'))

latch = builder.AddSystem(Latch())
builder.Connect(stabilize_controller.get_output_port(1), latch.get_input_port())
builder.Connect(latch.get_output_port(), switch.get_port_selector_input_port())

visualizer = PlanarSceneGraphVisualizer(cart_multi_pendulum.GetSubsystemByName("scene_graph"), show=False)
builder.AddSystem(visualizer)
builder.Connect(cart_multi_pendulum.get_output_port(1), visualizer.get_geometry_query_input_port())

logger_x = builder.AddSystem(VectorLogSink(2 * num_pendulums + 2))
builder.Connect(cart_multi_pendulum.get_output_port(0), logger_x.get_input_port())

logger_u = builder.AddSystem(VectorLogSink(1))
builder.Connect(switch.get_output_port(), logger_u.get_input_port())

logger_s = builder.AddSystem(VectorLogSink(1))
builder.Connect(stabilize_controller.get_output_port(1), logger_s.get_input_port())

diagram = builder.Build()

display(SVG(pydot.graph_from_dot_data(diagram.GetGraphvizString())[0].create_svg()))

In [None]:
simulator = Simulator(diagram)

x0 = [0] * (2*num_pendulums+2)

context = simulator.get_mutable_context()
context.SetTime(0.0)
context.SetContinuousState(x0)

simulator.Initialize()
visualizer.reset_recording()
visualizer.start_recording()
simulator.AdvanceTo(10)
visualizer.stop_recording()

In [None]:
visualizer.fig.set_size_inches([6, num_pendulums*2.1])
visualizer.ax.set_title('t = ?')
visualizer.fig.tight_layout()
visualizer.ax.set_xlim([-3, 3])
bbox = visualizer.ax.get_position()
bbox.x0, bbox.x1 = 0, 1
visualizer.ax.set_position(bbox)

HTML(visualizer.get_recording_as_animation().to_jshtml())

In [None]:
idxs = np.nonzero(logger_s.FindLog(context).data().flatten())[0]
switch_time = logger_s.FindLog(context).sample_times()[idxs[0]] if len(idxs) else None

fig = plt.figure()
ax1, ax2 = fig.subplots(2, 1, sharex=True)

trj_t = np.linspace(x_trj.start_time(), x_trj.end_time(), 1000)
sim_t = logger_x.FindLog(context).sample_times()
ax1.plot(trj_t, (x_trj.vector_values(trj_t).T)[:,1:num_pendulums+1], '--')
ax1.set_prop_cycle(None)
ax1.plot(sim_t, (logger_x.FindLog(context).data().T)[:,1:num_pendulums+1], '-')
if switch_time:
    ax1.axvline(switch_time, color='k', linestyle=':')
ax1.legend([rf'$\theta_{i}^{{trj}}$' for i in range(1,num_pendulums+1)] + [rf'$\theta_{i}^{{sim}}$' for i in range(1,num_pendulums+1)])
ax1.set_ylabel(r'$\theta$ (rad)')
ax1.set_yticks([-np.pi, 0, np.pi])
ax1.set_yticklabels([r'$-\pi$', '$0$', '$\pi$'])

ax2.plot(trj_t, u_trj.vector_values(trj_t).T, 'k--')
ax2.plot(sim_t, logger_u.FindLog(context).data().T, 'k-')
if switch_time:
    ax2.axvline(switch_time, color='k', linestyle=':')
ax2.legend(['$f^{trj}$'] + ['$f^{sim}$'])
ax2.set_ylabel('$f$ (N)')
ax2.set_xlabel('$t$')
ax2.set_xlim(min(sim_t), max(sim_t))

plt.show()