In [None]:
using Revise

In [None]:
using RigidBodyDynamics
using RigidBodyTreeInspector
using DrakeVisualizer
DrakeVisualizer.any_open_windows() || DrakeVisualizer.new_window()
using LCPSim
using Polyhedra
using CDDLib
using StaticArrays: SVector
using Gurobi
using JuMP
using LearningMPC
using ExplicitQPs
using Plots; plotly()
using ProgressMeter
using JLD2
using MLDataPattern

In [None]:
import Nets
reload("Nets")

In [None]:
mechanism = parse_urdf(Float64, "cartpole.urdf")
world = root_body(mechanism)

basevis = Visualizer()[:cartpole_net]
delete!(basevis)
vis = basevis[:robot]
setgeometry!(vis, mechanism, parse_urdf("cartpole.urdf", mechanism))

wall_radius = 1.5
μ = 0.5
walls = [planar_obstacle(default_frame(world), [1., 0, 0.], [-wall_radius, 0, 0.], μ), 
    planar_obstacle(default_frame(world), [-1., 0, 0.], [wall_radius, 0, 0.], μ)]
bounds = SimpleHRepresentation(vcat(eye(3), -eye(3)), vcat([wall_radius + 0.1, 0.5, 2.0], -[-wall_radius - 0.1, -0.5, -0.1]))

for wall in walls
    addgeometry!(basevis[:environment], CDDPolyhedron{3, Float64}(intersect(wall.interior, bounds)))
end

pole = findbody(mechanism, "pole")
env = Environment(
    Dict(pole => ContactEnvironment(
            [Point3D(default_frame(pole), SVector(0., 0, 1))],
            walls)))

In [None]:
records = jldopen("cart-pole-data-0.02.jld2", "r") do file
    file["records"]
    end;

states = [r[1] for r in records]
outputs = [hcat(r[2], r[3])[1:1,:] for r in records]
data = collect(zip(states, outputs));

In [None]:
plot([s[1] for s in states], [s[2] for s in states], [o[1] for o in outputs], line=nothing, marker=:circle, xlim=(-4, 4), ylim=(-4, 4))

In [None]:
plot([s[3] for s in states], [s[4] for s in states], [o[1] for o in outputs], line=nothing, marker=:circle, xlim=(-4, 4), ylim=(-4, 4))

In [None]:
plot([s[1] for s in states], [s[3] for s in states], [o[1] for o in outputs], line=nothing, marker=:circle, xlim=(-4, 4), ylim=(-4, 4))

In [None]:
train_data, test_data = splitobs(shuffleobs(data), 0.8)
train_data = train_data[1:(floor(Int, length(train_data) / 10) * 10)]

train_data_scaled, x_to_u, v_to_y = Nets.rescale(train_data)
u_to_x = inv(x_to_u)
y_to_v = inv(v_to_y)

In [None]:
widths = [4, 64, 32, 32, 32, 32, 1]
activation = Nets.leaky_relu
function sensitive_loss(λ)
    q = [1.0-λ λ λ λ λ]
    (params, x, y) -> sum(abs2, q .* (Nets.predict_sensitivity(Nets.Net(Nets.Params(widths, params), activation), x) .- y))
end
start_params = 0.1 * randn(Nets.Params{Float64}, widths).data
nepoch = 300;

In [None]:
params = copy(start_params)
net = Nets.Net(Nets.Params(widths, params), activation, x_to_u, v_to_y)
train_loss = sensitive_loss(0.1)
validate_loss = train_loss
losses = [mean(xy -> validate_loss(params, xy[1], xy[2]), train_data_scaled)]

@show mean(xy -> sum(abs2, net(xy[1]) .- xy[2][:,1]), test_data)

@showprogress for i in 1:nepoch
    Nets.adam!(train_loss, params, train_data_scaled, Nets.AdamOpts(learning_rate=0.01 * 0.99^i, batch_size=20))
    push!(losses, mean(xy -> validate_loss(params, xy[1], xy[2]), train_data_scaled))
end

@show mean(xy -> sum(abs2, net(xy[1]) .- xy[2][:,1]), test_data)

plot(losses, ylim=(0, losses[1]))

In [None]:
mean(xy -> sum(abs2, net(xy[1]) .- xy[2][:, 1]), test_data)

In [None]:
mean(xy -> sum(abs2, net(xy[1]) .- xy[2][:, 1]), train_data)

In [None]:
controller = x -> begin
    sv = state_vector(x)
    Nets.predict(net, sv)
end

In [None]:
Nets.predict(net, [0., 0, 0, 0])

In [None]:
x0 = MechanismState{Float64}(mechanism)
set_velocity!(x0, zeros(num_velocities(x0)))
set_configuration!(x0, findjoint(mechanism, "slider_to_cart"), [0])
set_configuration!(x0, findjoint(mechanism, "cart_to_pole"), [0])
q0 = copy(configuration(x0))
v0 = copy(velocity(x0))
u0 = zeros(num_velocities(x0))
xstar = MechanismState(mechanism, q0, v0)

contacts = Point3D[]
Q = diagm([10, 10, 1, 1])
R = 0.1 * eye(num_velocities(x0))
K, S = LCPSim.ContactLQR.contact_lqr(x0, u0, Q, R, contacts)

lqr_controller = x -> begin
    -K * (state_vector(x) - state_vector(xstar)) .+ u0
end
Δt = 0.01

In [None]:
set_configuration!(x0, q0)
set_velocity!(x0, v0)
# set_configuration!(x0, findjoint(mechanism, "slider_to_cart"), [-1.5])
# set_velocity!(x0, findjoint(mechanism, "slider_to_cart"), [-2])

set_configuration!(x0, findjoint(mechanism, "cart_to_pole"), [(rand() - 0.5) * π / 16])

results = LCPSim.simulate(x0, controller, env, Δt, 400, GurobiSolver(OutputFlag=0));

In [None]:
playback(vis, results)