In [None]:
using Revise

In [None]:
using RigidBodyDynamics
using DrakeVisualizer
DrakeVisualizer.any_open_windows() || DrakeVisualizer.new_window()
using RigidBodyTreeInspector
using Gurobi
import StochasticOptimization
using Plots; gr()
using JLD2
using ProgressMeter

In [None]:
import LCPSim
import LearningMPC
import Hoppers
import Nets

In [None]:
robot = Hoppers.Hopper()
mechanism = robot.mechanism
xstar = Hoppers.nominal_state(robot)

basevis = Visualizer()[:hopper]
setgeometry!(basevis, robot)
settransform!(basevis[:robot], xstar)

mpc_params = LearningMPC.MPCParams(
    Δt=0.05,
    horizon=10,
    mip_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0, TimeLimit=120, MIPGap=1e-1, FeasibilityTol=1e-3),
    lcp_solver=GurobiSolver(Gurobi.Env(), OutputFlag=0))

Q, R = Hoppers.default_costs(robot)
foot = findbody(mechanism, "foot")
lqrsol = LearningMPC.LQRSolution(xstar, Q, R, mpc_params.Δt, [Point3D(default_frame(foot), 0., 0., 0.)])
lqrsol.S .= 1 ./ mpc_params.Δt .* Q

hidden_widths = [16, 8, 8, 8]
activation = Nets.leaky_relu
net = LearningMPC.control_net(mechanism, hidden_widths, activation)

net_controller = x -> Nets.predict(net, state_vector(x))

mpc_controller = LearningMPC.MPCController(mechanism, 
    robot.environment, mpc_params, lqrsol, 
    [net_controller, lqrsol]);

sample_sink = LearningMPC.MPCSampleSink{Float64}()
playback_sink = LearningMPC.PlaybackSink(basevis[:robot], 0.25 * mpc_params.Δt)
live_viewer = LearningMPC.live_viewer(mechanism, basevis[:robot])

mpc_controller.callback = LearningMPC.call_each(
    sample_sink,
#     playback_sink,
#     (args...) -> println("tick")
#     (x, results) -> live_viewer(x)
)


dagger_controller = LearningMPC.call_each(
    LearningMPC.dagger_controller(
        mpc_controller,
        net_controller,
        0.5),
    live_viewer
    )

termination = x -> false

dataset = LearningMPC.Dataset(lqrsol)

# updater = Nets.adam_updater(net)
# gradient_sensitivity = 0.2
# loss = LearningMPC.sensitive_loss(net, gradient_sensitivity)
# adam_opts = Nets.AdamOpts(learning_rate=1e-2, batch_size=1)

gradient_sensitivity = 0.2
learning_loss = LearningMPC.sensitive_loss(net, gradient_sensitivity)
adam_opts = Nets.AdamOpts(learning_rate=2e-2, batch_size=1)
optimizer = Nets.AdamOptimizer(learning_loss, adam_opts, net, 
    zeros(length(net.input_tform.v)), zeros(length(net.output_tform.v), 1 + length(net.input_tform.v)))


x0 = MechanismState{Float64}(mechanism)

x_init = MechanismState{Float64}(mechanism)
set_configuration!(x_init, [1.0, 1.0])
set_velocity!(x_init, [0., 0.])
# x_init = xstar

function collect_into!(data::Vector{<:LearningMPC.Sample})
    empty!(sample_sink)
    LearningMPC.randomize!(x0, x_init, 0.5, 1.0)
    if configuration(x0)[1] - configuration(x0)[2] < 0
        set_configuration!(x0, [configuration(x0)[2], configuration(x0)[2]])
    end
    results = LCPSim.simulate(x0, 
        dagger_controller,
        robot.environment, mpc_params.Δt, 50, 
        mpc_params.lcp_solver;
        termination=termination);
    samples_to_keep = filter(1:length(sample_sink.samples)) do i
        for j in (i+1):length(sample_sink.samples)
            if norm(sample_sink.samples[j].state .- sample_sink.samples[i].state) < 1e-2
                return false
            end
        end
        return true
    end
    append!(data, sample_sink.samples[samples_to_keep])
#     append!(data, sample_sink.samples)
end

all_losses(net, dataset) = (LearningMPC.training_loss(net, dataset),
                            LearningMPC.validation_loss(net, dataset))

In [None]:
length(net.params.data)

In [None]:
# LearningMPC.randomize!(x0, x_init, 0.0, 0.0)
# results = LCPSim.simulate(x0, 
#     mpc_controller,
#     robot.environment, mpc_params.Δt, 50, mpc_params.lcp_solver);

In [None]:
# LearningMPC.playback(basevis[:robot], results, mpc_params.Δt)

In [None]:
losses = Tuple{Float64, Float64}[]
snapshots = LearningMPC.Snapshot{Float64}[]
gr()

@showprogress for i in 1:10
    for i in 1:2
        collect_into!(dataset.training_data)
    end
    collect_into!(dataset.testing_data)
    collect_into!(dataset.validation_data);
    
    @showprogress for i in 1:10
        Nets.update!(net.params.data, optimizer, LearningMPC.features.(dataset.training_data))
        optimizer.opts.learning_rate *= (1 - 3e-2)
#             Nets.adam_update!(net.params.data, updater, loss, 
#                 LearningMPC.features.(dataset.training_data), adam_opts);
        push!(losses, all_losses(net, dataset))
    end
    
    push!(snapshots, LearningMPC.Snapshot(net.params.data, net))
    
    jldopen("hopper-$gradient_sensitivity.jld2", "w") do file
        file["dataset"] = dataset
        file["snapshots"] = snapshots
    end
    
    plt = plot(first.(losses), label="training", yscale=:log10)
    plot!(plt, last.(losses), label="validation")
    ylims!(plt, (1, ylims(plt)[2]))
    display(plt)
end

In [None]:
# @showprogress for i in 1:100
#     Nets.update!(net.params.data, optimizer, LearningMPC.features.(dataset.training_data))
#     optimizer.opts.learning_rate *= (1 - 3e-2)
#     push!(losses, all_losses(net, dataset))
# end

In [None]:
plt = plot(first.(losses), label="training", yscale=:log10)
plot!(plt, last.(losses), label="validation")
ylims!(plt, (1, ylims(plt)[2]))
plt

In [None]:
LearningMPC.randomize!(x0, x_init, 0.1, 0.5)
results = LCPSim.simulate(x0, 
    net_controller,
    robot.environment, mpc_params.Δt, 200, mpc_params.lcp_solver);

In [None]:
LearningMPC.playback(basevis[:robot], results, mpc_params.Δt)

In [None]:
plotly()

In [None]:
function slice(data)
    filter(data) do sample
        x = sample.state
        (abs(x[1] - x[2]) < 1e-1) && (abs(x[3] - x[4]) < 1e-1)
    end
end

In [None]:
plt = plot([s.state[1] for s in slice(dataset.training_data)], [s.state[3] for s in slice(dataset.training_data)],
     [s.uJ[2, 1] for s in slice(dataset.training_data)], line=nothing, marker=:dot, markersize=0.3)
surface!(plt, linspace(0, 2), linspace(-4, 4), (x, y) -> net([x, x, y, y])[2])
# plot!(plt, [s.state[1] for s in dataset.training_data], [s.state[3] for s in dataset.training_data],
#      [net(s.state)[2] for s in dataset.training_data], line=nothing, marker=:dot, markersize=0.3, markercolor=:red)
zlims!(plt, -10, 50)
plt