In [None]:
using Revise

In [None]:
using RigidBodyDynamics
using LearningMPC
using Gurobi
using DrakeVisualizer
using CoordinateTransformations
using Plots; gr()
using ProgressMeter
using MLDataPattern
using JLD2
DrakeVisualizer.any_open_windows() || DrakeVisualizer.new_window()

In [None]:
import CartPoles
import Nets

In [None]:
reload("CartPoles")

In [None]:
const Sample = Tuple{Vector{Float64}, Matrix{Float64}}

In [None]:
cartpole = CartPoles.CartPole()
mechanism = cartpole.mechanism
basevis = Visualizer()[:cartpole]
delete!(basevis)
setgeometry!(basevis, cartpole)

train_data = Sample[]
test_data = Sample[]
validation_data = Sample[]

x_to_u = AffineMap(eye(4), zeros(4))
v_to_y = AffineMap(diagm([20., 0]), zeros(2))

widths = [4, 16, 16, 8, 2]
activation = Nets.leaky_relu
function sensitive_loss(λ)
    q = [1.0-λ λ λ λ λ]
    (params, x, y) -> sum(abs2, 
        q .* (Nets.predict_sensitivity(Nets.Net(Nets.Params(widths, params), activation, x_to_u, v_to_y), x) .- y))
end
start_params = 0.1 * randn(Nets.Params{Float64}, widths).data;

p_train = 0.6
p_validate = 0.2

In [None]:
cartpole_mpc_params = CartPoles.CartPoleMPCParams(Δt = 0.04, gap=1e-2)
xstar = MechanismState(mechanism, zeros(2), zeros(2))

lqrsol = CartPoles.LQRSolution(xstar, zeros(num_velocities(xstar)), cartpole_mpc_params.Q, cartpole_mpc_params.R)
lqr_controller = CartPoles.LQRController(lqrsol)

In [None]:
params = copy(start_params)
net = Nets.Net(Nets.Params(widths, params), activation, x_to_u, v_to_y)

In [None]:
net_controller = x -> begin
    Nets.predict(net, state_vector(x))
end

mpc_controller = CartPoles.MPCController(cartpole, cartpole_mpc_params, xstar, lqrsol, [net_controller, lqr_controller]);
mpc_controller.callback = (x, results) -> begin
    if !isnull(results.lcp_updates)
#         playback(basevis[:robot], get(results.lcp_updates), mpc_controller.params.Δt)
        if !isnull(results.jacobian)
            xv = state_vector(x)
            yJ = hcat(get(results.lcp_updates)[1].input, get(results.jacobian))
            r = rand()
            if r < p_train
                push!(train_data, (xv, yJ))
            elseif r < p_train + p_validate
                push!(validation_data, (xv, yJ))
            else
                push!(test_data, (xv, yJ))
            end
        end
    end
end

train_loss = sensitive_loss(0.1)
validate_loss = train_loss
training_losses = Float64[]
validation_losses = Float64[]

p_mpc = 0.2

x_control = MechanismState{Float64}(mechanism)

controller = x -> begin
    set_configuration!(x_control, configuration(x))
    settransform!(basevis[:robot], x_control)
    if rand() < p_mpc
        return mpc_controller(x)
    else
        return net_controller(x)
    end
end

In [None]:
x0 = MechanismState{Float64}(mechanism)

fname = "cart-pole-dagger-0.04.jld2"

@showprogress for i in 1:1000
    set_configuration!(x0, [2 * (rand() - 0.5), π * (rand() - 0.5)])
    set_velocity!(x0, (rand(2) .- 0.5))
    termination = x -> (configuration(x)[2] > π/2 || configuration(x)[2] < -π/2)
    results_net = LCPSim.simulate(x0, controller, cartpole.environment, cartpole_mpc_params.Δt, 100, GurobiSolver(OutputFlag=0);
        termination=termination)
    Nets.adam!(train_loss, params, shuffleobs(train_data), Nets.AdamOpts(learning_rate=0.01 * 0.999^i, batch_size=min(10, length(train_data))))
    push!(training_losses, mean(xy -> validate_loss(params, xy[1], xy[2]), train_data))
    push!(validation_losses, mean(xy -> validate_loss(params, xy[1], xy[2]), validation_data))
    
    @show training_losses[end]
    @show validation_losses[end]
    
    jldopen(fname, "w") do file
        file["train_data"] = train_data
        file["test_data"] = test_data
        file["validation_data"] = validation_data
    end
    
    jldopen("cart-pole-dagger-0.04-params.jld2", "w") do file
        file["params"] = params
        file["widths"] = widths
        file["x_to_u"] = x_to_u
        file["v_to_y"] = v_to_y
        file["training_losses"] = training_losses
        file["validation_losses"] = validation_losses
    end;
end 

In [None]:
set_configuration!(x0, [2 * (rand() - 0.5), π * (rand() - 0.5)])
set_velocity!(x0, (rand(2) .- 0.5))
results_net = LCPSim.simulate(x0, net_controller, cartpole.environment, cartpole_mpc_params.Δt, 300, GurobiSolver(OutputFlag=0));

In [None]:
results_net = LCPSim.simulate(x0, lqr_controller, cartpole.environment, cartpole_mpc_params.Δt, 300, GurobiSolver(OutputFlag=0));

In [None]:
set_configuration!(x0, configuration(results_net[1].state))
settransform!(basevis[:robot], x0)

In [None]:
playback(basevis[:robot], results_net, 2 * cartpole_mpc_params.Δt)

In [None]:
jldopen(fname, "r") do file
    length(file["train_data"])
end