In [None]:
using Revise

In [None]:
using MeshCatMechanisms
using MeshCat
using RigidBodyDynamics
using RigidBodyDynamics.Contact
using Flux
using ProgressMeter
using MLDataPattern
using JLD2
using Plots; gr()
using LCPSim
using LearningMPC
using LearningMPC.Models
using Blink

In [None]:
vis = Visualizer()
open(vis, Window())

In [None]:
using GeometryTypes

In [None]:
delete!(vis)
robot = CartPole(add_contacts=false, parameter_states=["wall_distance"])
mvis = MechanismVisualizer(robot, vis)

setobject!(vis[:leftwall], HyperRectangle(Vec(-0.001, -0.5, -1.5), Vec(0.002, 1.0, 3.0)))
setobject!(vis[:rightwall], HyperRectangle(Vec(-0.001, -0.5, -1.5), Vec(0.002, 1.0, 3.0)))

In [None]:
using CoordinateTransformations

In [None]:
function move_walls_in_vis(mvis::MechanismVisualizer, distance)
    settransform!(mvis.visualizer[:leftwall], Translation(-distance, 0, 0))
    settransform!(mvis.visualizer[:rightwall], Translation(distance, 0, 0))
end

In [None]:
params = MPCParams(robot)
lqrsol = LQRSolution(robot, params)
net, loss = LearningMPC.interval_net([5, 24, 24, 1]; regularization=1e-6, penalty = identity)
optimizer = Flux.ADAM(Flux.params(net); decay=1e-8)

net_cost = LearningMPC.LearnedCost(lqrsol, net)

net_mpc_params = MPCParams(robot)
net_mpc_params.horizon = 1
net_mpc_controller = MPCController(robot, net_mpc_params, net_cost, [lqrsol]);

full_mpc_controller = MPCController(robot, params, lqrsol, [lqrsol, net_mpc_controller]);

In [None]:
function move_wall(contact::Tuple{RigidBody, Point3D, LCPSim.Obstacle}, new_origin::Point3D)
    (contact[1], contact[2], move_wall(contact[3], new_origin))
end

function move_wall(obs::LCPSim.Obstacle, new_origin::Point3D)
    LCPSim.Obstacle(obs.interior, move_wall(obs.contact_face, new_origin), 
        obs.μ, obs.contact_basis)
end

function move_wall(halfspace::HalfSpace3D, new_origin::Point3D)
    HalfSpace3D(new_origin, halfspace.outward_normal)
end
 

In [None]:
sample_sink = LearningMPC.MPCSampleSink(keep_nulls=false; 
    lqrsol=lqrsol, 
    lqr_warmstart_index=1,
    learned_warmstart_index=2
)

playback_sink = LearningMPC.PlaybackSink{Float64}(mvis)

full_mpc_controller.callback = LearningMPC.multiplex!(
    sample_sink,
    playback_sink
)

live_viewer = LearningMPC.live_viewer(mvis)

dagger_controller = LearningMPC.multiplex!(
    LearningMPC.dagger_controller(
        full_mpc_controller,
        net_mpc_controller),
    live_viewer
    )

collect_into! = let x_init = nominal_state(robot), x0 = MechanismState{Float64}(robot.mechanism), sink = sample_sink
    function (data::Vector{<:LearningMPC.Sample}, σv, wall_distance)
        robot.environment.contacts .= 
            [move_wall(robot.environment.contacts[1], Point3D(robot.environment.contacts[1][3].contact_face.point.frame, -wall_distance, 0.0, 0.0)),
             move_wall(robot.environment.contacts[2], Point3D(robot.environment.contacts[2][3].contact_face.point.frame, wall_distance, 0.0, 0.0))]
        empty!(sink)
        move_walls_in_vis(mvis, wall_distance)
        LearningMPC.randomize!(x0, x_init, wall_distance / 3, σv)
        set_configuration!(x0, findjoint(mechanism(robot), "wall_distance"), [wall_distance])
        results = LCPSim.simulate(x0, 
            dagger_controller,
            robot.environment, params.Δt, 100, 
            params.lcp_solver;
            termination=x -> !(-π/4 <= configuration(x)[2] <= π/4))
        append!(data, sink.samples)
    end
end

In [None]:
library_file = "library.jld2"
dataset_file = "cartpole-interval-wall-params.jld2"

In [None]:
if isfile(library_file)
    all_training_data, all_validation_data = jldopen(library_file) do file
        file["training"], file["testing"]
    end
else
    all_training_data = Vector{Tuple{eltype(sample_sink.samples)}}()
    all_validation_data = Vector{Tuple{eltype(sample_sink.samples)}}()
end;


In [None]:
datasets = Vector{LearningMPC.Dataset{Float64}}()
losses = Vector{Tuple{Float64, Float64}}()

N_iter = 100
σv = 5.0

@showprogress for i in 1:N_iter
    dataset = LearningMPC.Dataset(lqrsol)
    wall_distance = rand(linspace(0.5, 2.0, 100))
    collect_into!(dataset.training_data, σv, wall_distance)
    wall_distance = rand(linspace(0.5, 2.0, 100))
    collect_into!(dataset.training_data, 0.1 * σv, wall_distance)
    wall_distance = rand(linspace(0.5, 2.0, 100))
    collect_into!(dataset.training_data, 0.01 * σv, wall_distance)
    
    filter!(dataset.training_data) do sample
        all(isfinite, sample.input)
    end

    
    new_samples = tuple.(dataset.training_data)
    if !isempty(new_samples)
        new_training, new_validation = splitobs(shuffleobs(new_samples); at=0.8)
        append!(all_training_data, new_training)
        append!(all_validation_data, new_validation)
    end

    @time for i in 1:50
        Flux.train!(loss, shuffleobs(all_training_data), optimizer)
        push!(losses, 
            (mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_training_data),
             mean(xy -> Flux.Tracker.data(loss(xy...)), 
                  all_validation_data)))
    end
    push!(datasets, dataset)
    
    jldopen(dataset_file, "w") do file
        file["datasets"] = datasets
        file["net"] = net
        file["lqrsol"] = lqrsol
        file["mpc_params"] = Dict(
            "Δt" => params.Δt, 
            "horizon" => params.horizon,
        )
        file["losses"] = losses
    end
    
    jldopen(library_file, "w") do file
        file["training"] = all_training_data
        file["testing"] = all_validation_data
    end
    
    plt = plot(first.(losses), label="training")
    plot!(plt, last.(losses), label="validation")
    ylims!(plt, (1, ylims(plt)[2]))
    display(plt)

end