In [1]:
using Revise

In [2]:
using MeshCat
using MeshCatMechanisms
using RigidBodyDynamics
using JuMP
using Gurobi
using JLD2
using Flux
using MLDataPattern
using ProgressMeter
using Plots; gr()

Plots.GRBackend()

In [3]:
import LCPSim
import BoxValkyries
import FluxExtensions

[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/rdeits/locomotion/explorations/learning-mpc/packages/lib/v0.6/LearningMPC.ji for module LearningMPC.
[39m[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/rdeits/locomotion/explorations/learning-mpc/packages/lib/v0.6/BoxValkyries.ji for module BoxValkyries.
[39m

In [4]:
robot = BoxValkyries.BoxValkyrie()
mvis = MechanismVisualizer(robot)
IJuliaCell(mvis)

Listening on 127.0.0.1:7000...
zmq_url=tcp://127.0.0.1:6000
web_url=http://127.0.0.1:7000/static/


In [5]:
function create_net()
    net = Chain(
        Dense(22, 64, elu),
        Dense(64, 64, elu),
        Dense(64, 1)
    )
    loss = (x, lb, ub) -> begin
        y = net(x)
        sum(ifelse.(y .< lb, lb .- y, ifelse.(y .> ub, y .- ub, 0 .* y)))
    end
    net, loss
end

features(s::LearningMPC.Sample) = (s.state, s.mip.objective_bound, s.mip.objective_value)

features (generic function with 1 method)

In [15]:
xstar = BoxValkyries.nominal_state(robot)
mpc_params = LearningMPC.MPCParams(
    Δt=0.05,
    horizon=10,
    mip_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0, 
        TimeLimit=5,
        FeasibilityTol=1e-3),
    lcp_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0))

Q, R = BoxValkyries.default_costs(robot)
feet = findbody.(robot.mechanism, ["lf", "rf"])
lqrsol = LearningMPC.LQRSolution(xstar, Q, R, mpc_params.Δt, Point3D.(default_frame.(feet), 0., 0., 0.))
# lqrsol.S .= 1 ./ mpc_params.Δt .* Q

net, loss = create_net()
tangent_net = FluxExtensions.TangentPropagator(net)
net_params = params(net)
optimizer = Flux.Optimise.ADAM(net_params)

net_qp_controller = let x_net = MechanismState(robot.mechanism)
    function (state)
        set_configuration!(x_net, configuration(state))
        set_velocity!(x_net, velocity(state))
        value, jac = tangent_net(Vector(state))
        m = Model(solver=mpc_params.mip_solver)
        _, results = LCPSim.optimize(x_net, robot.environment, mpc_params.Δt, 1, m)
        @assert length(results) == 1
        x = results[1].state.state
        u = results[1].input
        q = vec(Flux.Tracker.data(jac))
        @objective m Min u' * lqrsol.R * u + x' * lqrsol.Q * x + q' * x
        solve(m)
        getvalue(results[1].input)
    end
end

mpc_controller = LearningMPC.MPCController(robot.mechanism, 
    robot.environment, mpc_params, lqrsol, 
    [lqrsol, net_qp_controller]);

sample_sink = LearningMPC.MPCSampleSink{Float64}(true)
playback_sink = LearningMPC.PlaybackSink(mvis, mpc_params.Δt)

mpc_controller.callback = LearningMPC.call_each(
    sample_sink,
    playback_sink,
)

live_viewer = LearningMPC.live_viewer(mvis)

termination = x -> begin
    (configuration(x)[2] < 0.5 || 
     configuration(x)[3] > π/4 ||
     configuration(x)[3] < -π/4)
end

dagger_controller = LearningMPC.call_each(
    LearningMPC.dagger_controller(
        mpc_controller,
        net_qp_controller,
        0.2),
    live_viewer
    )

dataset = LearningMPC.Dataset(lqrsol)

x_init = BoxValkyries.nominal_state(robot)
x0 = MechanismState{Float64}(robot.mechanism)

function collect_into!(data::Vector{<:LearningMPC.Sample})
    empty!(sample_sink)
    LearningMPC.randomize!(x0, x_init, 0.0, 1.5)
    results = LCPSim.simulate(x0, 
        dagger_controller,
        robot.environment, mpc_params.Δt, 50, 
        mpc_params.lcp_solver;
        termination=termination);
    append!(data, sample_sink.samples)
end


collect_into! (generic function with 1 method)

In [18]:
datasets = Vector{LearningMPC.Dataset{Float64}}()
all_training_data = Vector{Tuple{Vector{Float64}, Float64, Float64}}()
all_validation_data = Vector{Tuple{Vector{Float64}, Float64, Float64}}()
losses = Vector{Tuple{Float64, Float64}}()

@showprogress for i in 1:100
    dataset = LearningMPC.Dataset(lqrsol)
    for i in 1:2
        collect_into!(dataset.training_data)
    end
    collect_into!(dataset.testing_data)
    collect_into!(dataset.validation_data)
    append!(all_training_data, features.(dataset.training_data))
    append!(all_validation_data, features.(dataset.validation_data))
    filter!(all_training_data) do x
        !(any(isnan, x[1]) || isnan(x[2]) || isnan(x[3]))
    end
    filter!(all_validation_data) do x
        !(any(isnan, x[1]) || isnan(x[2]) || isnan(x[3]))
    end
    
    @time for i in 1:20
        Flux.train!(loss, shuffleobs(all_training_data), optimizer)
        push!(losses, 
            (mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_training_data),
             mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_validation_data)))
    end
    push!(datasets, dataset)
    
    jldopen("box-atlas-value-interval.jld2", "w") do file
        file["datasets"] = datasets
        file["net"] = net
        file["lqrsol"] = lqrsol
        file["mpc_params"] = Dict(
            "Δt" => mpc_params.Δt,
            "horizon" => mpc_params.horizon,
        )
        file["all_training_data"] = all_training_data
        file["all_validation_data"] = all_validation_data
        file["losses"] = losses
    end
    
    plt = plot(first.(losses), label="training")
    plot!(plt, last.(losses), label="validation")
    ylims!(plt, (0, ylims(plt)[2]))
    display(plt)
end

  0.239717 seconds (185.67 k allocations: 15.783 MiB)


[32mProgress:   1%|                                         |  ETA: 6:09:15[39m

  0.422717 seconds (389.96 k allocations: 34.457 MiB, 4.28% gc time)


[32mProgress:   2%|█                                        |  ETA: 6:20:26[39m

  0.897804 seconds (644.30 k allocations: 56.957 MiB, 29.89% gc time)


[32mProgress:   3%|█                                        |  ETA: 6:51:52[39m

  0.793201 seconds (791.20 k allocations: 69.700 MiB, 3.02% gc time)


[32mProgress:   4%|██                                       |  ETA: 6:03:47[39m

  0.992442 seconds (998.20 k allocations: 87.833 MiB, 3.40% gc time)


[32mProgress:   5%|██                                       |  ETA: 5:51:20[39m

  1.125275 seconds (1.14 M allocations: 100.985 MiB, 2.84% gc time)


[32mProgress:   6%|██                                       |  ETA: 5:43:09[39m

  1.201516 seconds (1.22 M allocations: 107.725 MiB, 3.50% gc time)


[32mProgress:   7%|███                                      |  ETA: 5:35:49[39m

  1.324435 seconds (1.33 M allocations: 117.762 MiB, 3.17% gc time)


[32mProgress:   8%|███                                      |  ETA: 5:33:14[39m

  1.447341 seconds (1.46 M allocations: 130.030 MiB, 3.35% gc time)


[32mProgress:   9%|████                                     |  ETA: 5:33:35[39m

  1.628167 seconds (1.61 M allocations: 143.092 MiB, 3.65% gc time)


[32mProgress:  10%|████                                     |  ETA: 5:25:08[39m

  1.731112 seconds (1.76 M allocations: 156.616 MiB, 3.65% gc time)


[32mProgress:  11%|█████                                    |  ETA: 5:32:32[39m

  2.202879 seconds (1.97 M allocations: 174.982 MiB, 14.12% gc time)


[32mProgress:  12%|█████                                    |  ETA: 5:39:53[39m

  2.037998 seconds (2.14 M allocations: 190.332 MiB, 2.87% gc time)


[32mProgress:  13%|█████                                    |  ETA: 5:32:14[39m

  2.158034 seconds (2.23 M allocations: 198.365 MiB, 3.19% gc time)


[32mProgress:  14%|██████                                   |  ETA: 5:25:15[39m

  2.446001 seconds (2.50 M allocations: 221.609 MiB, 3.28% gc time)


[32mProgress:  15%|██████                                   |  ETA: 5:30:55[39m

  2.598450 seconds (2.66 M allocations: 236.068 MiB, 3.45% gc time)


[32mProgress:  16%|███████                                  |  ETA: 5:29:54[39m

  2.840789 seconds (2.89 M allocations: 256.573 MiB, 3.15% gc time)


[32mProgress:  17%|███████                                  |  ETA: 5:32:33[39m

  2.899036 seconds (3.02 M allocations: 267.175 MiB, 3.27% gc time)


[32mProgress:  18%|███████                                  |  ETA: 5:20:46[39m

  3.067812 seconds (3.19 M allocations: 283.118 MiB, 3.27% gc time)


[32mProgress:  19%|████████                                 |  ETA: 5:14:48[39m

  3.216821 seconds (3.34 M allocations: 296.642 MiB, 3.45% gc time)


[32mProgress:  20%|████████                                 |  ETA: 5:13:12[39m

  3.607265 seconds (3.50 M allocations: 310.820 MiB, 9.95% gc time)


[32mProgress:  21%|█████████                                |  ETA: 5:06:10[39m

  3.818281 seconds (3.71 M allocations: 329.235 MiB, 9.65% gc time)


[32mProgress:  22%|█████████                                |  ETA: 5:03:23[39m

  3.954507 seconds (3.95 M allocations: 349.879 MiB, 3.31% gc time)


[32mProgress:  23%|█████████                                |  ETA: 5:03:21[39m

  3.898277 seconds (4.08 M allocations: 361.405 MiB, 3.21% gc time)


[32mProgress:  24%|██████████                               |  ETA: 4:59:41[39m

  4.054834 seconds (4.21 M allocations: 373.212 MiB, 3.40% gc time)


[32mProgress:  25%|██████████                               |  ETA: 4:56:31[39m

  4.198930 seconds (4.37 M allocations: 386.944 MiB, 3.41% gc time)


[32mProgress:  26%|███████████                              |  ETA: 4:53:02[39m

  4.457515 seconds (4.60 M allocations: 407.943 MiB, 3.37% gc time)


[32mProgress:  27%|███████████                              |  ETA: 4:49:07[39m

  4.576929 seconds (4.76 M allocations: 421.288 MiB, 3.50% gc time)


[32mProgress:  28%|███████████                              |  ETA: 4:41:43[39m

  4.844677 seconds (5.02 M allocations: 444.902 MiB, 3.37% gc time)


[32mProgress:  29%|████████████                             |  ETA: 4:41:17[39m

  4.988382 seconds (5.18 M allocations: 458.528 MiB, 3.24% gc time)


[32mProgress:  30%|████████████                             |  ETA: 4:36:16[39m

  5.212198 seconds (5.39 M allocations: 476.892 MiB, 3.45% gc time)


[32mProgress:  31%|█████████████                            |  ETA: 4:33:12[39m

  5.390962 seconds (5.61 M allocations: 496.512 MiB, 3.40% gc time)


[32mProgress:  32%|█████████████                            |  ETA: 4:30:49[39m

  5.547081 seconds (5.76 M allocations: 509.626 MiB, 3.35% gc time)


[32mProgress:  33%|██████████████                           |  ETA: 4:24:32[39m

  5.701526 seconds (5.89 M allocations: 520.832 MiB, 3.24% gc time)


[32mProgress:  34%|██████████████                           |  ETA: 4:20:23[39m

  5.901495 seconds (6.12 M allocations: 541.335 MiB, 3.40% gc time)


[32mProgress:  35%|██████████████                           |  ETA: 4:18:13[39m

  6.147832 seconds (6.29 M allocations: 556.536 MiB, 3.40% gc time)


[32mProgress:  36%|███████████████                          |  ETA: 4:15:10[39m

  6.206394 seconds (6.46 M allocations: 571.135 MiB, 3.32% gc time)


[32mProgress:  37%|███████████████                          |  ETA: 4:10:25[39m

  6.439035 seconds (6.72 M allocations: 594.380 MiB, 3.25% gc time)


[32mProgress:  38%|████████████████                         |  ETA: 4:09:03[39m

  6.792947 seconds (6.93 M allocations: 612.744 MiB, 3.24% gc time)


[32mProgress:  39%|████████████████                         |  ETA: 4:06:34[39m

  6.994527 seconds (7.28 M allocations: 643.479 MiB, 3.30% gc time)


[32mProgress:  40%|████████████████                         |  ETA: 4:07:15[39m

  7.131789 seconds (7.40 M allocations: 654.313 MiB, 3.31% gc time)


[32mProgress:  41%|█████████████████                        |  ETA: 4:01:33[39m

  7.229655 seconds (7.51 M allocations: 663.611 MiB, 3.19% gc time)


[32mProgress:  42%|█████████████████                        |  ETA: 3:55:12[39m

  7.549790 seconds (7.78 M allocations: 687.187 MiB, 3.25% gc time)


[32mProgress:  43%|██████████████████                       |  ETA: 3:52:24[39m

  7.928685 seconds (8.00 M allocations: 706.065 MiB, 6.24% gc time)


[32mProgress:  44%|██████████████████                       |  ETA: 3:50:14[39m

  7.892541 seconds (8.19 M allocations: 723.032 MiB, 3.30% gc time)


[32mProgress:  45%|██████████████████                       |  ETA: 3:46:28[39m

  8.031142 seconds (8.36 M allocations: 738.836 MiB, 3.33% gc time)


[32mProgress:  46%|███████████████████                      |  ETA: 3:43:28[39m

  8.254630 seconds (8.57 M allocations: 756.829 MiB, 3.17% gc time)


[32mProgress:  47%|███████████████████                      |  ETA: 3:39:38[39m

  8.515062 seconds (8.74 M allocations: 772.119 MiB, 3.30% gc time)


[32mProgress:  48%|████████████████████                     |  ETA: 3:36:11[39m

  8.656666 seconds (8.94 M allocations: 789.600 MiB, 3.28% gc time)


[32mProgress:  49%|████████████████████                     |  ETA: 3:32:37[39m

  8.965173 seconds (9.20 M allocations: 813.126 MiB, 3.37% gc time)


[32mProgress:  50%|████████████████████                     |  ETA: 3:29:49[39m

  9.090050 seconds (9.44 M allocations: 834.372 MiB, 3.46% gc time)


[32mProgress:  51%|█████████████████████                    |  ETA: 3:26:40[39m

  9.226113 seconds (9.61 M allocations: 849.323 MiB, 3.38% gc time)


[32mProgress:  52%|█████████████████████                    |  ETA: 3:23:14[39m

  9.411492 seconds (9.80 M allocations: 866.311 MiB, 3.34% gc time)


[32mProgress:  53%|██████████████████████                   |  ETA: 3:18:43[39m

  9.617954 seconds (10.02 M allocations: 885.558 MiB, 3.33% gc time)


[32mProgress:  54%|██████████████████████                   |  ETA: 3:15:41[39m

  9.996870 seconds (10.35 M allocations: 914.015 MiB, 3.36% gc time)


[32mProgress:  55%|███████████████████████                  |  ETA: 3:13:22[39m

  9.971648 seconds (10.38 M allocations: 917.037 MiB, 3.31% gc time)


[32mProgress:  56%|███████████████████████                  |  ETA: 3:07:15[39m

 10.360004 seconds (10.67 M allocations: 942.470 MiB, 3.21% gc time)


[32mProgress:  57%|███████████████████████                  |  ETA: 3:04:04[39m

 10.416609 seconds (10.87 M allocations: 960.322 MiB, 3.18% gc time)


[32mProgress:  58%|████████████████████████                 |  ETA: 3:00:10[39m

 10.719111 seconds (11.11 M allocations: 980.596 MiB, 3.37% gc time)


[32mProgress:  59%|████████████████████████                 |  ETA: 2:56:32[39m

 10.790225 seconds (11.23 M allocations: 991.570 MiB, 3.19% gc time)


[32mProgress:  60%|█████████████████████████                |  ETA: 2:51:12[39m

 11.246843 seconds (11.48 M allocations: 1013.149 MiB, 3.25% gc time)


[32mProgress:  61%|█████████████████████████                |  ETA: 2:47:03[39m

 11.318890 seconds (11.71 M allocations: 1.009 GiB, 3.38% gc time)


[32mProgress:  62%|█████████████████████████                |  ETA: 2:43:36[39m

 11.416279 seconds (11.89 M allocations: 1.024 GiB, 3.32% gc time)


[32mProgress:  63%|██████████████████████████               |  ETA: 2:39:28[39m

 11.785841 seconds (12.15 M allocations: 1.047 GiB, 3.29% gc time)


[32mProgress:  64%|██████████████████████████               |  ETA: 2:36:27[39m

 11.890621 seconds (12.34 M allocations: 1.063 GiB, 3.39% gc time)


[32mProgress:  65%|███████████████████████████              |  ETA: 2:32:11[39m

 12.516134 seconds (12.55 M allocations: 1.082 GiB, 3.16% gc time)


[32mProgress:  66%|███████████████████████████              |  ETA: 2:28:16[39m

 12.295779 seconds (12.76 M allocations: 1.100 GiB, 3.32% gc time)


[32mProgress:  67%|███████████████████████████              |  ETA: 2:23:59[39m

 12.490320 seconds (12.98 M allocations: 1.118 GiB, 3.26% gc time)


[32mProgress:  68%|████████████████████████████             |  ETA: 2:19:35[39m

 12.678906 seconds (13.15 M allocations: 1.133 GiB, 3.23% gc time)


[32mProgress:  69%|████████████████████████████             |  ETA: 2:15:07[39m

 13.040319 seconds (13.34 M allocations: 1.149 GiB, 3.29% gc time)


[32mProgress:  70%|█████████████████████████████            |  ETA: 2:10:33[39m

 12.992691 seconds (13.50 M allocations: 1.163 GiB, 3.37% gc time)


[32mProgress:  71%|█████████████████████████████            |  ETA: 2:06:19[39m

 13.296536 seconds (13.81 M allocations: 1.190 GiB, 3.24% gc time)


[32mProgress:  72%|██████████████████████████████           |  ETA: 2:02:37[39m

 13.676808 seconds (14.00 M allocations: 1.206 GiB, 3.32% gc time)


[32mProgress:  73%|██████████████████████████████           |  ETA: 1:58:19[39m

 13.757951 seconds (14.26 M allocations: 1.229 GiB, 3.38% gc time)


[32mProgress:  74%|██████████████████████████████           |  ETA: 1:54:40[39m

 14.016259 seconds (14.52 M allocations: 1.251 GiB, 3.38% gc time)


[32mProgress:  75%|███████████████████████████████          |  ETA: 1:50:44[39m

 14.195636 seconds (14.71 M allocations: 1.267 GiB, 3.22% gc time)


[32mProgress:  76%|███████████████████████████████          |  ETA: 1:46:37[39m

 14.364234 seconds (14.97 M allocations: 1.290 GiB, 3.18% gc time)


[32mProgress:  77%|████████████████████████████████         |  ETA: 1:42:33[39m

 14.640224 seconds (15.17 M allocations: 1.307 GiB, 3.44% gc time)


[32mProgress:  78%|████████████████████████████████         |  ETA: 1:37:50[39m

 14.767246 seconds (15.37 M allocations: 1.324 GiB, 3.35% gc time)


[32mProgress:  79%|████████████████████████████████         |  ETA: 1:33:18[39m

 15.029000 seconds (15.64 M allocations: 1.347 GiB, 3.37% gc time)


[32mProgress:  80%|█████████████████████████████████        |  ETA: 1:29:11[39m

 15.304580 seconds (15.88 M allocations: 1.368 GiB, 3.28% gc time)


[32mProgress:  81%|█████████████████████████████████        |  ETA: 1:24:47[39m

 15.692534 seconds (16.01 M allocations: 1.379 GiB, 4.70% gc time)


[32mProgress:  82%|██████████████████████████████████       |  ETA: 1:20:12[39m

 15.613574 seconds (16.28 M allocations: 1.402 GiB, 3.28% gc time)


[32mProgress:  83%|██████████████████████████████████       |  ETA: 1:15:53[39m

 15.922398 seconds (16.48 M allocations: 1.419 GiB, 3.19% gc time)


[32mProgress:  84%|██████████████████████████████████       |  ETA: 1:11:35[39m

 16.046561 seconds (16.72 M allocations: 1.440 GiB, 3.27% gc time)


[32mProgress:  85%|███████████████████████████████████      |  ETA: 1:07:10[39m

 16.215868 seconds (16.90 M allocations: 1.455 GiB, 3.31% gc time)


[32mProgress:  86%|███████████████████████████████████      |  ETA: 1:02:42[39m

 16.406325 seconds (17.11 M allocations: 1.473 GiB, 3.29% gc time)


[32mProgress:  87%|████████████████████████████████████     |  ETA: 0:58:11[39m

 16.981960 seconds (17.32 M allocations: 1.491 GiB, 3.14% gc time)


[32mProgress:  88%|████████████████████████████████████     |  ETA: 0:53:46[39m

 16.836645 seconds (17.48 M allocations: 1.505 GiB, 3.28% gc time)


[32mProgress:  89%|████████████████████████████████████     |  ETA: 0:49:12[39m

 17.018708 seconds (17.74 M allocations: 1.527 GiB, 3.30% gc time)


[32mProgress:  90%|█████████████████████████████████████    |  ETA: 0:44:51[39m

 17.535522 seconds (17.98 M allocations: 1.548 GiB, 3.31% gc time)


[32mProgress:  91%|█████████████████████████████████████    |  ETA: 0:40:31[39m

 17.517764 seconds (18.13 M allocations: 1.561 GiB, 3.25% gc time)


[32mProgress:  92%|██████████████████████████████████████   |  ETA: 0:36:01[39m

 17.543769 seconds (18.28 M allocations: 1.574 GiB, 3.30% gc time)




 17.754233 seconds (18.45 M allocations: 1.589 GiB, 3.40% gc time)


[32mProgress:  94%|███████████████████████████████████████  |  ETA: 0:26:59[39m

 17.953571 seconds (18.67 M allocations: 1.608 GiB, 3.18% gc time)


[32mProgress:  95%|███████████████████████████████████████  |  ETA: 0:22:32[39m

 18.154138 seconds (18.90 M allocations: 1.627 GiB, 3.34% gc time)


[32mProgress:  96%|███████████████████████████████████████  |  ETA: 0:18:02[39m

 18.587294 seconds (19.15 M allocations: 1.649 GiB, 3.28% gc time)


[32mProgress:  97%|████████████████████████████████████████ |  ETA: 0:13:34[39m

 18.893004 seconds (19.40 M allocations: 1.671 GiB, 3.32% gc time)


[32mProgress:  98%|████████████████████████████████████████ |  ETA: 0:09:05[39m

 18.813435 seconds (19.62 M allocations: 1.689 GiB, 3.27% gc time)


[32mProgress:  99%|█████████████████████████████████████████|  ETA: 0:04:32[39m

 19.092336 seconds (19.83 M allocations: 1.708 GiB, 3.19% gc time)


[32mProgress: 100%|█████████████████████████████████████████| Time: 7:33:09[39m


In [19]:
println("finished")

finished
