In [1]:
%matplotlib inline

In [2]:
from __future__ import print_function

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import time

In [4]:
from ilqr.cost import QRCost, FiniteDiffCost
from ilqr.mujoco_dynamics import MujocoDynamics
from ilqr.mujoco_controller import iLQR, RecedingHorizonController
from ilqr.examples.cartpole import CartpoleDynamics
from ilqr.dynamics import constrain

from scipy.optimize import approx_fprime

import mujoco_py
from mujoco_py import MjViewer
import os

In [5]:
def on_iteration(iteration_count, xs, us, J_opt, accepted, converged):
    J_hist.append(J_opt)
    info = "converged" if converged else ("accepted" if accepted else "failed")
    final_state = xs[-1]
    print("iteration", iteration_count, info, J_opt, final_state)

In [6]:
xml_path = os.path.join('..', 'ilqr', 'xmls', 'inverted_double_pendulum.xml')
dynamics = MujocoDynamics(xml_path, frame_skip = 4, use_multiprocessing = True)
print(dynamics.dt)

Finished loading process 64203
Finished loading process 64204
Finished loading process 64205
Finished loading process 64206
Finished loading process 64207
Finished loading process 64208
Finished loading process 64209
Finished loading process 64210
Finished loading process 64211
Finished loading process 64212
Finished loading process 64213
Finished loading process 64214
0.04
Finished loading process 64215
Finished loading process 64216
Finished loading process 64217
Finished loading process 64218


In [7]:
print(dynamics.state_size)
x_goal = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

# Instantenous state cost.
Q = np.eye(dynamics.state_size)
Q[0, 0] = 5.0
Q[1, 1] = 50.0
Q[2, 2] = 50.0
Q[3, 3] = 20.0
Q[4, 4] = 700.0
Q[5, 5] = 700.0



# Terminal state cost.
Q_terminal = 10 * Q

# Instantaneous control cost.
R = np.eye(1)

cost1 = QRCost(Q, R, Q_terminal=Q_terminal, x_goal=x_goal)

6


In [8]:
def l(x, u, i):
    c0 = x[0] ** 2
    c1 = 10 * ((np.cos(x[1]) - 1) ** 2)
    c2 = 10 * ((np.cos(x[2]) - 1) ** 2)
    c3 = x[3] ** 2
    c4 = x[4] ** 2
    c5 = x[5] ** 2
    cu = 0.5 * u[0] ** 2
    return c0 + c1 + c2 + c3 + c4 + c5 + cu

cost2 = FiniteDiffCost(l, lambda x, i: l(x, [0.0], i), 6, 1, use_multiprocessing = True)

Finished loading process 64219
Finished loading process 64220
Finished loading process 64221
Finished loading process 64222
Finished loading process 64223
Finished loading process 64224
Finished loading process 64225
Finished loading process 64226
Finished loading process 64227
Finished loading process 64228
Finished loading process 64229
Finished loading process 64230
Finished loading process 64231
Finished loading process 64232
Finished loading process 64233
Finished loading process 64234


In [9]:
N = 100
x0 = np.array([0.0, np.random.uniform(-np.pi, np.pi), np.random.uniform(-np.pi, np.pi), 0.0, 0.0, 0.0])
#x0 = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
us_init = np.random.uniform(-1, 1, (N, dynamics.action_size))
#us_init = np.zeros((N, dynamics.action_size))
ilqr = iLQR(dynamics, cost2, N)
mpc = RecedingHorizonController(x0, ilqr)

In [10]:
t0 = time.time()
J_hist = []
controls = mpc.control(us_init, step_size = 3, initial_n_iterations = 500, subsequent_n_iterations = 100, on_iteration = on_iteration)
us = []
for i in range(30):
    print('ITERATION', i, '\n')
    if i == 29:
        us.append(next(controls)[2])
    else:
        us.append(next(controls)[1])
    
print('time', time.time() - t0)
us = np.concatenate(us)

ITERATION 0 

iteration 0 accepted 44985.92331231052 [  2.99460753 -13.84109699  49.1551775   -0.1342446  -10.67556058
  20.50230234]
iteration 1 accepted 44287.0175796551 [  2.99456037 -13.83028564  49.13434721  -0.13235598 -10.58677193
  20.25701955]
iteration 2 accepted 43810.01840165975 [  2.99757704 -13.76857026  49.1066828   -0.22085751 -10.18797772
  19.18025536]
iteration 3 accepted 42393.615347921106 [  2.87122954 -12.66138181  48.35584266   1.20972412  -8.47191904
  12.07200582]
iteration 4 accepted 41903.74465489674 [  2.89784879 -12.58538874  48.2916001    1.37076108  -8.40191666
  11.4237543 ]
iteration 5 accepted 40804.14851198039 [  2.94515031 -12.43027836  48.0515019    1.69028904  -8.37690742
   9.57544803]
iteration 6 accepted 40519.989201205135 [  2.99305577 -12.32207347  48.01040206   1.58043204  -7.40450583
   8.34803111]
iteration 7 accepted 40180.225596126504 [  2.99257996 -12.35882951  47.85211967   1.49502228  -7.63688088
   7.62224799]
iteration 8 accepted 400



iteration 1 accepted 40522.67211360037 [  2.63842856 -13.31633237  47.91603572  -3.2226727   -2.97937104
  -1.98314591]
iteration 2 accepted 40133.64417356158 [  2.33004874 -13.15925226  48.46935231  -2.8485398   -2.97876003
   0.95436331]
iteration 3 accepted 39606.072721398785 [  2.45315563 -12.95552319  48.86161124  -2.45616929  -1.98741202
   1.08543779]
iteration 4 accepted 39559.11589948319 [  2.4608849  -12.93377163  48.85003178  -2.40879416  -1.95736012
   1.07832694]
iteration 5 accepted 39463.96310953276 [  2.46711609 -12.91398086  48.84195954  -2.36554051  -1.92957466
   1.07007352]
iteration 6 accepted 39437.496914284144 [  2.45808746 -12.86486967  48.89039881  -2.3279493   -1.82371937
   1.11505687]
iteration 7 accepted 39417.70716237294 [  2.45931294 -12.86108695  48.88802213  -2.3199335   -1.81918857
   1.1125991 ]
iteration 8 failed 39417.70716237296 [  2.45931294 -12.86108695  48.88802213  -2.3199335   -1.81918857
   1.1125991 ]
iteration 9 failed 39417.70716237296 [  

iteration 5 accepted 38894.60801410092 [  2.06325648 -13.84581504  49.74698094  -2.63630654  -5.02152803
   4.62817537]
iteration 6 failed 38894.60801410094 [  2.06325648 -13.84581504  49.74698094  -2.63630654  -5.02152803
   4.62817537]
iteration 7 failed 38894.60801410094 [  2.06325648 -13.84581504  49.74698094  -2.63630654  -5.02152803
   4.62817537]
iteration 8 failed 38894.60801410094 [  2.06325648 -13.84581504  49.74698094  -2.63630654  -5.02152803
   4.62817537]
iteration 9 failed 38894.60801410094 [  2.06325648 -13.84581504  49.74698094  -2.63630654  -5.02152803
   4.62817537]
iteration 10 failed 38894.60801410094 [  2.06325648 -13.84581504  49.74698094  -2.63630654  -5.02152803
   4.62817537]
iteration 11 accepted 38892.28622509372 [  2.06302938 -13.84586895  49.74847738  -2.63713604  -5.02094738
   4.62847016]
iteration 12 accepted 38891.839132806264 [  2.06307724 -13.84585889  49.74860371  -2.63709086  -5.0208378
   4.62844563]
iteration 13 failed 38891.83913280627 [  2.0630

iteration 8 failed 39340.71860236061 [  0.18144098 -16.82180968  49.55347603  -4.30029346  -8.22203421
   7.9091277 ]
iteration 9 accepted 39339.80300839459 [  0.19495817 -16.82201742  49.55299573  -4.29305014  -8.22318587
   7.90993578]
iteration 10 accepted 39339.470559800735 [  0.19519047 -16.82201784  49.55292735  -4.29287623  -8.22321534
   7.90992455]
iteration 11 failed 39339.470559800735 [  0.19519047 -16.82201784  49.55292735  -4.29287623  -8.22321534
   7.90992455]
iteration 12 accepted 39339.1441763312 [  0.20123227 -16.82203001  49.55329011  -4.2909784   -8.22307239
   7.9100135 ]
iteration 13 failed 39339.1441763312 [  0.20123227 -16.82203001  49.55329011  -4.2909784   -8.22307239
   7.9100135 ]
iteration 14 failed 39339.1441763312 [  0.20123227 -16.82203001  49.55329011  -4.2909784   -8.22307239
   7.9100135 ]
iteration 15 failed 39339.1441763312 [  0.20123227 -16.82203001  49.55329011  -4.2909784   -8.22307239
   7.9100135 ]
iteration 16 failed 39339.1441763312 [  0.2012

iteration 10 accepted 34608.56917523218 [ -2.08001612 -18.52600356  52.20654658  -4.82967647   3.74054325
   2.20117253]
iteration 11 failed 34608.56917523218 [ -2.08001612 -18.52600356  52.20654658  -4.82967647   3.74054325
   2.20117253]
iteration 12 accepted 34608.02466958817 [ -2.07561272 -18.52606425  52.20585353  -4.82837743   3.73966466
   2.20013633]
iteration 13 failed 34608.02466958818 [ -2.07561272 -18.52606425  52.20585353  -4.82837743   3.73966466
   2.20013633]
iteration 14 failed 34608.02466958818 [ -2.07561272 -18.52606425  52.20585353  -4.82837743   3.73966466
   2.20013633]
iteration 15 failed 34608.02466958818 [ -2.07561272 -18.52606425  52.20585353  -4.82837743   3.73966466
   2.20013633]
iteration 16 accepted 34606.620906969896 [ -2.06608028 -18.52606352  52.2059812   -4.82493258   3.74093602
   2.19954635]
iteration 17 accepted 34597.18204309562 [ -2.05462957 -18.52625683  52.20551706  -4.82608941   3.73977422
   2.1992303 ]
iteration 18 failed 34597.18204309562 [

iteration 56 failed 31852.686760589455 [ -0.40839675 -17.38960121  52.01366989  -1.21855436   6.10360631
  -2.1767981 ]
iteration 57 accepted 31852.56199749427 [ -0.40794544 -17.3895877   52.01345721  -1.21852921   6.10346958
  -2.17679499]
iteration 58 failed 31852.561997494267 [ -0.40794544 -17.3895877   52.01345721  -1.21852921   6.10346958
  -2.17679499]
iteration 59 failed 31852.561997494267 [ -0.40794544 -17.3895877   52.01345721  -1.21852921   6.10346958
  -2.17679499]
iteration 60 accepted 31852.210046899614 [ -0.40810382 -17.38968655  52.01280078  -1.21852179   6.102414
  -2.17679771]
iteration 61 failed 31852.210046899603 [ -0.40810382 -17.38968655  52.01280078  -1.21852179   6.102414
  -2.17679771]
iteration 62 failed 31852.210046899603 [ -0.40810382 -17.38968655  52.01280078  -1.21852179   6.102414
  -2.17679771]
iteration 63 accepted 31852.03734226765 [ -0.40789875 -17.38956825  52.01346778  -1.21835125   6.10351043
  -2.17681707]
iteration 64 failed 31852.03734226766 [ -0

iteration 6 failed 32344.196646900233 [ -0.10445067 -14.18118129  49.08915625   0.26786349   4.02751088
   5.68844365]
iteration 7 failed 32344.196646900233 [ -0.10445067 -14.18118129  49.08915625   0.26786349   4.02751088
   5.68844365]
iteration 8 failed 32344.196646900233 [ -0.10445067 -14.18118129  49.08915625   0.26786349   4.02751088
   5.68844365]
iteration 9 failed 32344.196646900233 [ -0.10445067 -14.18118129  49.08915625   0.26786349   4.02751088
   5.68844365]
iteration 10 accepted 32344.141563776237 [ -0.10446579 -14.18118436  49.08916135   0.26788203   4.02751591
   5.68842639]
iteration 11 accepted 32343.848451203874 [ -0.10445245 -14.18121942  49.08914939   0.26794701   4.02757505
   5.68838809]
iteration 12 failed 32343.848451203863 [ -0.10445245 -14.18121942  49.08914939   0.26794701   4.02757505
   5.68838809]
iteration 13 accepted 32343.531064620773 [ -0.10196378 -14.18112353  49.08921109   0.26938561   4.02796742
   5.68772328]
iteration 14 failed 32343.531064620784

iteration 1 accepted 19775.754108436824 [  0.41658999 -15.05792641  50.75307838  -2.69598156   0.95537913
  -1.53926687]
iteration 2 accepted 19684.38391554777 [  1.18819204 -14.21565995  50.65179409  -0.66496301   1.39933579
  -0.08309322]
iteration 3 accepted 19565.92708693221 [ 7.40461008e-01 -1.43942986e+01  5.05353507e+01 -1.03667550e+00
  1.82322595e+00  7.32273885e-03]
iteration 4 accepted 19542.844612257977 [  0.76426774 -14.7687738   50.40164591  -1.72888678   2.41956098
   0.0562767 ]
iteration 5 accepted 19480.377796983412 [  0.65778366 -14.87892074  50.41113392  -1.92156395   2.29540645
   0.0518743 ]
iteration 6 accepted 19462.18440961545 [ 5.57849881e-01 -1.50162107e+01  5.04178185e+01 -2.23030308e+00
  2.18974566e+00  4.34883672e-03]
iteration 7 accepted 19439.067988893894 [  0.26607457 -15.20065408  50.43683338  -2.8475486    2.14117254
  -0.1080089 ]
iteration 8 accepted 19346.735240330596 [ -0.23602348 -15.46004058  50.51924376  -3.07179722   1.56927424
  -0.13061806]

iteration 69 converged 17997.169488137803 [  0.32841913 -18.60541358  50.34464635   0.45485939   0.48057215
   0.18963605]
ITERATION 23 

iteration 0 accepted 14338.93424384703 [  0.33123415 -18.67935193  50.3690781    0.19473492   0.57268567
   0.19164852]
iteration 1 accepted 14338.30885860385 [  0.32419163 -18.67268093  50.35616627   0.24038133   0.53109237
   0.17163635]
iteration 2 accepted 14338.195690124125 [  0.31355931 -18.67162938  50.34116215   0.28443808   0.46246756
   0.15765   ]
iteration 3 accepted 14338.08674221868 [  0.31043551 -18.67192265  50.33707442   0.29611864   0.44120373
   0.15541865]
iteration 4 accepted 14338.044201082703 [  0.30562207 -18.67241429  50.33089723   0.31392886   0.4086509
   0.15218503]
iteration 5 converged 14338.035955933481 [  0.30281559 -18.6726497   50.32737123   0.32433724   0.38985842
   0.15042995]
ITERATION 24 

iteration 0 accepted 9105.523190127915 [  0.2961811  -18.73469528  50.34354359   0.10728258   0.44108374
   0.13970028]
iter

In [11]:
viewer = MjViewer(dynamics.sim)
dynamics.set_state(x0)
print(dynamics.get_state())
for i in range(us.shape[0]):
    dynamics.step(us[i])
    viewer.render()

Creating window glfw
[0.         0.60347394 1.28892961 0.         0.         0.        ]


In [12]:
print(us.shape)

(187, 1)
