In [52]:
using ForwardDiff
using Flux
using Plots; gr()

Plots.GRBackend()

In [47]:
module Att

using Flux
using ForwardDiff

struct Attention{T1, T2}
    signals::T1
    weights::T2
end

function (a::Attention)(x)
    sum(a.weights(x) .* a.signals(x), 1)
end

Flux.params(a::Attention) = vcat(params(a.signals), params(a.weights))

struct TangentPropagator{F <: Function, C}
    f::F
    layer::C
end

function TangentPropagator(chain::Chain)
    f = reduce(∘, identity, _propagate_tangent.(reverse(chain.layers)))
    TangentPropagator(x -> f((x, eye(length(x)))), chain)
end

(p::TangentPropagator)(x) = p.f(x)

Flux.params(p::TangentPropagator) = Flux.params(p.layer)

function _propagate_tangent(f)
    (xJ) -> begin
        (f(xJ[1]), ForwardDiff.jacobian(f, Flux.Tracker.value(xJ[1])) * xJ[2])
    end
end

function _propagate_tangent(f::Dense)
    xJ -> begin
        x, J = xJ
        y = f.W * x + f.b
        gσ = ForwardDiff.derivative.(f.σ, y)
        (f(x), gσ .* f.W * J)
    end
end

function TangentPropagator(a::Attention)
    t1 = TangentPropagator(a.signals)
    t2 = TangentPropagator(a.weights)
    function f(x)
        x1, J1 = t1(x)
        x2, J2 = t2(x)
        sum(x1 .* x2), sum(x1 .* J2, 1) .+ sum(x2 .* J1, 1)
    end
    TangentPropagator(f, a)
end

TangentPropagator(d::Dense) = TangentPropagator(Chain(d))
        

# tangent(f, x) = ForwardDiff.jacobian(f, x)

# function tangent(f::Dense, x)
#     x, J = xJ
#     y = f.W * x + f.b
#     gσ = ForwardDiff.derivative.(f.σ, y)
#     gσ .* f.W * J
# end

end



Att

In [48]:
signals = Dense(1, 2)
weights = Chain(Dense(1, 4, elu), Dense(4, 2, elu), softmax)
model = Att.Attention(signals, weights)

t = Att.TangentPropagator(model)

Att.TangentPropagator{Att.#f#7{Att.TangentPropagator{Att.##1#2{Base.##55#56{Base.#identity,Att.##5#6{Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}}}},Flux.Chain},Att.TangentPropagator{Att.##1#2{Base.##55#56{Base.##55#56{Base.##55#56{Base.#identity,Att.##3#4{NNlib.#softmax}},Att.##5#6{Flux.Dense{NNlib.#elu,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}}},Att.##5#6{Flux.Dense{NNlib.#elu,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}}}},Flux.Chain}},Att.Attention{Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}},Flux.Chain}}(Att.f, Att.Attention{Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}},Flux.Chain}(Dense(1, 2), Chain(Dense(1, 4, NNlib.elu), Dense(4, 2, NNlib.elu), NNlib.softmax)))

In [56]:
eps = 1e-6
x = [1.5]
model(x), (model(x + eps) - model(x)) / eps

(param([-1.23506]), param([-0.835811]))

In [57]:
t(x)

(param(-1.23506), param([-0.835811]))

In [88]:
function oracle(x)
    if x < 2
        -0.5x + 3
    else
        2x - 7
    end
end

xx = linspace(0, 5, 100)
plot(xx, oracle.(xx))

train_x = linspace(0, 5, 5)
train_y = oracle.(train_x)

wrap(x) = [x]
unwrap(x) = (@assert length(x) == 1; Flux.Tracker.value(first(x)))

train_data = [(wrap.(train_x), wrap.(train_y), fill.(ForwardDiff.derivative.(oracle, train_x), 1, 1))]

1-element Array{Tuple{Array{Array{Float64,1},1},Array{Array{Float64,1},1},Array{Array{Float64,2},1}},1}:
 (Array{Float64,1}[[0.0], [1.25], [2.5], [3.75], [5.0]], Array{Float64,1}[[3.0], [2.375], [-2.0], [0.5], [3.0]], Array{Float64,2}[[-0.5], [-0.5], [2.0], [2.0], [2.0]])

In [107]:
signals = Dense(1, 2)
weights = Chain(Dense(1, 4, elu), Dense(4, 2, elu), softmax)
model = Att.Attention(signals, weights)
tmodel = Att.TangentPropagator(model)

λ = 0.05

loss = (x, y, J) -> begin
    yJ = collect(zip(y, J))
    yJhat = tmodel.(x)
    
    function sample_error(yJi, yJhati)
        yi, Ji = yJi
        yhati, Jhati = yJhati
        (1 - λ) * sum(abs2.(yi .- yhati)) + λ * sum(abs2.(Ji .- Jhati))
    end
    sum(sample_error.(yJ, yJhat)) / length(x)
end
opt = Flux.Optimise.ADADelta(params(tmodel))

(::#71) (generic function with 1 method)

In [108]:
xx = linspace(0, 5, 100)
plt = plot(xx, oracle.(xx))
plot!(plt, xx, unwrap.(model.(wrap.(xx))))
scatter!(plt, train_x, train_y)

In [103]:
for i in 1:5000
    Flux.train!(loss, train_data, opt)
end

In [104]:
xx = linspace(0, 5, 100)
plt = plot(xx, oracle.(xx))
plot!(plt, xx, unwrap.(model.(wrap.(xx))))
scatter!(plt, train_x, train_y)

In [113]:
signals = Dense(1, 2)
weights = Chain(Dense(1, 4, elu), Dense(4, 2, elu), softmax)
model = Att.Attention(signals, weights)

loss = (x, y, J) -> sum(z -> sum(abs2.(z)), model.(x) .- y) / length(x)
opt = Flux.Optimise.ADADelta(params(model))

(::#71) (generic function with 1 method)

In [114]:
xx = linspace(0, 5, 100)
plt = plot(xx, oracle.(xx))
plot!(plt, xx, unwrap.(model.(wrap.(xx))))
scatter!(plt, train_x, train_y)

In [121]:
for i in 1:5000
    Flux.train!(loss, train_data, opt)
end

In [122]:
xx = linspace(0, 5, 100)
plt = plot(xx, oracle.(xx))
plot!(plt, xx, unwrap.(model.(wrap.(xx))))
scatter!(plt, train_x, train_y)