In [1]:
using JLD2
using FileIO
import LearningMPC



In [2]:
addprocs(6)

6-element Array{Int64,1}:
 2
 3
 4
 5
 6
 7

In [None]:
@everywhere using Flux
@everywhere using MLDataPattern
@everywhere using CoordinateTransformations
@everywhere using ProgressMeter
@everywhere import FluxExtensions

In [4]:
samples = load("../2018-02-07-hopper-smaller-grid/grid_search.jld2")["samples"];

In [5]:
filter!(samples) do sample
    sample.state[2] <= sample.state[1]
end;

In [6]:
features(sample::LearningMPC.Sample) = (sample.state, sample.uJ[:, 1])
data = features.(samples);
train_data, test_data = splitobs(shuffleobs(data), at=0.85);

In [7]:
@everywhere function attention_model()
    signals = Chain(Dense(4, 2 * 16), x -> reshape(x, 16, 2))
    weights = Chain(
        LinearMap(UniformScaling(0.2)),
        Dense(4, 16, elu),
        Dense(16, 16, elu),
        softmax)
    model = FluxExtensions.Attention(signals, weights)
            
    loss = (x, y) -> Flux.mse(vec(model(x)), y)
    model, loss
end

@everywhere function layered_model()
    model = Chain(
        LinearMap(UniformScaling(0.2)),
        Dense(4, 16, elu),
        Dense(16, 16, elu),
        Dense(16, 2))
            
    loss = (x, y) -> Flux.mse(model(x), y)
    model, loss
end

In [8]:
attention_losses = @parallel (vcat) for (train, validation) in collect(kfolds(train_data; k=10))
    model, loss = attention_model()
    opt = Flux.ADADelta(params(model))
    for i in 1:500
        Flux.train!(loss, train, opt)
    end
    mean(xy -> Flux.Tracker.value(loss(xy...)), validation)
end

10-element Array{Float64,1}:
 10.4491 
  9.05005
 12.0563 
  9.333  
 10.5554 
 10.1823 
  9.94241
  7.99761
  8.29451
  9.03425

In [9]:
layered_losses = @parallel (vcat) for (train, validation) in collect(kfolds(train_data; k=10))
    model, loss = layered_model()
    opt = Flux.ADADelta(params(model))
    for i in 1:500
        Flux.train!(loss, train, opt)
    end
    mean(xy -> Flux.Tracker.value(loss(xy...)), validation)
end

10-element Array{Float64,1}:
 12.2559 
  9.37764
  9.24167
 10.4031 
 10.7871 
 10.033  
 11.1641 
  8.99252
  9.42132
  9.52264

In [10]:
(mean(attention_losses), std(attention_losses))

(9.689493039288115, 1.2022650368700554)

In [11]:
(mean(layered_losses), std(layered_losses))

(10.11989090119604, 1.034061778507911)