In [None]:
using Revise

In [None]:
using RigidBodyDynamics
using DrakeVisualizer
DrakeVisualizer.any_open_windows() || DrakeVisualizer.new_window()
using RigidBodyTreeInspector
using Gurobi
import StochasticOptimization
using Plots
using JLD2
using ProgressMeter

In [None]:
import LCPSim
import LearningMPC
import BoxValkyries
import Nets

In [None]:
# reload("LCPSim")
reload("LearningMPC")
reload("BoxValkyries")

In [None]:
boxval = BoxValkyries.BoxValkyrie()
mechanism = boxval.mechanism
xstar = BoxValkyries.nominal_state(boxval)

basevis = Visualizer()[:box_robot]
setgeometry!(basevis, boxval)
settransform!(basevis[:robot], xstar)

mpc_params = LearningMPC.MPCParams(
    Δt=0.04,
    horizon=1,
    mip_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0, TimeLimit=120, MIPGap=1e-1, MIPGapAbs=5, FeasibilityTol=1e-3),
    lcp_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0))

feet = findbody.(mechanism, ["rf", "lf"])
contacts = [Point3D(default_frame(body), 0., 0, 0) for body in feet]
Q, R = BoxValkyries.default_costs(xstar)
lqrsol = LearningMPC.LQRSolution(xstar, Q, R, contacts, mpc_params.Δt)
LearningMPC.zero_element!(lqrsol, 1)

hidden_widths = [32, 32, 32, 32]
activation = Nets.leaky_relu
net = LearningMPC.control_net(mechanism, hidden_widths, activation)

net_controller = x -> Nets.predict(net, state_vector(x))

mpc_controller = LearningMPC.MPCController(mechanism, 
    boxval.environment, mpc_params, lqrsol, 
    [net_controller, lqrsol]);

sample_sink = LearningMPC.MPCSampleSink{Float64}()
playback_sink = LearningMPC.PlaybackSink(basevis[:robot], mpc_params.Δt)

mpc_controller.callback = LearningMPC.call_each(
    sample_sink,
    playback_sink)

live_viewer = LearningMPC.live_viewer(mechanism, basevis[:robot])

dagger_controller = LearningMPC.call_each(
    LearningMPC.dagger_controller(
        mpc_controller,
        net_controller,
        0.2),
    live_viewer
    )

termination = x -> begin
    (configuration(x)[1] < -1 ||
     configuration(x)[2] < 0.5 || 
     configuration(x)[3] > 1.2 || 
     configuration(x)[3] < -1.2)
end

dataset = LearningMPC.Dataset(lqrsol)

updater = Nets.adam_updater(net)

loss = LearningMPC.sensitive_loss(net, 0.2)
adam_opts = Nets.AdamOpts(learning_rate=0.1e-4, batch_size=1)

function collect!(data::Vector{<:LearningMPC.Sample})
    empty!(sample_sink)
    LearningMPC.randomize!(x0, xstar, 0.1, 0.5)
    results = LCPSim.simulate(x0, 
        dagger_controller,
        boxval.environment, mpc_params.Δt, 100, 
        mpc_params.lcp_solver;
        termination=termination);
    append!(data, sample_sink.samples)
end

In [None]:
x0 = MechanismState{Float64}(mechanism)

In [None]:
all_losses(net, dataset) = (LearningMPC.training_loss(net, dataset),
                            LearningMPC.validation_loss(net, dataset))

In [None]:
losses = Tuple{Float64, Float64}[]
snapshots = LearningMPC.Snapshot{Float64}[]

@showprogress for i in 1:100
    for i in 1:2
        collect!(dataset.training_data)
    end
    collect!(dataset.testing_data)
    collect!(dataset.validation_data);
    
    for i in 1:5
        Nets.adam_update!(net.params.data, updater, loss, 
            LearningMPC.features.(dataset.training_data), adam_opts);
        push!(losses, all_losses(net, dataset))
    end
    
    push!(snapshots, LearningMPC.Snapshot(net.params.data, net))
    
    jldopen("box-val-improved.jld2", "w") do file
        file["dataset"] = dataset
        file["snapshots"] = snapshots
    end
    
    display(plot(plot(first.(losses)), plot(last.(losses))))
end

In [None]:

LearningMPC.randomize!(x0, xstar, 0.1, 0.5)
results = LCPSim.simulate(x0, 
    LearningMPC.call_each(net_controller, live_viewer),
    boxval.environment, mpc_params.Δt, 200, mpc_params.lcp_solver);