In [8]:
using Flux
using Plots; gr()

Plots.GRBackend()

In [9]:
function oracle(x)
    if x < 2
        -0.5x + 3
    else
        2x - 7
    end
end

oracle (generic function with 1 method)

In [10]:
xx = linspace(0, 5, 100)
plot(xx, oracle.(xx))

train_x = linspace(0, 5, 10)
train_y = oracle.(train_x)

wrap(x) = [x]
unwrap(x) = (@assert length(x) == 1; Flux.Tracker.value(first(x)))

train_data = [(wrap.(train_x), wrap.(train_y))]

1-element Array{Tuple{Array{Array{Float64,1},1},Array{Array{Float64,1},1}},1}:
 (Array{Float64,1}[[0.0], [0.555556], [1.11111], [1.66667], [2.22222], [2.77778], [3.33333], [3.88889], [4.44444], [5.0]], Array{Float64,1}[[3.0], [2.72222], [2.44444], [2.16667], [-2.55556], [-1.44444], [-0.333333], [0.777778], [1.88889], [3.0]])

In [11]:
module Att

using Flux

struct Attention{T1, T2}
    signals::T1
    weights::T2
end

function (a::Attention)(x)
    sum(a.weights(x) .* a.signals(x))
end

Flux.params(a::Attention) = vcat(params(a.signals), params(a.weights))

end

Att

In [12]:
signals = Dense(1, 2)
weights = Chain(Dense(1, 4, elu), Dense(4, 2, elu), softmax)
model = Att.Attention(signals, weights)

loss = (x, y) -> sum(z -> sum(abs2.(z)), model.(x) .- y) / length(x)
opt = Flux.Optimise.ADADelta(params(model))

(::#71) (generic function with 1 method)

In [15]:
xx = linspace(0, 5, 100)
plt = plot(xx, oracle.(xx))
plot!(plt, xx, unwrap.(model.(wrap.(xx))))

In [40]:
for i in 1:1000
    Flux.train!(loss, train_data, opt)
end

In [41]:
xx = linspace(0, 5, 100)
plt = plot(xx, oracle.(xx))
plot!(plt, xx, unwrap.(model.(wrap.(xx))))
scatter!(plt, train_x, train_y)

In [39]:
params(model)

6-element Array{Any,1}:
 param([-0.487664; 1.89043])                                                    
 param([2.99542, -6.57796])                                                     
 param([1.75288; -0.745934; 1.49678; -2.21352])                                 
 param([-1.30016, 1.74908, -0.990811, 5.27682])                                 
 param([-0.644879 2.71131 -0.484501 3.47082; 1.80944 -2.31656 2.41192 -2.52793])
 param([1.29037, -0.475701])                                                    

In [42]:
sum(length, params(model))

22

In [57]:
chainmodel = Chain(
    Dense(1, 4, elu),
    Dense(4, 4, elu),
    Dense(4, 1, elu)
)

chainloss = (x, y) -> sum(z -> sum(abs2.(z)), chainmodel.(x) .- y) / length(x)
chainopt = Flux.Optimise.ADADelta(params(chainmodel))

(::#71) (generic function with 1 method)

In [58]:
xx = linspace(0, 5, 100)
plt = plot(xx, oracle.(xx))
plot!(plt, xx, unwrap.(chainmodel.(wrap.(xx))))

In [61]:
for i in 1:5000
    Flux.train!(chainloss, train_data, chainopt)
end

In [62]:
xx = linspace(0, 5, 100)
plt = plot(xx, oracle.(xx))
plot!(plt, xx, unwrap.(chainmodel.(wrap.(xx))))