In [28]:
using Plots; gr()
using JLD2, FileIO
using Optim
using ReverseDiff
using ReverseDiff: @forward

In [9]:
import Nets
import LearningMPC

[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/rdeits/locomotion/explorations/learning-mpc/packages/lib/v0.6/LearningMPC.ji for module LearningMPC.
[39m

In [193]:
function interval_hinge_loss(x::T) where T
    if x < -1
        -1 - x
    elseif x > 1
        x - 1
    else
        zero(T)
    end
end

sample_loss = (net, sample) -> begin
    x, y = sample
    ŷ = Nets.predict(net, x)
    sum(@forward(interval_hinge_loss).((ŷ .- mean.(y)) ./ ((last.(y) .- first.(y)) ./ 2)))
end

function cost_function(sample_loss::Function, net::Nets.Net, training_data)
    function f(params)
        n = similar(net, params)
        mean(training_data) do sample
            sample_loss(n, sample)
        end
    end
    
    loss_tape = ReverseDiff.compile(ReverseDiff.GradientTape(f, (net.params.data,)))
    g! = (∇, params) -> ReverseDiff.gradient!((∇,), loss_tape, (params,))
    f, g!
end

cost_function (generic function with 1 method)

In [195]:
samples = load("2017-12-20-hopper-grid/grid_search.jld2")["samples"];

filter!(samples) do sample
    (sample.state[1] == 0.25 &&
     sample.state[2] == 0.25 &&
     sample.state[3] == sample.state[4])
end

# samples = samples[1:100:length(samples)]

x = [[s.state[3]] for s in samples]
y = [[(s.mip.objective_bound, s.mip.objective_value)] for s in samples];
training_data = collect(zip(x, y));

In [196]:
# training_data = [([0.0], [(1.0, 2.0)])]

In [202]:
net = Nets.Net(zeros(Nets.Params{Float64}, [1, 10, 10, 1]), Nets.elu)
for I in eachindex(net.params.data)
    net.params.data[I] += 0.1 * randn()
end

f, g! = cost_function(sample_loss, net, training_data)

(f, #578)

In [203]:
plt = scatter(first.(first.(training_data)), first.(first.(last.(training_data))))
scatter!(plt, first.(first.(training_data)), last.(first.(last.(training_data))), markercolor=colorant"red")
xs = linspace(-5, 5, 100)
plot!(plt, xs, first.(net.([[xi] for xi in xs])))

In [204]:
solver = LBFGS()
options = Optim.Options(allow_f_increases=false)
results = optimize(f, g!, copy(net.params.data), solver, options)
@show results
net.params.data .= results.minimizer;

results = Results of Optimization Algorithm
 * Algorithm: L-BFGS
 * Starting Point: [-0.01937387598201116,0.07915712857896127, ...]
 * Minimizer: [-1.3007787589894435,2.4727211412770465, ...]
 * Minimum: 9.420851e-02
 * Iterations: 243
 * Convergence: false
   * |x - x'| < 1.0e-32: false 
     |x - x'| = 1.44e-04 
   * |f(x) - f(x')| / |f(x)| < 1.0e-32: false
     |f(x) - f(x')| / |f(x)| = 6.95e-05 
   * |g(x)| < 1.0e-08: false 
     |g(x)| = 6.79e+00 
   * stopped by an increasing objective: false
   * Reached Maximum Number of Iterations: false
 * Objective Calls: 1750
 * Gradient Calls: 1750




In [205]:
plt = scatter(first.(first.(training_data)), first.(first.(last.(training_data))))
scatter!(plt, first.(first.(training_data)), last.(first.(last.(training_data))), markercolor=colorant"red")
xs = linspace(-5, 5, 100)
plot!(plt, xs, first.(net.([[xi] for xi in xs])))