In [None]:
using Revise

In [None]:
using RigidBodyDynamics
using LearningMPC
using Gurobi
using DrakeVisualizer
using CoordinateTransformations
using ProgressMeter
using MLDataPattern
using JLD2
using ProfileView
DrakeVisualizer.any_open_windows() || DrakeVisualizer.new_window()

In [None]:
import CartPoles
import Nets

In [None]:
reload("CartPoles")

In [None]:
const Sample = Tuple{Vector{Float64}, Matrix{Float64}}

In [None]:
cartpole = CartPoles.CartPole()
mechanism = cartpole.mechanism
basevis = Visualizer()[:cartpole]
delete!(basevis)
setgeometry!(basevis, cartpole)

train_data = Sample[]
test_data = Sample[]
validation_data = Sample[]

x_to_u = AffineMap(eye(4), zeros(4))
v_to_y = AffineMap(diagm([20., 0]), zeros(2))

widths = [4, 16, 16, 8, 2]
activation = Nets.leaky_relu

In [None]:
cartpole_mpc_params = CartPoles.CartPoleMPCParams(Δt = 0.04, gap=1e-2)
xstar = MechanismState(mechanism, zeros(2), zeros(2))

lqrsol = CartPoles.LQRSolution(xstar, zeros(num_velocities(xstar)), cartpole_mpc_params.Q, cartpole_mpc_params.R)
lqr_controller = CartPoles.LQRController(lqrsol)

In [None]:
params, widths, x_to_u, v_to_y = jldopen("cart-pole-dagger-0.04-params.jld2", "r") do file
    params = file["params"]
    widths = file["widths"]
    x_to_u = file["x_to_u"]
    v_to_y = file["v_to_y"]
    params, widths, x_to_u, v_to_y
end;
net = Nets.Net(Nets.Params(widths, params), activation, x_to_u, v_to_y)

In [None]:
net_controller = x -> begin
    Nets.predict(net, state_vector(x))
end

mpc_controller = CartPoles.MPCController(cartpole, cartpole_mpc_params, xstar, lqrsol, [net_controller, lqr_controller]);
mpc_controller.callback = (x, results) -> begin
    if !isnull(results.lcp_updates)
#         playback(basevis[:robot], get(results.lcp_updates), mpc_controller.params.Δt)
        if !isnull(results.jacobian)
            xv = state_vector(x)
            yJ = hcat(get(results.lcp_updates)[1].input, get(results.jacobian))
            r = rand()
            if r < p_train
                push!(train_data, (xv, yJ))
            elseif r < p_train + p_validate
                push!(validation_data, (xv, yJ))
            else
                push!(test_data, (xv, yJ))
            end
        end
    end
end

In [None]:
x0 = MechanismState{Float64}(mechanism)

In [None]:
q0 = [2 * (rand() - 0.5), π * (rand() - 0.5)]
v0 = (rand(2) .- 0.5);

In [None]:
configuration(x0) .= q0
velocity(x0) .= v0
results_net = LCPSim.simulate(x0, net_controller, cartpole.environment, cartpole_mpc_params.Δt, 300, GurobiSolver(OutputFlag=0));

In [None]:
set_configuration!(x0, configuration(results_net[1].state))
settransform!(basevis[:robot], x0)

playback(basevis[:robot], results_net, 0.5 * cartpole_mpc_params.Δt)

In [None]:
online_params = CartPoles.CartPoleMPCParams(Δt = 0.04, gap=1e-2, horizon=5)

online_controller = CartPoles.OnlineMPCController(cartpole, 
    cartpole_mpc_params, xstar, lqrsol, [net_controller, lqr_controller]);

In [None]:
online_controller(x0)

In [None]:
configuration(x0) .= q0
velocity(x0) .= v0
results_online = LCPSim.simulate(x0, online_controller, cartpole.environment, cartpole_mpc_params.Δt, 10, GurobiSolver(OutputFlag=0));

In [None]:
set_configuration!(x0, configuration(results_net[1].state))
settransform!(basevis[:robot], x0)

playback(basevis[:robot], results_net, 0.5 * cartpole_mpc_params.Δt)

In [None]:
LCPSim.simulate(x0, net_controller, cartpole.environment, cartpole_mpc_params.Δt, 50, GurobiSolver(OutputFlag=0));
Profile.clear()
@time @profile LCPSim.simulate(x0, net_controller, cartpole.environment, cartpole_mpc_params.Δt, 50, GurobiSolver(Gurobi.Env(), OutputFlag=0));
ProfileView.view()

In [None]:
Profile.clear()
@time @profile online_controller(x0)
ProfileView.view()