In [None]:
using Revise

In [None]:
using MeshCatMechanisms
using MeshCat
using RigidBodyDynamics
using RigidBodySim
using Flux
using JLD2
using Plots; gr()
using LCPSim
using LearningMPC
using LearningMPC.Models
using DataFrames
using Blink
using DataFrames

In [None]:
robot = BoxAtlas(add_contacts=true)
position_bounds(findjoint(mechanism(robot), "floating_base")) .= RigidBodyDynamics.Bounds(-Inf, Inf)
mvis = MechanismVisualizer(robot)
open(mvis, Window())

In [None]:
using CoordinateTransformations

In [None]:
settransform!(mvis.visualizer["/Cameras/default"], Translation(3.5, -0.45, 1.1) ∘ LinearMap(RotZ(π/2)))
settransform!(mvis.visualizer["/Cameras/default/rotated/<object>"], Translation(0., 0, 0))
setprop!(mvis.visualizer["/Cameras/default/rotated/<object>"], "zoom", 2.0)
setprop!(mvis.visualizer["/Grid"], "visible", 0)

In [None]:
interval_net, lqrsol = jldopen("boxatlas-regularized.jld2") do file
    file["net"], file["lqrsol"]
end
interval_net_mpc = LearningMPC.mpc_controller(robot, LearningMPC.LearnedCost(lqrsol, interval_net), Δt=0.05)

upperbound_net = jldopen("boxatlas-regularized-upperbound.jld2") do file
    file["net"]
end
upperbound_net_mpc = LearningMPC.mpc_controller(robot, LearningMPC.LearnedCost(lqrsol, upperbound_net), Δt=0.05)

lowerbound_net = jldopen("boxatlas-regularized-lowerbound.jld2") do file
    file["net"]
end
lowerbound_net_mpc = LearningMPC.mpc_controller(robot, LearningMPC.LearnedCost(lqrsol, lowerbound_net), Δt=0.05)

policy_net = jldopen("boxatlas-regularized-mimic.jld2") do file
    file["net"]
end

policy_net_controller = let net = policy_net
    function (τ, t, x)
        τ .= Flux.Tracker.data(net(LearningMPC.qv(x)))
    end
end

lqr_mpc = LearningMPC.mpc_controller(robot, lqrsol, Δt=0.05);

In [None]:
state = nominal_state(robot)
set_velocity!(state, [0.5, 0])
problem = LearningMPC.simulation_problem(state, interval_net_mpc, 0.01, 4.0)
@time solution = RigidBodySim.solve(problem, Tsit5(), abs_tol=1e-8, dt=1e-6)
setanimation!(mvis, solution)

In [None]:
xx = linspace(-1.5, 1.5, 51)
yy = linspace(-π, π, 51)
getstate = (x, y) -> begin
    z = copy(lqrsol.x0)
    z[11 + 1] = x
    z[11 + 3] = y
    z
end
plt = heatmap(xx, yy, (x, y) -> (getstate(x, y) - lqrsol.x0)' * lqrsol.S * (getstate(x, y) - lqrsol.x0), 
    color=:coolwarm, clim=(0, 300))
title!(plt, "LQR Cost-to-Go")
xlabel!(plt, "Initial x velocity")
ylabel!(plt, "Initial rotational velocity")
savefig(plt, "lqr_cost_to_go.pdf")
plt

In [None]:
plt = heatmap(xx, yy, (x, y) -> Flux.Tracker.data(interval_net(getstate(x, y)))[], color=:coolwarm, clim=(0, 300))
title!(plt, "Learned Cost-to-Go")
xlabel!(plt, "Initial x velocity")
ylabel!(plt, "Initial rotational velocity")
savefig(plt, "learned_cost_to_go.pdf")
plt

In [None]:
policies = [(lqrsol, "LQR"), 
            (lqr_mpc, "MPC + LQR cost"),
            (policy_net_controller, "Policy Net"), 
            (interval_net_mpc, "MPC + Learned Interval"),
            (upperbound_net_mpc, "MPC + Learned Upper Bound"),
            (lowerbound_net_mpc, "MPC + Learned Lower Bound")
]

tables = map(policies) do args
    controller, label = args
    LearningMPC.run_evaluations(
        controller,
        label,
        robot,
        lqrsol,
        [(1, [0])],
        [(1, linspace(-1.5, 1.5, 15)), (3, linspace(-π, π, 15))]; 
        mvis=mvis,
        horizon=400)
end
cost_table = vcat(tables...)

In [None]:
@save "cost_table-3.jld2" cost_table

In [None]:
@load "cost_table-3.jld2" cost_table

In [None]:
function cost_heatmap(table, label; clim=(0, 30000), feature=:running_cost)
    xs = sort(unique(row[:v0][1] for row in eachrow(table) if row[:controller] == label))              
    ys = sort(unique(row[:v0][3] for row in eachrow(table) if row[:controller] == label))
    zs = zeros(length(ys), length(xs))
    for row in eachrow(table)
        if row[:controller] == label
            x = row[:v0][1]
            y = row[:v0][3]
            ix = findfirst(k -> k == x, xs)
            iy = findfirst(k -> k == y, ys)
            zs[iy, ix] = row[feature][1]
        end
    end

    plt = heatmap(xs, ys, zs, clim=clim, color=:coolwarm, aspect_ratio=1.5/π)
    xlabel!(plt, "Initial x velocity")
    ylabel!(plt, "Initial rotational velocity")
    title!(plt, label)
    plt
end
                                


In [None]:
state = nominal_state(robot)
center_of_mass(state)
z = center_of_mass(state).v[3]
g = 9.81

In [None]:
xcapture = translation(transform_to_root(state, findbody(mechanism(robot), "l_foot_sole")))[2]

In [None]:
state = nominal_state(robot)
I = get(findbody(mechanism(robot), "pelvis").inertia)
Ib = get(findbody(mechanism(robot), "r_hand_mount").inertia)
I += transform(Ib, relative_transform(state, Ib.frame, I.frame))
Ib = get(findbody(mechanism(robot), "l_hand_mount").inertia)
I += transform(Ib, relative_transform(state, Ib.frame, I.frame))
Ib = get(findbody(mechanism(robot), "r_foot_sole").inertia)
I += transform(Ib, relative_transform(state, Ib.frame, I.frame))
Ib = get(findbody(mechanism(robot), "l_foot_sole").inertia)
I += transform(Ib, relative_transform(state, Ib.frame, I.frame))

J = I.moment[1]
m = I.mass

In [None]:
tofilename(x) = lowercase(replace(replace(x, " ", "_"), r"[^a-zA-Z0-9_]", ""))

function cost_heatmap_annotated(label)
    plt = cost_heatmap(cost_table, label)
    title!(plt, "$label: Running Cost")
    savefig(plt, "$(tofilename(label))_running_cost.svg")
    savefig(plt, "$(tofilename(label))_running_cost.pdf")
    savefig(plt, "$(tofilename(label))_running_cost.png")
    xx = linspace(-1.5, 1.5, 5)
    yy = @. (xx - xcapture / sqrt(z / g)) / (J / (m * z))
    plot!(plt, xx, yy, color=:yellow, linewidth=3, label="")
    xlims!(plt, -1.5, 1.5)
    yy = @. (xx + xcapture / sqrt(z / g)) / (J / (m * z))
    plot!(plt, xx, yy, color=:yellow, linewidth=3, label="")
    xlims!(plt, -1.5, 1.5)
    ylims!(plt, -π, π)
    savefig(plt, "$(tofilename(label))_running_cost_with_capture.svg")
    savefig(plt, "$(tofilename(label))_running_cost_with_capture.pdf")
    savefig(plt, "$(tofilename(label))_running_cost_with_capture.png")
    plt
end

In [None]:
cost_heatmap_annotated("LQR")

In [None]:
cost_heatmap_annotated("MPC + LQR cost")

In [None]:
cost_heatmap_annotated("Policy Net")

In [None]:
cost_heatmap_annotated("MPC + Learned Interval")

In [None]:
cost_heatmap_annotated("MPC + Learned Upper Bound")

In [None]:
cost_heatmap_annotated("MPC + Learned Lower Bound")