In [None]:
using Revise

In [None]:
using MeshCatMechanisms
using MeshCat
using RigidBodyDynamics
using RigidBodySim
using Flux
using JLD2
using Plots; gr()
using LCPSim
using LearningMPC
using LearningMPC.Models
using DataFrames
using Blink

In [None]:
vis = Visualizer()
open(vis, Window())

In [None]:
delete!(vis)

robot = CartPole(add_contacts=true)
mvis = MechanismVisualizer(robot, vis)

In [None]:
interval_net, lqrsol = jldopen("cartpole-interval.jld2") do file
    file["net"], file["lqrsol"]
end
interval_net_mpc = LearningMPC.mpc_controller(robot, LearningMPC.LearnedCost(lqrsol, interval_net), Δt=0.025)

full_mpc_controller = MPCController(robot, MPCParams(robot), lqrsol, [lqrsol]);

policy_net = jldopen("cartpole-mimic.jld2") do file
    file["net"]
end

policy_net_controller = let net = policy_net
    function (τ, t, x)
        τ .= Flux.Tracker.data(net(LearningMPC.qv(x)))
    end
end

lqr_mpc = LearningMPC.mpc_controller(robot, lqrsol, Δt=0.025);

In [None]:
@load "library.jld2" training

In [None]:
training[indmax([d[1].mip.objective_bound for d in training])]

In [None]:
xx = linspace(-8, 8, 51)
yy = linspace(-π, π, 51)
getstate(x, y) = [0, 0, x, y]
plt = heatmap(xx, yy, (x, y) -> (getstate(x, y) - lqrsol.x0)' * lqrsol.S * (getstate(x, y) - lqrsol.x0), 
    color=:coolwarm,
    clim=(0, 15000))
title!(plt, "LQR Cost-to-Go")
xlabel!(plt, "Initial x velocity")
ylabel!(plt, "Initial rotational velocity")
savefig(plt, "lqr_cost_to_go.pdf")
plt

In [None]:
plt = heatmap(xx, yy, (x, y) -> Flux.Tracker.data(interval_net(getstate(x, y)))[], color=:coolwarm,
    clim=(0, 15000))
title!(plt, "Learned Cost-to-Go")
xlabel!(plt, "Initial x velocity")
ylabel!(plt, "Initial rotational velocity")
savefig(plt, "learned_cost_to_go.pdf")
plt

In [None]:
# d = [x for x in training if norm(x[1].state[[1,2,4]]) < 0.1]
# plt = scatter([s[1].state[3] for s in d], [s[1].mip.objective_value for s in d])
plt = plot()
xx = linspace(-8, 8, 101)
yy = [Flux.Tracker.data(interval_net(getstate(xi, 0)))[1] for xi in xx]
plot!(plt, xx, yy, label="LQR cost-to-go")
yy = [getstate(xi, 0)' * lqrsol.S * getstate(xi, 0) for xi in xx]
plot!(plt, xx, yy, label="Learned cost-to-go")
xlabel!(plt, "Initial cart velocity (m/s)")
ylabel!(plt, "Cost")
savefig(plt, "lqr_vs_learned_cost_1d.pdf")
savefig(plt, "lqr_vs_learned_cost_1d.png")
plt

In [None]:
state = nominal_state(robot)
set_velocity!(state, [0.0, 8])
problem = LearningMPC.simulation_problem(state, interval_net_mpc, 0.01, 6.0)
solution = RigidBodySim.solve(problem, Tsit5(), abs_tol=1e-8, dt=1e-6)
setanimation!(mvis, solution)

In [None]:
policies = [
            (interval_net_mpc, "MPC + Learned Interval"),
            (lqrsol, "LQR"), 
            (lqr_mpc, "MPC + LQR cost"),
            (policy_net_controller, "Policy Net"), 
]

tables = map(policies) do args
    controller, label = args
    LearningMPC.run_evaluations(
        controller,
        label,
        robot,
        lqrsol,
        [(1, [0])],
        [(1, linspace(-8, 8, 11)), (2, linspace(-π, π, 11))]; 
        mvis=mvis,
        horizon=400)
end
cost_table = vcat(tables...)

In [None]:
@save "cost_table.jld2" cost_table

In [None]:
@load "cost_table.jld2" cost_table

In [None]:
function cost_heatmap(table, label; clim=(0, 5000), feature=:running_cost)
    xs = sort(unique(row[:v0][1] for row in eachrow(table) if row[:controller] == label))              
    ys = sort(unique(row[:v0][2] for row in eachrow(table) if row[:controller] == label))
    zs = zeros(length(ys), length(xs))
    for row in eachrow(table)
        if row[:controller] == label
            x = row[:v0][1]
            y = row[:v0][2]
            ix = findfirst(k -> k == x, xs)
            iy = findfirst(k -> k == y, ys)
            zs[iy, ix] = row[feature][1]
        end
    end

    plt = heatmap(xs, ys, zs, clim=clim, color=:coolwarm, aspect_ratio=8/π)
    xlabel!(plt, "Cart velocity")
    ylabel!(plt, "Pole velocity")
    title!(plt, label)
    plt
end
                                
                                

In [None]:
function success_heatmap(table, label, q_threshold=π/8, v_threshold=π/8)
    xs = sort(unique(row[:v0][1] for row in eachrow(table) if row[:controller] == label))              
    ys = sort(unique(row[:v0][2] for row in eachrow(table) if row[:controller] == label))
    zs = zeros(Bool, length(ys), length(xs))
    for row in eachrow(table)
        if row[:controller] == label
            x = row[:v0][1]
            y = row[:v0][2]
            ix = findfirst(k -> k == x, xs)
            iy = findfirst(k -> k == y, ys)
            success = (isapprox(row[:qf][2], 0, atol=q_threshold) && 
                isapprox(row[:vf][2], 0, atol=v_threshold))
            zs[iy, ix] = success
        end
    end

    plt = heatmap(xs, ys, zs, colorbar=false, aspect_ratio=8/π)
    xlabel!(plt, "Cart velocity")
    ylabel!(plt, "Pole velocity")
    title!(plt, label)
    plt
end
                                
   

In [None]:
label = "Policy Net"
plt = cost_heatmap(cost_table, label)
title!(plt, "Policy Net: Running Cost")
savefig(plt, "policy_net_running_cost.svg")
savefig(plt, "policy_net_running_cost.pdf")
plt

In [None]:
plt = success_heatmap(cost_table, label)
title!(plt, "Policy Net: Successes")
savefig(plt, "policy_net_successes.pdf")
plt

In [None]:
label = "MPC + Learned Interval"
plt = cost_heatmap(cost_table, label)
title!(plt, "MPC + Learned Interval: Running Cost")
savefig(plt, "mpc_interval_running_cost.svg")
savefig(plt, "mpc_interval_running_cost.pdf")
plt

In [None]:
plt = success_heatmap(cost_table, label)
title!(plt, "MPC + Learned Interval: Successes")
savefig(plt, "mpc_interval_successes.pdf")
plt

In [None]:
label = "MPC + LQR cost"
plt = cost_heatmap(cost_table, label)
title!(plt, "MPC + LQR: Running Cost")
savefig(plt, "mpc_lqr_running_cost.svg")
savefig(plt, "mpc_lqr_running_cost.pdf")
plt

In [None]:
plt = success_heatmap(cost_table, label)
title!(plt, "MPC + LQR: Successes")
savefig(plt, "mpc_lqr_successes.pdf")
plt