In [1]:
using Revise

In [2]:
using MeshCatMechanisms
using MeshCat
using RigidBodyDynamics
using Gurobi
using Flux
using ProgressMeter
using MLDataPattern
using JLD2
using Plots; gr()

Plots.GRBackend()

In [3]:
import LCPSim
import LearningMPC
import BoxValkyries
reload("LearningMPC")
reload("BoxValkyries")

[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/rdeits/locomotion/explorations/learning-mpc/packages/lib/v0.6/LearningMPC.ji for module LearningMPC.
[39m[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/rdeits/locomotion/explorations/learning-mpc/packages/lib/v0.6/BoxValkyries.ji for module BoxValkyries.


In [4]:
robot = BoxValkyries.BoxValkyrie();
mvis = MechanismVisualizer(robot)
IJuliaCell(mvis)

Listening on 127.0.0.1:7000...
zmq_url=tcp://127.0.0.1:6000
web_url=http://127.0.0.1:7000/static/


In [5]:
function create_net()
    net = Chain(
        Dense(20, 32, elu),
        Dense(32, 32, elu),
        Dense(32, 10)
    )
    loss = (x, y) -> Flux.mse(net(x), y)
    net, loss
end

create_net (generic function with 1 method)

In [11]:
net, loss = create_net()
net_controller = x -> Flux.Tracker.data(net(state_vector(x)))
net_params = params(net)
optimizer = Flux.Optimise.ADADelta(net_params)

xstar = BoxValkyries.nominal_state(robot)

mpc_params = LearningMPC.MPCParams(
    Δt=0.05,
    horizon=10,
    mip_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0, 
        TimeLimit=60, 
        MIPGap=1e-1, 
        MIPGapAbs=5e-1,
        FeasibilityTol=1e-3),
    lcp_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0))

Q, R = BoxValkyries.default_costs(robot)
feet = findbody.(robot.mechanism, ["lf", "rf"])
lqrsol = LearningMPC.LQRSolution(xstar, Q, R, mpc_params.Δt, Point3D.(default_frame.(feet), 0., 0., 0.))
lqrsol.S .= 1 ./ mpc_params.Δt .* Q

mpc_controller = LearningMPC.MPCController(robot.mechanism, 
    robot.environment, mpc_params, lqrsol, 
    [net_controller, lqrsol]);

sample_sink = LearningMPC.MPCSampleSink{Float64}(true)
playback_sink = LearningMPC.PlaybackSink(mvis, mpc_params.Δt)

mpc_controller.callback = LearningMPC.call_each(
    sample_sink,
    playback_sink,
)

live_viewer = LearningMPC.live_viewer(mvis)


dagger_controller = LearningMPC.call_each(
    LearningMPC.dagger_controller(
        mpc_controller,
        net_controller,
        0.2),
    live_viewer
    )

termination = x -> false

dataset = LearningMPC.Dataset(lqrsol)

function collect_into!(data::Vector{<:LearningMPC.Sample})
    empty!(sample_sink)
    LearningMPC.randomize!(x0, xstar, 0.0, 1.5)
    results = LCPSim.simulate(x0, 
        dagger_controller,
        robot.environment, mpc_params.Δt, 30, 
        mpc_params.lcp_solver;
        termination=termination);
    append!(data, sample_sink.samples)
end

x0 = MechanismState{Float64}(robot.mechanism)

features(s::LearningMPC.Sample) = (s.state, s.uJ[:, 1])

Academic license - for non-commercial use only
Academic license - for non-commercial use only


features (generic function with 1 method)

In [12]:
datasets = Vector{LearningMPC.Dataset{Float64}}()
all_training_data = Vector{Tuple{Vector{Float64}, Vector{Float64}}}()
all_validation_data = Vector{Tuple{Vector{Float64}, Vector{Float64}}}()
losses = Vector{Tuple{Float64, Float64}}()

for i in 1:20
    dataset = LearningMPC.Dataset(lqrsol)
    for i in 1:2
        collect_into!(dataset.training_data)
    end
    collect_into!(dataset.testing_data)
    collect_into!(dataset.validation_data)
    append!(all_training_data, features.(dataset.training_data))
    append!(all_validation_data, features.(dataset.validation_data))
    
    @showprogress for i in 1:20
        Flux.train!(loss, shuffleobs(all_training_data), optimizer)
        push!(losses, 
            (mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_training_data),
             mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_validation_data)))
    end
    push!(datasets, dataset)
    
    jldopen("box-val.jld2", "w") do file
        file["datasets"] = datasets
        file["net"] = net
        file["lqrsol"] = lqrsol
        file["mpc_params"] = Dict(
            "Δt" => mpc_params.Δt,
            "horizon" => mpc_params.horizon,
        )
        file["all_training_data"] = all_training_data
        file["all_validation_data"] = all_validation_data
        file["losses"] = losses
    end
    
    plt = plot(first.(losses), label="training")
    plot!(plt, last.(losses), label="validation")
    ylims!(plt, (0, ylims(plt)[2]))
    display(plt)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m


captured: ErrorException("Unrecognized solution status: loaded")
captured: ErrorException("Unrecognized solution status: loaded")
captured: ErrorException("Unrecognized solution status: loaded")


[32mProgress:  45%|██████████████████                       |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m


captured: ErrorException("Unrecognized solution status: loaded")


[32mProgress:  40%|████████████████                         |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m
[32mProgress:  60%|█████████████████████████                |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m
[32mProgress:  60%|█████████████████████████                |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m
[32mProgress:  45%|██████████████████                       |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  65%|███████████████████████████              |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  60%|█████████████████████████                |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  75%|███████████████████████████████          |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  60%|█████████████████████████                |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  75%|███████████████████████████████          |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  60%|█████████████████████████                |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  65%|███████████████████████████              |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  70%|█████████████████████████████            |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  90%|█████████████████████████████████████    |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:00:00[39m

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m


In [41]:
x0 = MechanismState{Float64}(robot.mechanism)
LearningMPC.randomize!(x0, xstar, 0.5, 2.0)
results = LCPSim.simulate(x0, 
    net_controller,
    robot.environment, mpc_params.Δt, 100, 
    mpc_params.lcp_solver,
    termination=termination);
LearningMPC.playback(mvis, results, 0.05)

In [None]:
LearningMPC.playback(mvis, results, 0.05)

In [28]:
    @showprogress for i in 1:100
        Flux.train!(loss, shuffleobs(all_training_data), optimizer)
        push!(losses, 
            (mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_training_data),
             mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_validation_data)))
    end
    plt = plot(first.(losses), label="training")
    plot!(plt, last.(losses), label="validation")
    ylims!(plt, (0, ylims(plt)[2]))
    plt

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:06[39m
