In [1]:
using Revise

In [11]:
using Flux
using JLD2
using FileIO
using MLDataPattern
using CoordinateTransformations
using ProgressMeter
using RigidBodyDynamics
using Gurobi
using DrakeVisualizer
DrakeVisualizer.any_open_windows() || DrakeVisualizer.new_window()
import FluxExtensions
import LearningMPC
import LCPSim
import Hoppers

[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/rdeits/locomotion/explorations/learning-mpc/packages/lib/v0.6/Hoppers.ji for module Hoppers.
[39m

In [4]:
samples = load("2017-12-20-hopper-grid/grid_search.jld2")["samples"];

In [9]:
minimum(samples) do sample
    sample.mip.objective_bound
end

2.2224779361445144

In [8]:
maximum(samples) do sample
    sample.mip.objective_value
end

1459.1910382673889

In [10]:
samples[1]

LearningMPC.Sample{Float64}([0.25, 0.25, -5.0, -5.0], [-0.0 -3.44884e-13 … 7.031e-15 -1.62567e-21; 40.0 -5.19332e-13 … 8.27324e-15 9.80654e-27], [445.908], LearningMPC.MIPResults
  solvetime_s: Float64 0.0306549072265625
  objective_value: Float64 232.33189199391802
  objective_bound: Float64 210.50321396361096
)

In [12]:
robot = Hoppers.Hopper()
xstar = Hoppers.nominal_state(robot)
ustar = zeros(num_velocities(xstar))
basevis = Visualizer()[:hopper]
setgeometry!(basevis, robot)
settransform!(basevis[:robot], xstar)

Q, R = Hoppers.default_costs(robot)
foot = findbody(robot.mechanism, "foot")
Δt = 0.05
Jc = LCPSim.ContactLQR.contact_jacobian(xstar, [Point3D(default_frame(foot), 0., 0., 0.)])
A, B, c = LCPSim.ContactLQR.contact_linearize(xstar, ustar, Jc)

([0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], [0.0 0.0; 0.0 0.0; 0.0 1.0; 0.0 1.0], [0.0, 0.0, -9.81, -9.81])

In [16]:
features(sample::LearningMPC.Sample) = (sample.state, sample.mip.objective_bound, sample.mip.objective_value, sample.uJ[:, 1])
data = features.(samples);
train_data, test_data = splitobs(shuffleobs(data), at=0.85);

In [33]:
function setup_model(R, B)
    model = FluxExtensions.TangentPropagator(Chain(
        LinearMap(UniformScaling(0.2)),
        Dense(4, 32, elu),
        Dense(32, 32, elu),
        Dense(32, 1, elu),
        AffineMap(fill(750, 1, 1), [750])
        ))
    RiBt = inv(R) * B' 

    function sample_loss(w_tangent)
        (x, lb, ub, input) -> begin
            y, J = model(x)
            value_cost = sum(ifelse.(y .< lb, lb .- y, ifelse.(y .> ub, y .- ub, 0 .* y)))
            tangent_cost = w_tangent * sum(abs2.(-RiBt * J' .- input))
            value_cost + tangent_cost
        end
    end
    function batch_loss(w_tangent)
        loss = sample_loss(w_tangent)
        (samples) -> sum((sample -> loss(sample...)).(samples)) / length(samples)
    end
    model, batch_loss
end

model, loss = setup_model(R, B)
opt = Flux.ADADelta(params(model))

(::#71) (generic function with 1 method)

In [38]:
loss(0.0)(train_data[1:1000])

Tracked 0-dimensional Array{Float64,0}:
551.6

In [39]:
train_loss = loss(1.0)
@showprogress for i in 1:200
    batches = eachbatch(shuffleobs(train_data), size=50)
    for batch in batches
        l = train_loss(batch)
        isinf(Flux.Tracker.value(l)) && error("Loss is Inf")
        isnan(Flux.Tracker.value(l)) && error("Loss is NaN")
        Flux.back!(l)
        opt()
    end
    @show loss(1.0)(train_data) loss(0.0)(train_data)
end

(loss(1.0))(train_data) = param(6.87789e5)
(loss(0.0))(train_data) = param(479.016)


[32mProgress:   0%|                                         |  ETA: 15:26:52[39m

(loss(1.0))(train_data) = param(9846.29)
(loss(0.0))(train_data) = param(169.417)


[32mProgress:   1%|                                         |  ETA: 15:22:47[39m

(loss(1.0))(train_data) = param(291.665)
(loss(0.0))(train_data) = param(227.037)


[32mProgress:   2%|█                                        |  ETA: 15:14:37[39m

(loss(1.0))(train_data) = param(277.14)
(loss(0.0))(train_data) = param(228.902)


[32mProgress:   2%|█                                        |  ETA: 15:07:15[39m

(loss(1.0))(train_data) = param(280.612)
(loss(0.0))(train_data) = param(229.049)


[32mProgress:   2%|█                                        |  ETA: 14:58:30[39m

(loss(1.0))(train_data) = param(274.471)
(loss(0.0))(train_data) = param(228.942)


[32mProgress:   3%|█                                        |  ETA: 14:51:12[39m

(loss(1.0))(train_data) = param(285.61)
(loss(0.0))(train_data) = param(228.725)


[32mProgress:   4%|█                                        |  ETA: 14:45:33[39m

(loss(1.0))(train_data) = param(272.539)
(loss(0.0))(train_data) = param(228.369)


[32mProgress:   4%|██                                       |  ETA: 14:39:19[39m

(loss(1.0))(train_data) = param(265.245)
(loss(0.0))(train_data) = param(228.059)


[32mProgress:   4%|██                                       |  ETA: 14:33:35[39m

(loss(1.0))(train_data) = param(262.619)
(loss(0.0))(train_data) = param(227.701)
(loss(1.0))(train_data) = param(258.863)
(loss(0.0))(train_data) = param(227.213)


[32mProgress:   5%|██                                       |  ETA: 14:28:28[39m[32mProgress:   6%|██                                       |  ETA: 14:24:05[39m

(loss(1.0))(train_data) = param(266.511)
(loss(0.0))(train_data) = param(226.726)


[32mProgress:   6%|██                                       |  ETA: 14:20:00[39m

(loss(1.0))(train_data) = param(271.107)
(loss(0.0))(train_data) = param(225.898)


[32mProgress:   6%|███                                      |  ETA: 14:15:56[39m

(loss(1.0))(train_data) = param(262.841)
(loss(0.0))(train_data) = param(224.866)


[32mProgress:   7%|███                                      |  ETA: 14:11:39[39m

(loss(1.0))(train_data) = param(250.537)
(loss(0.0))(train_data) = param(223.777)


[32mProgress:   8%|███                                      |  ETA: 14:10:52[39m

(loss(1.0))(train_data) = param(245.456)
(loss(0.0))(train_data) = param(222.489)


[32mProgress:   8%|███                                      |  ETA: 14:06:58[39m

(loss(1.0))(train_data) = param(245.061)
(loss(0.0))(train_data) = param(220.857)


[32mProgress:   8%|███                                      |  ETA: 14:02:28[39m

(loss(1.0))(train_data) = param(239.942)
(loss(0.0))(train_data) = param(218.63)
(loss(1.0))(train_data) = param(236.675)
(loss(0.0))(train_data) = param(215.76)


[32mProgress:   9%|████                                     |  ETA: 13:57:50[39m[32mProgress:  10%|████                                     |  ETA: 13:52:45[39m

(loss(1.0))(train_data) = param(230.327)
(loss(0.0))(train_data) = param(211.129)


[32mProgress:  10%|████                                     |  ETA: 13:48:07[39m

(loss(1.0))(train_data) = param(243.954)
(loss(0.0))(train_data) = param(203.38)


[32mProgress:  10%|████                                     |  ETA: 13:42:52[39m

(loss(1.0))(train_data) = param(238.041)
(loss(0.0))(train_data) = param(198.026)
(loss(1.0))(train_data) = param(221.965)
(loss(0.0))(train_data) = param(189.18)


[32mProgress:  11%|█████                                    |  ETA: 13:37:58[39m[32mProgress:  12%|█████                                    |  ETA: 13:32:53[39m

(loss(1.0))(train_data) = param(247.496)
(loss(0.0))(train_data) = param(173.964)


[32mProgress:  12%|█████                                    |  ETA: 13:27:55[39m

(loss(1.0))(train_data) = param(239.808)
(loss(0.0))(train_data) = param(154.554)


[32mProgress:  12%|█████                                    |  ETA: 13:23:14[39m

(loss(1.0))(train_data) = param(164.029)
(loss(0.0))(train_data) = param(130.356)


[32mProgress:  13%|█████                                    |  ETA: 13:18:14[39m

(loss(1.0))(train_data) = param(140.717)
(loss(0.0))(train_data) = param(114.961)


[32mProgress:  14%|██████                                   |  ETA: 13:13:28[39m

(loss(1.0))(train_data) = param(141.437)
(loss(0.0))(train_data) = param(112.366)


[32mProgress:  14%|██████                                   |  ETA: 13:08:49[39m

(loss(1.0))(train_data) = param(141.488)
(loss(0.0))(train_data) = param(112.277)


[32mProgress:  14%|██████                                   |  ETA: 13:05:17[39m

(loss(1.0))(train_data) = param(140.624)
(loss(0.0))(train_data) = param(111.967)


[32mProgress:  15%|██████                                   |  ETA: 13:01:45[39m

(loss(1.0))(train_data) = param(136.972)
(loss(0.0))(train_data) = param(111.954)


[32mProgress:  16%|██████                                   |  ETA: 12:57:44[39m

(loss(1.0))(train_data) = param(140.95)
(loss(0.0))(train_data) = param(111.979)


[32mProgress:  16%|███████                                  |  ETA: 12:53:01[39m

(loss(1.0))(train_data) = param(138.88)
(loss(0.0))(train_data) = param(111.776)


[32mProgress:  16%|███████                                  |  ETA: 12:48:10[39m

(loss(1.0))(train_data) = param(138.329)
(loss(0.0))(train_data) = param(111.875)


[32mProgress:  17%|███████                                  |  ETA: 12:43:34[39m

(loss(1.0))(train_data) = param(137.249)
(loss(0.0))(train_data) = param(111.771)


[32mProgress:  18%|███████                                  |  ETA: 12:39:21[39m

(loss(1.0))(train_data) = param(135.704)
(loss(0.0))(train_data) = param(111.666)


[32mProgress:  18%|███████                                  |  ETA: 12:34:45[39m

(loss(1.0))(train_data) = param(136.821)
(loss(0.0))(train_data) = param(111.587)


[32mProgress:  18%|████████                                 |  ETA: 12:30:17[39m

(loss(1.0))(train_data) = param(136.607)
(loss(0.0))(train_data) = param(111.766)


[32mProgress:  19%|████████                                 |  ETA: 12:25:43[39m

(loss(1.0))(train_data) = param(137.005)
(loss(0.0))(train_data) = param(111.784)


[32mProgress:  20%|████████                                 |  ETA: 12:21:02[39m

(loss(1.0))(train_data) = param(140.104)
(loss(0.0))(train_data) = param(111.721)


[32mProgress:  20%|████████                                 |  ETA: 12:16:33[39m

(loss(1.0))(train_data) = param(139.144)
(loss(0.0))(train_data) = param(111.545)


[32mProgress:  20%|████████                                 |  ETA: 12:11:36[39m

(loss(1.0))(train_data) = param(139.915)
(loss(0.0))(train_data) = param(111.767)


[32mProgress:  21%|█████████                                |  ETA: 12:06:42[39m

(loss(1.0))(train_data) = param(139.904)
(loss(0.0))(train_data) = param(111.663)


[32mProgress:  22%|█████████                                |  ETA: 12:01:55[39m

(loss(1.0))(train_data) = param(135.416)
(loss(0.0))(train_data) = param(111.679)


[32mProgress:  22%|█████████                                |  ETA: 11:57:07[39m

(loss(1.0))(train_data) = param(135.404)
(loss(0.0))(train_data) = param(111.609)


[32mProgress:  22%|█████████                                |  ETA: 11:52:22[39m

(loss(1.0))(train_data) = param(134.701)
(loss(0.0))(train_data) = param(111.671)


[32mProgress:  23%|█████████                                |  ETA: 11:47:29[39m

(loss(1.0))(train_data) = param(138.534)
(loss(0.0))(train_data) = param(111.593)
(loss(1.0))(train_data) = param(139.997)
(loss(0.0))(train_data) = param(111.7)


[32mProgress:  24%|██████████                               |  ETA: 11:42:44[39m[32mProgress:  24%|██████████                               |  ETA: 11:37:53[39m

(loss(1.0))(train_data) = param(147.31)
(loss(0.0))(train_data) = param(111.706)


[32mProgress:  24%|██████████                               |  ETA: 11:33:05[39m

(loss(1.0))(train_data) = param(134.748)
(loss(0.0))(train_data) = param(111.537)


[32mProgress:  25%|██████████                               |  ETA: 11:28:18[39m

(loss(1.0))(train_data) = param(143.494)
(loss(0.0))(train_data) = param(111.677)


[32mProgress:  26%|██████████                               |  ETA: 11:23:35[39m

(loss(1.0))(train_data) = param(138.571)
(loss(0.0))(train_data) = param(111.744)
(loss(1.0))(train_data) = param(134.57)
(loss(0.0))(train_data) = param(111.544)


[32mProgress:  26%|███████████                              |  ETA: 11:18:47[39m[32mProgress:  26%|███████████                              |  ETA: 11:14:01[39m

(loss(1.0))(train_data) = param(136.414)
(loss(0.0))(train_data) = param(111.661)


[32mProgress:  27%|███████████                              |  ETA: 11:09:25[39m

(loss(1.0))(train_data) = param(134.489)
(loss(0.0))(train_data) = param(111.724)


[32mProgress:  28%|███████████                              |  ETA: 11:04:40[39m

(loss(1.0))(train_data) = param(137.101)
(loss(0.0))(train_data) = param(111.517)


[32mProgress:  28%|███████████                              |  ETA: 10:59:56[39m

(loss(1.0))(train_data) = param(141.721)
(loss(0.0))(train_data) = param(111.652)


[32mProgress:  28%|████████████                             |  ETA: 10:55:14[39m

(loss(1.0))(train_data) = param(133.768)
(loss(0.0))(train_data) = param(111.65)


[32mProgress:  29%|████████████                             |  ETA: 10:50:33[39m

(loss(1.0))(train_data) = param(134.751)
(loss(0.0))(train_data) = param(111.627)


[32mProgress:  30%|████████████                             |  ETA: 10:45:58[39m

(loss(1.0))(train_data) = param(136.144)
(loss(0.0))(train_data) = param(111.654)


[32mProgress:  30%|████████████                             |  ETA: 10:41:13[39m

(loss(1.0))(train_data) = param(134.231)
(loss(0.0))(train_data) = param(111.706)


[32mProgress:  30%|█████████████                            |  ETA: 10:36:36[39m

(loss(1.0))(train_data) = param(134.52)
(loss(0.0))(train_data) = param(111.63)


[32mProgress:  31%|█████████████                            |  ETA: 10:31:51[39m

(loss(1.0))(train_data) = param(133.555)
(loss(0.0))(train_data) = param(111.747)
(loss(1.0))(train_data) = param(133.303)
(loss(0.0))(train_data) = param(111.735)


[32mProgress:  32%|█████████████                            |  ETA: 10:27:10[39m[32mProgress:  32%|█████████████                            |  ETA: 10:22:31[39m

(loss(1.0))(train_data) = param(135.48)
(loss(0.0))(train_data) = param(111.72)


[32mProgress:  32%|█████████████                            |  ETA: 10:17:49[39m

(loss(1.0))(train_data) = param(134.401)
(loss(0.0))(train_data) = param(111.527)


[32mProgress:  33%|██████████████                           |  ETA: 10:13:10[39m

(loss(1.0))(train_data) = param(134.126)
(loss(0.0))(train_data) = param(111.739)


[32mProgress:  34%|██████████████                           |  ETA: 10:08:27[39m

(loss(1.0))(train_data) = param(133.878)
(loss(0.0))(train_data) = param(111.632)


[32mProgress:  34%|██████████████                           |  ETA: 10:03:53[39m

(loss(1.0))(train_data) = param(138.571)
(loss(0.0))(train_data) = param(111.857)


[32mProgress:  34%|██████████████                           |  ETA: 9:59:10[39m

(loss(1.0))(train_data) = param(138.48)
(loss(0.0))(train_data) = param(111.612)


[32mProgress:  35%|██████████████                           |  ETA: 9:54:32[39m

(loss(1.0))(train_data) = param(140.44)
(loss(0.0))(train_data) = param(111.604)


[32mProgress:  36%|███████████████                          |  ETA: 9:49:53[39m

(loss(1.0))(train_data) = param(146.97)
(loss(0.0))(train_data) = param(111.479)


[32mProgress:  36%|███████████████                          |  ETA: 9:45:15[39m

(loss(1.0))(train_data) = param(136.55)
(loss(0.0))(train_data) = param(111.529)


[32mProgress:  36%|███████████████                          |  ETA: 9:40:34[39m

(loss(1.0))(train_data) = param(141.584)
(loss(0.0))(train_data) = param(111.527)


[32mProgress:  37%|███████████████                          |  ETA: 9:35:54[39m

(loss(1.0))(train_data) = param(134.156)
(loss(0.0))(train_data) = param(111.923)


[32mProgress:  38%|███████████████                          |  ETA: 9:31:17[39m

(loss(1.0))(train_data) = param(141.741)
(loss(0.0))(train_data) = param(111.757)


[32mProgress:  38%|████████████████                         |  ETA: 9:26:36[39m

(loss(1.0))(train_data) = param(143.579)
(loss(0.0))(train_data) = param(111.679)


[32mProgress:  38%|████████████████                         |  ETA: 9:22:00[39m

(loss(1.0))(train_data) = param(133.334)
(loss(0.0))(train_data) = param(111.574)


[32mProgress:  39%|████████████████                         |  ETA: 9:17:23[39m

(loss(1.0))(train_data) = param(133.97)
(loss(0.0))(train_data) = param(111.298)


[32mProgress:  40%|████████████████                         |  ETA: 9:12:45[39m

(loss(1.0))(train_data) = param(134.044)
(loss(0.0))(train_data) = param(111.538)


[32mProgress:  40%|████████████████                         |  ETA: 9:08:05[39m

(loss(1.0))(train_data) = param(135.121)
(loss(0.0))(train_data) = param(111.46)


[32mProgress:  40%|█████████████████                        |  ETA: 9:03:29[39m

(loss(1.0))(train_data) = param(132.961)
(loss(0.0))(train_data) = param(111.664)


[32mProgress:  41%|█████████████████                        |  ETA: 8:58:52[39m

(loss(1.0))(train_data) = param(132.333)
(loss(0.0))(train_data) = param(111.474)


[32mProgress:  42%|█████████████████                        |  ETA: 8:54:12[39m

(loss(1.0))(train_data) = param(132.931)
(loss(0.0))(train_data) = param(111.359)


[32mProgress:  42%|█████████████████                        |  ETA: 8:49:37[39m

(loss(1.0))(train_data) = param(132.041)
(loss(0.0))(train_data) = param(111.351)


[32mProgress:  42%|█████████████████                        |  ETA: 8:45:01[39m

(loss(1.0))(train_data) = param(140.351)
(loss(0.0))(train_data) = param(111.467)
(loss(1.0))(train_data) = param(133.317)
(loss(0.0))(train_data) = param(111.37)


[32mProgress:  43%|██████████████████                       |  ETA: 8:40:22[39m[32mProgress:  44%|██████████████████                       |  ETA: 8:35:44[39m

(loss(1.0))(train_data) = param(137.348)
(loss(0.0))(train_data) = param(111.27)


[32mProgress:  44%|██████████████████                       |  ETA: 8:31:08[39m

(loss(1.0))(train_data) = param(132.533)
(loss(0.0))(train_data) = param(111.517)


[32mProgress:  44%|██████████████████                       |  ETA: 8:26:31[39m

(loss(1.0))(train_data) = param(134.266)
(loss(0.0))(train_data) = param(111.376)


[32mProgress:  45%|██████████████████                       |  ETA: 8:21:54[39m

(loss(1.0))(train_data) = param(131.333)
(loss(0.0))(train_data) = param(111.402)


[32mProgress:  46%|███████████████████                      |  ETA: 8:17:20[39m

(loss(1.0))(train_data) = param(131.474)
(loss(0.0))(train_data) = param(111.5)


[32mProgress:  46%|███████████████████                      |  ETA: 8:12:43[39m

(loss(1.0))(train_data) = param(131.661)
(loss(0.0))(train_data) = param(111.34)


[32mProgress:  46%|███████████████████                      |  ETA: 8:08:08[39m

(loss(1.0))(train_data) = param(132.527)
(loss(0.0))(train_data) = param(111.511)


[32mProgress:  47%|███████████████████                      |  ETA: 8:03:30[39m

(loss(1.0))(train_data) = param(133.817)
(loss(0.0))(train_data) = param(111.311)


[32mProgress:  48%|███████████████████                      |  ETA: 7:58:56[39m

(loss(1.0))(train_data) = param(130.869)
(loss(0.0))(train_data) = param(111.323)


[32mProgress:  48%|████████████████████                     |  ETA: 7:54:19[39m

(loss(1.0))(train_data) = param(131.179)
(loss(0.0))(train_data) = param(111.28)


[32mProgress:  48%|████████████████████                     |  ETA: 7:49:42[39m

(loss(1.0))(train_data) = param(131.854)
(loss(0.0))(train_data) = param(111.168)


[32mProgress:  49%|████████████████████                     |  ETA: 7:45:09[39m

(loss(1.0))(train_data) = param(131.974)
(loss(0.0))(train_data) = param(111.122)


[32mProgress:  50%|████████████████████                     |  ETA: 7:40:32[39m

(loss(1.0))(train_data) = param(131.388)
(loss(0.0))(train_data) = param(111.488)


[32mProgress:  50%|████████████████████                     |  ETA: 7:35:55[39m

(loss(1.0))(train_data) = param(132.146)
(loss(0.0))(train_data) = param(111.534)


[32mProgress:  50%|█████████████████████                    |  ETA: 7:31:18[39m

(loss(1.0))(train_data) = param(131.234)
(loss(0.0))(train_data) = param(111.551)
(loss(1.0))(train_data) = param(133.583)
(loss(0.0))(train_data) = param(111.431)


[32mProgress:  51%|█████████████████████                    |  ETA: 7:26:46[39m[32mProgress:  52%|█████████████████████                    |  ETA: 7:22:09[39m

(loss(1.0))(train_data) = param(131.928)
(loss(0.0))(train_data) = param(111.57)


[32mProgress:  52%|█████████████████████                    |  ETA: 7:17:32[39m

(loss(1.0))(train_data) = param(130.549)
(loss(0.0))(train_data) = param(110.975)


[32mProgress:  52%|██████████████████████                   |  ETA: 7:12:59[39m

(loss(1.0))(train_data) = param(134.082)
(loss(0.0))(train_data) = param(111.063)


[32mProgress:  53%|██████████████████████                   |  ETA: 7:08:23[39m

(loss(1.0))(train_data) = param(130.84)
(loss(0.0))(train_data) = param(111.294)


[32mProgress:  54%|██████████████████████                   |  ETA: 7:03:48[39m

(loss(1.0))(train_data) = param(130.684)
(loss(0.0))(train_data) = param(111.152)


[32mProgress:  54%|██████████████████████                   |  ETA: 6:59:13[39m

(loss(1.0))(train_data) = param(131.157)
(loss(0.0))(train_data) = param(111.169)


[32mProgress:  54%|██████████████████████                   |  ETA: 6:54:38[39m

(loss(1.0))(train_data) = param(130.81)
(loss(0.0))(train_data) = param(111.112)


[32mProgress:  55%|███████████████████████                  |  ETA: 6:50:02[39m

(loss(1.0))(train_data) = param(130.726)
(loss(0.0))(train_data) = param(111.364)


[32mProgress:  56%|███████████████████████                  |  ETA: 6:45:27[39m

(loss(1.0))(train_data) = param(134.848)
(loss(0.0))(train_data) = param(111.191)


[32mProgress:  56%|███████████████████████                  |  ETA: 6:40:55[39m

(loss(1.0))(train_data) = param(131.857)
(loss(0.0))(train_data) = param(111.165)


[32mProgress:  56%|███████████████████████                  |  ETA: 6:36:19[39m

(loss(1.0))(train_data) = param(131.591)
(loss(0.0))(train_data) = param(110.979)


[32mProgress:  57%|███████████████████████                  |  ETA: 6:31:45[39m

(loss(1.0))(train_data) = param(131.199)
(loss(0.0))(train_data) = param(110.973)


[32mProgress:  58%|████████████████████████                 |  ETA: 6:27:10[39m

(loss(1.0))(train_data) = param(130.084)
(loss(0.0))(train_data) = param(110.977)


[32mProgress:  58%|████████████████████████                 |  ETA: 6:22:36[39m

(loss(1.0))(train_data) = param(130.117)
(loss(0.0))(train_data) = param(111.05)


[32mProgress:  58%|████████████████████████                 |  ETA: 6:18:01[39m

(loss(1.0))(train_data) = param(130.776)
(loss(0.0))(train_data) = param(111.085)
(loss(1.0))(train_data) = param(134.93)
(loss(0.0))(train_data) = param(110.979)


[32mProgress:  59%|████████████████████████                 |  ETA: 6:13:26[39m[32mProgress:  60%|████████████████████████                 |  ETA: 6:08:52[39m

(loss(1.0))(train_data) = param(134.294)
(loss(0.0))(train_data) = param(111.133)


[32mProgress:  60%|█████████████████████████                |  ETA: 6:04:17[39m

(loss(1.0))(train_data) = param(133.087)
(loss(0.0))(train_data) = param(110.881)


[32mProgress:  60%|█████████████████████████                |  ETA: 5:59:43[39m

(loss(1.0))(train_data) = param(132.965)
(loss(0.0))(train_data) = param(110.928)


[32mProgress:  61%|█████████████████████████                |  ETA: 5:55:10[39m

(loss(1.0))(train_data) = param(131.45)
(loss(0.0))(train_data) = param(111.03)


[32mProgress:  62%|█████████████████████████                |  ETA: 5:50:35[39m

(loss(1.0))(train_data) = param(135.683)
(loss(0.0))(train_data) = param(110.813)


[32mProgress:  62%|█████████████████████████                |  ETA: 5:46:00[39m

(loss(1.0))(train_data) = param(132.748)
(loss(0.0))(train_data) = param(111.12)


[32mProgress:  62%|██████████████████████████               |  ETA: 5:41:26[39m

(loss(1.0))(train_data) = param(129.897)
(loss(0.0))(train_data) = param(110.971)


[32mProgress:  63%|██████████████████████████               |  ETA: 5:36:51[39m

(loss(1.0))(train_data) = param(130.392)
(loss(0.0))(train_data) = param(110.945)
(loss(1.0))(train_data) = param(132.975)
(loss(0.0))(train_data) = param(110.531)


[32mProgress:  64%|██████████████████████████               |  ETA: 5:32:16[39m[32mProgress:  64%|██████████████████████████               |  ETA: 5:27:43[39m

(loss(1.0))(train_data) = param(129.747)
(loss(0.0))(train_data) = param(110.697)


[32mProgress:  64%|██████████████████████████               |  ETA: 5:23:09[39m

(loss(1.0))(train_data) = param(130.255)
(loss(0.0))(train_data) = param(110.87)


[32mProgress:  65%|███████████████████████████              |  ETA: 5:18:35[39m

(loss(1.0))(train_data) = param(129.856)
(loss(0.0))(train_data) = param(110.771)


[32mProgress:  66%|███████████████████████████              |  ETA: 5:14:01[39m

(loss(1.0))(train_data) = param(131.564)
(loss(0.0))(train_data) = param(110.766)


[32mProgress:  66%|███████████████████████████              |  ETA: 5:09:27[39m

(loss(1.0))(train_data) = param(134.943)
(loss(0.0))(train_data) = param(110.508)


[32mProgress:  66%|███████████████████████████              |  ETA: 5:04:52[39m

(loss(1.0))(train_data) = param(129.837)
(loss(0.0))(train_data) = param(110.951)


[32mProgress:  67%|███████████████████████████              |  ETA: 5:00:18[39m

(loss(1.0))(train_data) = param(130.037)
(loss(0.0))(train_data) = param(110.515)


[32mProgress:  68%|████████████████████████████             |  ETA: 4:55:45[39m

(loss(1.0))(train_data) = param(130.916)
(loss(0.0))(train_data) = param(110.946)


[32mProgress:  68%|████████████████████████████             |  ETA: 4:51:11[39m

(loss(1.0))(train_data) = param(130.166)
(loss(0.0))(train_data) = param(110.562)


[32mProgress:  68%|████████████████████████████             |  ETA: 4:46:37[39m

(loss(1.0))(train_data) = param(129.759)
(loss(0.0))(train_data) = param(110.821)


[32mProgress:  69%|████████████████████████████             |  ETA: 4:42:03[39m

(loss(1.0))(train_data) = param(129.861)
(loss(0.0))(train_data) = param(110.635)


[32mProgress:  70%|████████████████████████████             |  ETA: 4:37:31[39m

(loss(1.0))(train_data) = param(130.224)
(loss(0.0))(train_data) = param(110.836)


[32mProgress:  70%|█████████████████████████████            |  ETA: 4:32:56[39m

(loss(1.0))(train_data) = param(130.614)
(loss(0.0))(train_data) = param(110.665)


[32mProgress:  70%|█████████████████████████████            |  ETA: 4:28:23[39m

(loss(1.0))(train_data) = param(137.673)
(loss(0.0))(train_data) = param(110.637)


[32mProgress:  71%|█████████████████████████████            |  ETA: 4:23:50[39m

(loss(1.0))(train_data) = param(131.946)
(loss(0.0))(train_data) = param(110.484)


[32mProgress:  72%|█████████████████████████████            |  ETA: 4:19:16[39m

(loss(1.0))(train_data) = param(129.682)
(loss(0.0))(train_data) = param(110.635)


[32mProgress:  72%|██████████████████████████████           |  ETA: 4:14:43[39m

(loss(1.0))(train_data) = param(131.023)
(loss(0.0))(train_data) = param(110.661)


[32mProgress:  72%|██████████████████████████████           |  ETA: 4:10:09[39m

(loss(1.0))(train_data) = param(131.016)
(loss(0.0))(train_data) = param(110.612)


[32mProgress:  73%|██████████████████████████████           |  ETA: 4:05:36[39m

(loss(1.0))(train_data) = param(130.454)
(loss(0.0))(train_data) = param(110.61)


[32mProgress:  74%|██████████████████████████████           |  ETA: 4:01:02[39m

(loss(1.0))(train_data) = param(129.093)
(loss(0.0))(train_data) = param(110.541)
(loss(1.0))(train_data) = param(130.469)
(loss(0.0))(train_data) = param(110.642)


[32mProgress:  74%|██████████████████████████████           |  ETA: 3:56:29[39m[32mProgress:  74%|███████████████████████████████          |  ETA: 3:51:55[39m

(loss(1.0))(train_data) = param(131.105)
(loss(0.0))(train_data) = param(110.696)


[32mProgress:  75%|███████████████████████████████          |  ETA: 3:47:22[39m

(loss(1.0))(train_data) = param(130.209)
(loss(0.0))(train_data) = param(110.3)


[32mProgress:  76%|███████████████████████████████          |  ETA: 3:42:49[39m

(loss(1.0))(train_data) = param(129.167)
(loss(0.0))(train_data) = param(110.439)


[32mProgress:  76%|███████████████████████████████          |  ETA: 3:38:15[39m

(loss(1.0))(train_data) = param(129.123)
(loss(0.0))(train_data) = param(110.23)


[32mProgress:  76%|███████████████████████████████          |  ETA: 3:33:42[39m

(loss(1.0))(train_data) = param(130.078)
(loss(0.0))(train_data) = param(110.293)


[32mProgress:  77%|████████████████████████████████         |  ETA: 3:29:08[39m

(loss(1.0))(train_data) = param(130.103)
(loss(0.0))(train_data) = param(110.029)
(loss(1.0))(train_data) = param(129.752)
(loss(0.0))(train_data) = param(110.02)


[32mProgress:  78%|████████████████████████████████         |  ETA: 3:24:36[39m[32mProgress:  78%|████████████████████████████████         |  ETA: 3:20:03[39m

(loss(1.0))(train_data) = param(130.29)
(loss(0.0))(train_data) = param(110.321)


[32mProgress:  78%|████████████████████████████████         |  ETA: 3:15:30[39m

(loss(1.0))(train_data) = param(128.834)
(loss(0.0))(train_data) = param(110.255)


[32mProgress:  79%|████████████████████████████████         |  ETA: 3:10:56[39m

(loss(1.0))(train_data) = param(129.232)
(loss(0.0))(train_data) = param(110.276)
(loss(1.0))(train_data) = param(129.01)
(loss(0.0))(train_data) = param(110.158)


[32mProgress:  80%|█████████████████████████████████        |  ETA: 3:06:23[39m[32mProgress:  80%|█████████████████████████████████        |  ETA: 3:01:50[39m

(loss(1.0))(train_data) = param(129.318)
(loss(0.0))(train_data) = param(109.82)


[32mProgress:  80%|█████████████████████████████████        |  ETA: 2:57:17[39m

(loss(1.0))(train_data) = param(130.168)
(loss(0.0))(train_data) = param(110.067)


[32mProgress:  81%|█████████████████████████████████        |  ETA: 2:52:44[39m

(loss(1.0))(train_data) = param(128.248)
(loss(0.0))(train_data) = param(109.847)
(loss(1.0))(train_data) = param(131.646)
(loss(0.0))(train_data) = param(110.162)


[32mProgress:  82%|█████████████████████████████████        |  ETA: 2:48:11[39m[32mProgress:  82%|██████████████████████████████████       |  ETA: 2:43:38[39m

(loss(1.0))(train_data) = param(129.872)
(loss(0.0))(train_data) = param(109.944)


[32mProgress:  82%|██████████████████████████████████       |  ETA: 2:39:05[39m

(loss(1.0))(train_data) = param(130.369)
(loss(0.0))(train_data) = param(110.388)


[32mProgress:  83%|██████████████████████████████████       |  ETA: 2:34:32[39m

(loss(1.0))(train_data) = param(129.99)
(loss(0.0))(train_data) = param(110.316)


[32mProgress:  84%|██████████████████████████████████       |  ETA: 2:29:59[39m

(loss(1.0))(train_data) = param(128.964)
(loss(0.0))(train_data) = param(109.961)


[32mProgress:  84%|██████████████████████████████████       |  ETA: 2:25:26[39m

(loss(1.0))(train_data) = param(128.783)
(loss(0.0))(train_data) = param(109.881)


[32mProgress:  84%|███████████████████████████████████      |  ETA: 2:20:53[39m

(loss(1.0))(train_data) = param(130.59)
(loss(0.0))(train_data) = param(109.736)


[32mProgress:  85%|███████████████████████████████████      |  ETA: 2:16:20[39m

(loss(1.0))(train_data) = param(130.896)
(loss(0.0))(train_data) = param(109.786)


[32mProgress:  86%|███████████████████████████████████      |  ETA: 2:11:47[39m

(loss(1.0))(train_data) = param(129.853)
(loss(0.0))(train_data) = param(109.599)


[32mProgress:  86%|███████████████████████████████████      |  ETA: 2:07:14[39m

(loss(1.0))(train_data) = param(129.902)
(loss(0.0))(train_data) = param(109.702)


[32mProgress:  86%|███████████████████████████████████      |  ETA: 2:02:42[39m

(loss(1.0))(train_data) = param(139.074)
(loss(0.0))(train_data) = param(109.9)


[32mProgress:  87%|████████████████████████████████████     |  ETA: 1:58:09[39m

(loss(1.0))(train_data) = param(131.473)
(loss(0.0))(train_data) = param(109.513)


[32mProgress:  88%|████████████████████████████████████     |  ETA: 1:53:36[39m

(loss(1.0))(train_data) = param(127.984)
(loss(0.0))(train_data) = param(109.717)
(loss(1.0))(train_data) = param(134.202)
(loss(0.0))(train_data) = param(109.762)


[32mProgress:  88%|████████████████████████████████████     |  ETA: 1:49:03[39m[32mProgress:  88%|████████████████████████████████████     |  ETA: 1:44:30[39m

(loss(1.0))(train_data) = param(130.057)
(loss(0.0))(train_data) = param(109.509)


[32mProgress:  89%|████████████████████████████████████     |  ETA: 1:39:58[39m

(loss(1.0))(train_data) = param(129.404)
(loss(0.0))(train_data) = param(109.432)


[32mProgress:  90%|█████████████████████████████████████    |  ETA: 1:35:25[39m

(loss(1.0))(train_data) = param(128.28)
(loss(0.0))(train_data) = param(109.49)


[32mProgress:  90%|█████████████████████████████████████    |  ETA: 1:30:52[39m

(loss(1.0))(train_data) = param(128.712)
(loss(0.0))(train_data) = param(109.758)


[32mProgress:  90%|█████████████████████████████████████    |  ETA: 1:26:19[39m

(loss(1.0))(train_data) = param(129.822)
(loss(0.0))(train_data) = param(109.316)


[32mProgress:  91%|█████████████████████████████████████    |  ETA: 1:21:46[39m

(loss(1.0))(train_data) = param(128.198)
(loss(0.0))(train_data) = param(109.428)


[32mProgress:  92%|██████████████████████████████████████   |  ETA: 1:17:14[39m

(loss(1.0))(train_data) = param(133.028)
(loss(0.0))(train_data) = param(109.027)


[32mProgress:  92%|██████████████████████████████████████   |  ETA: 1:12:41[39m

(loss(1.0))(train_data) = param(131.449)
(loss(0.0))(train_data) = param(109.019)
(loss(1.0))(train_data) = param(128.203)
(loss(0.0))(train_data) = param(109.673)


[32mProgress:  92%|██████████████████████████████████████   |  ETA: 1:08:09[39m[32mProgress:  93%|██████████████████████████████████████   |  ETA: 1:03:36[39m

(loss(1.0))(train_data) = param(128.0)
(loss(0.0))(train_data) = param(109.538)


[32mProgress:  94%|██████████████████████████████████████   |  ETA: 0:59:03[39m

(loss(1.0))(train_data) = param(127.786)
(loss(0.0))(train_data) = param(109.239)


[32mProgress:  94%|███████████████████████████████████████  |  ETA: 0:54:31[39m

(loss(1.0))(train_data) = param(128.188)
(loss(0.0))(train_data) = param(109.251)


[32mProgress:  94%|███████████████████████████████████████  |  ETA: 0:49:58[39m

(loss(1.0))(train_data) = param(127.501)
(loss(0.0))(train_data) = param(108.956)


[32mProgress:  95%|███████████████████████████████████████  |  ETA: 0:45:26[39m

(loss(1.0))(train_data) = param(133.137)
(loss(0.0))(train_data) = param(108.958)


[32mProgress:  96%|███████████████████████████████████████  |  ETA: 0:40:53[39m

(loss(1.0))(train_data) = param(131.252)
(loss(0.0))(train_data) = param(109.177)
(loss(1.0))(train_data) = param(127.962)
(loss(0.0))(train_data) = param(109.021)


[32mProgress:  96%|███████████████████████████████████████  |  ETA: 0:36:20[39m[32mProgress:  96%|████████████████████████████████████████ |  ETA: 0:31:48[39m

(loss(1.0))(train_data) = param(128.05)
(loss(0.0))(train_data) = param(108.938)


[32mProgress:  97%|████████████████████████████████████████ |  ETA: 0:27:15[39m

(loss(1.0))(train_data) = param(127.233)
(loss(0.0))(train_data) = param(108.693)


[32mProgress:  98%|████████████████████████████████████████ |  ETA: 0:22:43[39m

(loss(1.0))(train_data) = param(128.559)
(loss(0.0))(train_data) = param(108.985)
(loss(1.0))(train_data) = param(126.595)
(loss(0.0))(train_data) = param(108.847)


[32mProgress:  98%|████████████████████████████████████████ |  ETA: 0:18:10[39m[32mProgress:  98%|████████████████████████████████████████ |  ETA: 0:13:38[39m

(loss(1.0))(train_data) = param(127.839)
(loss(0.0))(train_data) = param(108.577)


[32mProgress:  99%|█████████████████████████████████████████|  ETA: 0:09:05[39m

(loss(1.0))(train_data) = param(127.183)
(loss(0.0))(train_data) = param(108.805)


[32mProgress: 100%|█████████████████████████████████████████|  ETA: 0:04:32[39m

(loss(1.0))(train_data) = param(135.504)
(loss(0.0))(train_data) = param(108.217)


[32mProgress: 100%|█████████████████████████████████████████| Time: 15:08:19[39m


In [40]:
jldopen("2018-01-25-value-tangents-data.jld2", "w") do file
    file["model"] = model
    file["params"] = params(model)
end

6-element Array{Any,1}:
 param([-0.874318 -0.894906 0.449782 -0.29218; -0.0491876 -0.916839 0.0816046 -0.0822522; … ; -0.753808 -2.2484 0.0164334 -0.0220471; -0.135119 -0.0400157 0.206078 -0.206259])                                              
 param([-0.685669, -0.508003, 0.0230763, 0.150286, -0.280697, -0.0709213, 0.370353, 0.528606, -0.17281, -1.45689  …  -0.12054, 0.0365872, 0.5142, 0.171768, -1.13385, 0.0225197, -1.77688, -0.431941, -0.724202, -0.0487545])
 param([0.0827438 0.181199 … 0.0714366 0.429206; 0.296025 0.318421 … 0.190893 0.398564; … ; -0.350043 -0.300057 … -0.0530053 -0.973669; 0.331874 0.488173 … 0.620462 0.636777])                                              
 param([-0.876633, -0.984719, -0.130419, -1.70156, 1.24424, -0.45569, 0.0369737, -1.02634, -1.16589, -0.681811  …  -0.480377, -1.24427, -1.37558, -0.426286, -1.05001, -0.729481, -0.587845, 2.35657, 1.83911, -0.917382])   
 param([0.0186335 0.091613 … 0.00885737 0.00890532])                                    

In [45]:
import LCPSim
import Hoppers
using RigidBodyDynamics
using Gurobi
using DrakeVisualizer
DrakeVisualizer.any_open_windows() || DrakeVisualizer.new_window()

Process(`/home/rdeits/locomotion/explorations/learning-mpc/packages/v0.6/DrakeVisualizer/src/../deps/usr/bin/drake-visualizer`, ProcessRunning)

In [46]:
robot = Hoppers.Hopper()
xstar = Hoppers.nominal_state(robot)
ustar = zeros(num_velocities(xstar))
basevis = Visualizer()[:hopper]
setgeometry!(basevis, robot)
settransform!(basevis[:robot], xstar)

Q, R = Hoppers.default_costs(robot)
foot = findbody(robot.mechanism, "foot")
Δt = 0.05
Jc = LCPSim.ContactLQR.contact_jacobian(xstar, [Point3D(default_frame(foot), 0., 0., 0.)])
A, B, c = LCPSim.ContactLQR.contact_linearize(xstar, ustar, Jc)

([0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], [0.0 0.0; 0.0 0.0; 0.0 1.0; 0.0 1.0], [0.0, 0.0, -9.81, -9.81])

In [47]:
model(zeros(4))

(param([223.087]), param([-32.3593 3.97868 -0.706242 0.638184]))

In [48]:
net_value_controller = state -> begin
    x = state_vector(state)
    value, jac = model(x)
    u = vec(-inv(R) * B' * Flux.Tracker.value(jac)')
end

(::#99) (generic function with 1 method)

In [51]:
x_init = MechanismState{Float64}(robot.mechanism)
set_configuration!(x_init, [1.0, 1.0])
set_velocity!(x_init, [0., 0.])
LearningMPC.randomize!(x_init, x_init, 0.5, 1.0)
results = LCPSim.simulate(x_init, net_value_controller,
    robot.environment,
    Δt,
    100,
    GurobiSolver(Gurobi.Env(), OutputFlag=0));

In [52]:
LearningMPC.playback(basevis[:robot], results, Δt)