Idea: what if we run a lot more optimizations but with a much shorter time limit. Can we bootstrap the warmstarts to make that short horizon count?

In [1]:
using Revise

In [2]:
using MeshCatMechanisms
using MeshCat
using RigidBodyDynamics
using Gurobi
using Flux
using ProgressMeter
using MLDataPattern
using JLD2
using Plots; gr()

Plots.GRBackend()

In [3]:
import LCPSim
import LearningMPC
import BoxValkyries
reload("LearningMPC")
reload("BoxValkyries")

[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/rdeits/locomotion/explorations/learning-mpc/packages/lib/v0.6/BoxValkyries.ji for module BoxValkyries.


In [4]:
robot = BoxValkyries.BoxValkyrie(true, LCPSim.planar_revolute_base);
mvis = MechanismVisualizer(robot)
IJuliaCell(mvis)

Listening on 127.0.0.1:7001...
zmq_url=tcp://127.0.0.1:6001
web_url=http://127.0.0.1:7001/static/




In [19]:
function create_net()
    net = Chain(
        Dense(22, 64, elu),
        Dense(64, 64, elu),
        Dense(64, 11)
    )
    loss = (x, y) -> Flux.mse(net(x), y)
    net, loss
end

create_net (generic function with 1 method)

In [22]:
net, loss = create_net()
net_controller = x -> Flux.Tracker.data(net(state_vector(x)))
net_params = params(net)
optimizer = Flux.Optimise.ADADelta(net_params)

xstar = BoxValkyries.nominal_state(robot)

mpc_params = LearningMPC.MPCParams(
    Δt=0.05,
    horizon=10,
    mip_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0, 
        TimeLimit=5, 
        MIPGap=1e-1, 
        MIPGapAbs=5e-1,
        FeasibilityTol=1e-3),
    lcp_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0))

Q, R = BoxValkyries.default_costs(robot)
feet = findbody.(robot.mechanism, ["lf", "rf"])
lqrsol = LearningMPC.LQRSolution(xstar, Q, R, mpc_params.Δt, Point3D.(default_frame.(feet), 0., 0., 0.))
lqrsol.S .= 1 ./ mpc_params.Δt .* Q

mpc_controller = LearningMPC.MPCController(robot.mechanism, 
    robot.environment, mpc_params, lqrsol, 
    [net_controller, lqrsol]);

sample_sink = LearningMPC.MPCSampleSink{Float64}(true)
playback_sink = LearningMPC.PlaybackSink(mvis, mpc_params.Δt)

mpc_controller.callback = LearningMPC.call_each(
    sample_sink,
    playback_sink,
)

live_viewer = LearningMPC.live_viewer(mvis)


dagger_controller = LearningMPC.call_each(
    LearningMPC.dagger_controller(
        mpc_controller,
        net_controller,
        0.2),
    live_viewer
    )

termination = x -> begin
    (configuration(x)[2] < 0.5 || 
     configuration(x)[3] > π/4 ||
     configuration(x)[3] < -π/4)
end

dataset = LearningMPC.Dataset(lqrsol)

function collect_into!(data::Vector{<:LearningMPC.Sample})
    empty!(sample_sink)
    LearningMPC.randomize!(x0, xstar, 0.0, 1.5)
    results = LCPSim.simulate(x0, 
        dagger_controller,
        robot.environment, mpc_params.Δt, 30, 
        mpc_params.lcp_solver;
        termination=termination);
    append!(data, sample_sink.samples)
end

x0 = MechanismState{Float64}(robot.mechanism)

features(s::LearningMPC.Sample) = (s.state, s.uJ[:, 1])

Academic license - for non-commercial use only
Academic license - for non-commercial use only


features (generic function with 1 method)

In [24]:
datasets = Vector{LearningMPC.Dataset{Float64}}()
all_training_data = Vector{Tuple{Vector{Float64}, Vector{Float64}}}()
all_validation_data = Vector{Tuple{Vector{Float64}, Vector{Float64}}}()
losses = Vector{Tuple{Float64, Float64}}()

@showprogress for i in 1:20
    dataset = LearningMPC.Dataset(lqrsol)
    for i in 1:2
        collect_into!(dataset.training_data)
    end
    collect_into!(dataset.testing_data)
    collect_into!(dataset.validation_data)
    append!(all_training_data, features.(dataset.training_data))
    append!(all_validation_data, features.(dataset.validation_data))
    
    @time for i in 1:10
        Flux.train!(loss, shuffleobs(all_training_data), optimizer)
        push!(losses, 
            (mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_training_data),
             mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_validation_data)))
    end
    push!(datasets, dataset)
    
    jldopen("box-val-revolute.jld2", "w") do file
        file["datasets"] = datasets
        file["net"] = net
        file["lqrsol"] = lqrsol
        file["mpc_params"] = Dict(
            "Δt" => mpc_params.Δt,
            "horizon" => mpc_params.horizon,
        )
        file["all_training_data"] = all_training_data
        file["all_validation_data"] = all_validation_data
        file["losses"] = losses
    end
    
    plt = plot(first.(losses), label="training")
    plot!(plt, last.(losses), label="validation")
    ylims!(plt, (0, ylims(plt)[2]))
    display(plt)
end

captured: ErrorException("Unrecognized solution status: loaded")


  0.128445 seconds (58.96 k allocations: 5.459 MiB)


[32mProgress:   5%|██                                       |  ETA: 0:29:19[39m

  0.049811 seconds (72.02 k allocations: 7.929 MiB)


[32mProgress:  10%|████                                     |  ETA: 0:23:26[39m

  0.083658 seconds (113.04 k allocations: 12.449 MiB, 10.53% gc time)


[32mProgress:  15%|██████                                   |  ETA: 0:27:24[39m

  0.140275 seconds (168.38 k allocations: 18.571 MiB, 6.50% gc time)


[32mProgress:  20%|████████                                 |  ETA: 0:28:50[39m

  0.148565 seconds (224.30 k allocations: 24.749 MiB, 6.21% gc time)


[32mProgress:  25%|██████████                               |  ETA: 0:28:28[39m

  0.214130 seconds (282.01 k allocations: 31.127 MiB, 3.92% gc time)


[32mProgress:  30%|████████████                             |  ETA: 0:28:02[39m

  0.496793 seconds (321.25 k allocations: 35.451 MiB, 52.77% gc time)


[32mProgress:  35%|██████████████                           |  ETA: 0:26:17[39m

  0.257415 seconds (387.80 k allocations: 42.773 MiB, 5.85% gc time)


[32mProgress:  40%|████████████████                         |  ETA: 0:24:09[39m

  0.315862 seconds (446.67 k allocations: 49.268 MiB, 4.93% gc time)


[32mProgress:  45%|██████████████████                       |  ETA: 0:22:35[39m

  0.365952 seconds (487.69 k allocations: 53.788 MiB, 4.32% gc time)


[32mProgress:  50%|████████████████████                     |  ETA: 0:20:33[39m

  0.409100 seconds (557.83 k allocations: 61.514 MiB, 5.72% gc time)


[32mProgress:  55%|███████████████████████                  |  ETA: 0:18:54[39m

captured: ErrorException("Unrecognized solution status: loaded")


  0.439191 seconds (627.39 k allocations: 69.182 MiB, 7.06% gc time)


[32mProgress:  60%|█████████████████████████                |  ETA: 0:17:08[39m

  0.447059 seconds (662.46 k allocations: 73.049 MiB, 4.80% gc time)


[32mProgress:  65%|███████████████████████████              |  ETA: 0:14:47[39m

  0.479230 seconds (698.74 k allocations: 77.049 MiB, 6.39% gc time)


[32mProgress:  70%|█████████████████████████████            |  ETA: 0:12:26[39m

  0.526922 seconds (767.14 k allocations: 84.602 MiB, 5.87% gc time)


[32mProgress:  75%|███████████████████████████████          |  ETA: 0:10:48[39m

  0.548123 seconds (815.32 k allocations: 89.923 MiB, 5.26% gc time)


[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:08:47[39m

  0.589017 seconds (880.72 k allocations: 97.133 MiB, 6.40% gc time)


[32mProgress:  85%|███████████████████████████████████      |  ETA: 0:06:37[39m

  0.609262 seconds (950.33 k allocations: 104.827 MiB, 4.67% gc time)


[32mProgress:  90%|█████████████████████████████████████    |  ETA: 0:04:29[39m

  0.686763 seconds (1.01 M allocations: 111.837 MiB, 5.32% gc time)


[32mProgress:  95%|███████████████████████████████████████  |  ETA: 0:02:15[39m

  0.705001 seconds (1.06 M allocations: 116.384 MiB, 5.08% gc time)


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:44:22[39m


In [33]:
x0 = MechanismState{Float64}(robot.mechanism)
LearningMPC.randomize!(x0, xstar, 0.0, 1.0)
results = LCPSim.simulate(x0, 
    net_controller,
    robot.environment, mpc_params.Δt, 100, 
    mpc_params.lcp_solver,
    termination=x -> false);
LearningMPC.playback(mvis, results, 0.05)

In [35]:
online_mpc_controller = LearningMPC.OnlineMPCController(
    robot.mechanism,
    robot.environment,
    mpc_params,
    lqrsol,
    [lqrsol, net_controller]
);

LearningMPC.OnlineMPCController{Float64,LearningMPC.MPCParams{Gurobi.GurobiSolver,Gurobi.GurobiSolver},RigidBodyDynamics.MechanismState{Float64,Float64,Float64,TypeSortedCollections.TypeSortedCollection{Tuple{Array{RigidBodyDynamics.Joint{Float64,RigidBodyDynamics.Prismatic{Float64}},1},Array{RigidBodyDynamics.Joint{Float64,RigidBodyDynamics.Revolute{Float64}},1},Array{RigidBodyDynamics.Joint{Float64,RigidBodyDynamics.Fixed{Float64}},1}},3},TypeSortedCollections.TypeSortedCollection{Tuple{Array{RigidBodyDynamics.Spatial.GeometricJacobian{StaticArrays.SArray{Tuple{3,1},Float64,2,3}},1},Array{RigidBodyDynamics.Spatial.GeometricJacobian{StaticArrays.SArray{Tuple{3,1},Float64,2,3}},1},Array{RigidBodyDynamics.Spatial.GeometricJacobian{StaticArrays.SArray{Tuple{3,0},Float64,2,0}},1}},3},TypeSortedCollections.TypeSortedCollection{Tuple{Array{RigidBodyDynamics.Spatial.WrenchMatrix{StaticArrays.SArray{Tuple{3,5},Float64,2,15}},1},Array{RigidBodyDynamics.Spatial.WrenchMatrix{StaticArrays.SArray{

In [41]:
x0 = MechanismState{Float64}(robot.mechanism)
LearningMPC.randomize!(x0, xstar, 0.0, 1.0)
results = LCPSim.simulate(x0, 
    online_mpc_controller,
    robot.environment, mpc_params.Δt, 100, 
    mpc_params.lcp_solver,
    termination=x -> false);
LearningMPC.playback(mvis, results, 0.05)

cost.(warmstarts) = [43.3878, 57.9562]
cost.(warmstarts) = [14.3816, 128.244]
cost.(warmstarts) = [9.47058, 83.5072]
cost.(warmstarts) = [4.76008, 91.9679]
cost.(warmstarts) = [2.59061, 63.6374]
cost.(warmstarts) = [1.61861, 56.9274]
cost.(warmstarts) = [1.03664, 46.9526]
cost.(warmstarts) = [0.771254, 44.6018]
cost.(warmstarts) = [0.695533, 42.0804]
cost.(warmstarts) = [0.726613, 38.8269]
cost.(warmstarts) = [0.812213, 37.7745]
cost.(warmstarts) = [0.920329, 36.0123]
cost.(warmstarts) = [1.081, 35.1758]
cost.(warmstarts) = [1.1755, 36.2615]
cost.(warmstarts) = [1.27987, 36.9602]
cost.(warmstarts) = [1.43556, 36.659]
cost.(warmstarts) = [1.49673, 37.1217]
cost.(warmstarts) = [1.55826, 37.0315]
cost.(warmstarts) = [1.69655, 36.8085]
cost.(warmstarts) = [1.72253, 38.034]
cost.(warmstarts) = [1.75582, 36.2201]
cost.(warmstarts) = [1.86907, 36.6439]
cost.(warmstarts) = [1.85505, 36.5424]
cost.(warmstarts) = [1.93842, 35.6882]
cost.(warmstarts) = [1.92259, 35.9033]
cost.(warmstarts) = [2.00

In [40]:
LearningMPC.playback(mvis, results, 0.05)

This really isn't working well. The training isn't very effective (the final loss is relatively high, and there's a lot of overfitting), and the trained policy isn't very effective (LQR seems to outperform it consistently). 

In addition, I'm concerned that the bootstrap idea isn't terribly good, because even if it did work, we would still have a lot of bad old samples in the training set. We *might* be able to decay their importance over time, but that's a whole other kind of hyperparameter that I'm not sure I want to tune.

In addition, the online MPC version seems little better than LQR, as the network's warmstart is never chosen. 

So far, the only case where the network does something non-trivial correctly is the simple hopper, which is simple enough that the MIQP is almost fast enough to run online as-is. 

That's frustrating. 

How can we do better? 

* Learn better
* Use more domain knowledge

I have a feeling that I should re-visit learning the optimal value function instead of the policy. There are a couple of potential advantages:

* It's a much simpler function (scalar instead of 11-dimensional)
* I can exploit my knowledge of the model when I compute `u` from dJ/dx
* It might be more amenable to bootstrapping

To that last point: currently if I terminate an optimization prematurely, I end up with a bad (x, u) pair in the dataset. Even if the network learns a better warmstart near that x, my dataset still contains the bad pair. 

On the other hand, if I'm instead learning (x, (J_lower_bound, J_upper_bound)), then a bad sample just consists of a bound that's too wide. If I later learn a tighter bound for that sample, then the original bad data has no effect! 

