In [None]:
using Revise

In [None]:
using RigidBodyDynamics
using DrakeVisualizer
DrakeVisualizer.any_open_windows() || DrakeVisualizer.new_window()
using RigidBodyTreeInspector
using Gurobi
import StochasticOptimization
using Plots; gr()
using JLD2
using ProgressMeter

In [None]:
import LCPSim
import LearningMPC
import BoxValkyries
import Nets

In [None]:
robot = BoxValkyries.BoxValkyrie(false)
mechanism = robot.mechanism
xstar = BoxValkyries.nominal_state(robot)

basevis = Visualizer()[:cartpole]
setgeometry!(basevis, robot)
settransform!(basevis[:robot], xstar)

mpc_params = LearningMPC.MPCParams(
    Δt=0.05,
    horizon=10,
    mip_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0, TimeLimit=120, MIPGap=1e-1, FeasibilityTol=1e-3),
    lcp_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0))

Q, R = BoxValkyries.default_costs(robot)
lqrsol = LearningMPC.LQRSolution(xstar, Q, R, mpc_params.Δt)

hidden_widths = [32, 32, 32, 32]
activation = Nets.leaky_relu
net = LearningMPC.control_net(mechanism, hidden_widths, activation)

net_controller = x -> Nets.predict(net, state_vector(x))

mpc_controller = LearningMPC.MPCController(mechanism, 
    robot.environment, mpc_params, lqrsol, 
    [net_controller, lqrsol]);

sample_sink = LearningMPC.MPCSampleSink{Float64}()
playback_sink = LearningMPC.PlaybackSink(basevis[:robot], mpc_params.Δt)

mpc_controller.callback = LearningMPC.call_each(
    sample_sink,
    playback_sink)

live_viewer = LearningMPC.live_viewer(mechanism, basevis[:robot])

dagger_controller = LearningMPC.call_each(
    LearningMPC.dagger_controller(
        mpc_controller,
        net_controller,
        0.2),
    live_viewer
    )

termination = x -> begin
    (configuration(x)[2] < 0.5 || 
     configuration(x)[3] > π/4 ||
     configuration(x)[3] < -π/4)
end

dataset = LearningMPC.Dataset(lqrsol)

updater = Nets.adam_updater(net)

gradient_sensitivity = 0.2
loss = LearningMPC.sensitive_loss(net, gradient_sensitivity)
adam_opts = Nets.AdamOpts(learning_rate=1e-3, batch_size=1)

function collect_into!(data::Vector{<:LearningMPC.Sample})
    empty!(sample_sink)
    LearningMPC.randomize!(x0, xstar, 0.1, 0.5)
    results = LCPSim.simulate(x0, 
        dagger_controller,
        robot.environment, mpc_params.Δt, 30, 
        mpc_params.lcp_solver;
        termination=termination);
    append!(data, sample_sink.samples)
end

all_losses(net, dataset) = (LearningMPC.training_loss(net, dataset),
                            LearningMPC.validation_loss(net, dataset))

x0 = MechanismState{Float64}(mechanism)

In [None]:
losses = Tuple{Float64, Float64}[]
snapshots = LearningMPC.Snapshot{Float64}[]

@showprogress for i in 1:20
    for i in 1:2
        collect_into!(dataset.training_data)
    end
    collect_into!(dataset.testing_data)
    collect_into!(dataset.validation_data);
    
    for i in 1:5
        Nets.adam_update!(net.params.data, updater, loss, 
            LearningMPC.features.(dataset.training_data), adam_opts);
        push!(losses, all_losses(net, dataset))
    end
    
    push!(snapshots, LearningMPC.Snapshot(net.params.data, net))
    
    jldopen("box-val-$gradient_sensitivity.jld2", "w") do file
        file["dataset"] = dataset
        file["snapshots"] = snapshots
    end
    
    plt = plot(first.(losses), label="training")
    plot!(plt, last.(losses), label="validation")
    ylims!(plt, (0, ylims(plt)[2]))
    display(plt)
end

In [None]:
basename = "box-val-losses-sensitivity-$gradient_sensitivity"

In [None]:
jldopen("$basename.jld2", "w") do file
    file["losses"] = losses
end;

plt = plot(first.(losses), label="training", ylim=(0, max(maximum(first, losses), maximum(last, losses))))
plot!(plt, last.(losses), label="validation")
savefig(plt, "$basename.pdf")
savefig(plt, "$basename.png")
plt

In [None]:
plt = plot()

for sensitivity in [0.0, 0.2, 0.5, 0.8]
    ls = jldopen("box-val-losses-sensitivity-$sensitivity.jld2", "r") do file
        file["losses"]
    end

    plot!(plt, first.(ls), label="training-$sensitivity")
    plot!(plt, last.(ls), label="validation-$sensitivity")
end
ylims!(plt, (0, ylims(plt)[2]))
plt

In [None]:
LearningMPC.randomize!(x0, xstar, 0.1, 1.0)
results = LCPSim.simulate(x0, 
    net_controller,
    robot.environment, mpc_params.Δt, 200, mpc_params.lcp_solver);

In [None]:
LearningMPC.playback(basevis[:robot], results, mpc_params.Δt)