In [0]:
import collections
import scipy.linalg
import scipy.integrate
import numpy as np

PendulumParams = collections.namedtuple(
  "PendulumParams", "length mass g"
)

In [9]:
def pendulum_dynamics(params):
  
  def dynamics(x, u, t):
    del u, t
    w = np.product(params)*np.sin(x[0])
    # assume point mass and massless arm
    inertia = params.mass * params.length**2
    amat = np.array([[0, 1], [w/inertia, 0]])
    bmat = np.array([[0], [np.reciprocal(inertia)]])
    return amat, bmat
  
  return dynamics


def dynamical_system(dynamics):
  
  def compute_dynamics(x, u, t):
    amat, bmat = dynamics(x, u, t)
    return np.dot(amat, x) + np.dot(bmat, u)
  
  return compute_dynamics


def controlled_system(policy, system):
  
  def compute_dxdt(x, t):
    u = policy(x, t)
    dxdt = system(x, u, t)
    return dxdt
  
  return compute_dxdt


def continuous_lqr_solve(dynamics, x_goal, t, qmat, rmat):
  amat, bmat = dynamics(x_goal, np.zeros(1), t)
  pmat = scipy.linalg.solve_continuous_are(amat, bmat, qmat, rmat)
  return scipy.linalg.solve(rmat, np.dot(bmat.T, pmat))
  

def lqr_policy(kmat, x_goal):
  def policy(x, t):
    return np.dot(kmat, x_goal - x)
  return policy
                  

x_init = np.array([np.pi/2 - 0.2, 0.])
x_goal = np.array([np.pi/2, 0.])
params = PendulumParams(mass=1., length=1., g=-9.81)


dynamics = pendulum_dynamics(params)
kmat = continuous_lqr_solve(dynamics, x_goal, 0.,
                            np.diag(np.array([1., 1e-2])),
                            np.ones(1)*1e-5)
policy = lqr_policy(kmat, x_goal)

scipy.integrate.odeint(controlled_system(policy, dynamical_system(dynamics)),
                       x_init, np.linspace(0, 2., 10))

array([[1.37079633e+00, 0.00000000e+00],
       [1.50064743e+00, 2.30054249e-01],
       [1.52018635e+00, 2.11649634e-02],
       [1.52196856e+00, 1.92084376e-03],
       [1.52213030e+00, 1.74342563e-04],
       [1.52214498e+00, 1.58324214e-05],
       [1.52214631e+00, 1.44330321e-06],
       [1.52214643e+00, 1.34434916e-07],
       [1.52214644e+00, 1.33820520e-08],
       [1.52214644e+00, 3.30573074e-09]])