# MapIn functionality

In [1]:
using Flux

In [82]:
struct MapIn{T<:Tuple}
    layers::T
    MapIn(xs...) = new{typeof(xs)}(xs)
end
applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(Base.tail(fs), first(fs)(x))
(a::MapIn)(x) = broadcast(y -> applychain(a.layers, y), x)
(a::MapIn)(x::Array{Float32,1}) = broadcast(first, a(map(y -> [y], x)))
Flux.trainable(a::MapIn) = a.layers


In [119]:
#@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex
Base.getindex(x::MapIn, args...) = Base.getindex(x.layers)
Base.length(x::MapIn, args...) = Base.length(x.layers)
Base.first(x::MapIn, args...) = Base.first(x.layers)
Base.last(x::MapIn, args...) = Base.last(x.layers)
Base.iterate(x::MapIn, args...) = Base.iterate(x.layers)
Base.lastindex(x::MapIn, args...) = Base.lastindex(x.layers)

In [118]:
println(length(Chain(Dense(1,4),Dense(4,1))))
println(length(MapIn(Chain(Dense(1,4),Dense(4,1)))))
println(length(MapIn(Dense(1,4),Dense(4,1))))

2
1
2


## Testing single input map

In [83]:
layerExample1 = Chain(Dense(1, 4), Dense(4, 1))
layerExample2 = Chain(Dense(1, 4), Dense(4, 4), Dense(4, 1))
layerExample3 = Chain(Chain(Dense(1, 4), Dense(4, 4)), Dense(4, 4), Dense(4, 1))
layerExample4 = (Dense(1,4), Dense(4,1))
layerExample5 = (Chain(Dense(1, 4), Dense(4, 4)), Dense(4,1));

In [84]:
mapExample1 = MapIn(layerExample1)
mapExample2 = MapIn(layerExample2)
mapExample3 = MapIn(layerExample3)
mapExample4 = MapIn(layerExample4...)
mapExample5 = MapIn(layerExample5...);

In [86]:
testInput = [Float32[10.0], Float32[150.0]]

2-element Array{Array{Float32,1},1}:
 [10.0]
 [150.0]

In [87]:
println(layerExample1.(testInput))
println(layerExample2.(testInput))
println(layerExample3.(testInput))
println(Chain(layerExample4...).(testInput))
println(Chain(layerExample5...).(testInput))

Array{Float32,1}[[13.179995], [197.69992]]
Array{Float32,1}[[-0.22761732], [-3.4142616]]
Array{Float32,1}[[4.955084], [74.32627]]
Array{Float32,1}[[4.214313], [63.214706]]
Array{Float32,1}[[-0.04624009], [-0.6935978]]


In [88]:
println(mapExample1(testInput))
println(mapExample2(testInput))
println(mapExample3(testInput))
println(mapExample4(testInput))
println(mapExample5(testInput))

Array{Float32,1}[[13.179995], [197.69992]]
Array{Float32,1}[[-0.22761732], [-3.4142616]]
Array{Float32,1}[[4.955084], [74.32627]]
Array{Float32,1}[[4.214313], [63.214706]]
Array{Float32,1}[[-0.04624009], [-0.6935978]]


In [89]:
println(params(layerExample1))
println(params(layerExample2))
println(params(layerExample3))
println(params(layerExample4))
println(params(layerExample5))

Params([Float32[-1.0064486; -0.93504226; -0.8504779; 0.72340864], Float32[0.0, 0.0, 0.0, 0.0], Float32[-1.0804373 0.3765047 -0.31961435 0.4296571], Float32[0.0]])
Params([Float32[-0.19757561; -0.56147265; -0.91904217; 0.5786279], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.047808196 -0.6531654 0.15053648 0.7152996; 0.28748927 0.3779386 0.091013 0.5178525; -0.17343782 0.09136649 -0.07390766 -0.742634; 0.6729056 -0.828108 0.76132566 0.5507739], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.1832609 -0.21096186 0.42679384 -0.24006322], Float32[0.0]])
Params([Float32[0.2938482; -0.10416369; 0.34470356; 0.85389453], Float32[0.0, 0.0, 0.0, 0.0], Float32[-0.77527004 0.55330014 0.8563348 0.6562407; -0.31134906 0.23269184 -0.68226826 -0.10193954; 0.19654316 -0.4510282 0.30097693 -0.8477545; -0.76206505 0.59767276 0.08153965 -0.30116215], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.79407823 -0.17102307 0.35262477 0.46181247; -0.7917516 0.31127062 0.12796839 -0.7307273; 0.29722732 -0.19282556 0.13755983 0.25414476

In [90]:
println(params(mapExample1))
println(params(mapExample2))
println(params(mapExample3))
println(params(mapExample4))
println(params(mapExample5))

Params([Float32[-1.0064486; -0.93504226; -0.8504779; 0.72340864], Float32[0.0, 0.0, 0.0, 0.0], Float32[-1.0804373 0.3765047 -0.31961435 0.4296571], Float32[0.0]])
Params([Float32[-0.19757561; -0.56147265; -0.91904217; 0.5786279], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.047808196 -0.6531654 0.15053648 0.7152996; 0.28748927 0.3779386 0.091013 0.5178525; -0.17343782 0.09136649 -0.07390766 -0.742634; 0.6729056 -0.828108 0.76132566 0.5507739], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.1832609 -0.21096186 0.42679384 -0.24006322], Float32[0.0]])
Params([Float32[0.2938482; -0.10416369; 0.34470356; 0.85389453], Float32[0.0, 0.0, 0.0, 0.0], Float32[-0.77527004 0.55330014 0.8563348 0.6562407; -0.31134906 0.23269184 -0.68226826 -0.10193954; 0.19654316 -0.4510282 0.30097693 -0.8477545; -0.76206505 0.59767276 0.08153965 -0.30116215], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.79407823 -0.17102307 0.35262477 0.46181247; -0.7917516 0.31127062 0.12796839 -0.7307273; 0.29722732 -0.19282556 0.13755983 0.25414476

## Testing multiple input map

In [91]:
layerExample6 = Chain(Dense(2, 4), Dense(4, 3))
layerExample7 = Chain(Dense(2, 4), Dense(4, 4), Dense(4, 3))
layerExample8 = Chain(Chain(Dense(2, 4), Dense(4, 4)), Dense(4, 4), Dense(4, 3))
layerExample9 = (Dense(2,4), Dense(4,3));

In [92]:
mapExample6 = MapIn(layerExample6)
mapExample7 = MapIn(layerExample7)
mapExample8 = MapIn(layerExample8)
mapExample9 = MapIn(layerExample9...);

In [93]:
testInput2 = [Float32[10.0, 150.0], Float32[10.0, 90.0]]
println(testInput2)
println(testInput2[2])

Array{Float32,1}[[10.0, 150.0], [10.0, 90.0]]
Float32[10.0, 90.0]


In [95]:
println(layerExample6(testInput2[2]))
println(layerExample7(testInput2[2]))
println(layerExample8(testInput2[2]))
println(Chain(layerExample9...)(testInput2[2]))

Float32[47.735157, -32.752495, -11.39335]
Float32[0.7560291, 23.671745, 80.52094]
Float32[-8.827183, -47.86567, -75.4666]
Float32[-31.583313, 29.015354, -14.199995]


In [96]:
println(layerExample6.(testInput2))
println(layerExample7.(testInput2))
println(layerExample8.(testInput2))
println(Chain(layerExample9...).(testInput2))

Array{Float32,1}[[81.01882, -51.16385, -14.6287565], [47.735157, -32.752495, -11.39335]]
Array{Float32,1}[[1.7693024, 42.105507, 133.95842], [0.7560291, 23.671745, 80.52094]]
Array{Float32,1}[[-11.751091, -90.60091, -125.74857], [-8.827183, -47.86567, -75.4666]]
Array{Float32,1}[[-54.348114, 53.44255, -26.874184], [-31.583313, 29.015354, -14.199995]]


In [97]:
println(mapExample6(testInput2))
println(mapExample7(testInput2))
println(mapExample8(testInput2))
println(mapExample9(testInput2))

Array{Float32,1}[[81.01882, -51.16385, -14.6287565], [47.735157, -32.752495, -11.39335]]
Array{Float32,1}[[1.7693024, 42.105507, 133.95842], [0.7560291, 23.671745, 80.52094]]
Array{Float32,1}[[-11.751091, -90.60091, -125.74857], [-8.827183, -47.86567, -75.4666]]
Array{Float32,1}[[-54.348114, 53.44255, -26.874184], [-31.583313, 29.015354, -14.199995]]


## Parameters are trainable

In [98]:
println(Flux.trainable(layerExample1))
println(Flux.trainable(layerExample2))
println(Flux.trainable(layerExample3))
println(Flux.trainable(layerExample4))
println(Flux.trainable(layerExample5))
println(Flux.trainable(layerExample6))
println(Flux.trainable(layerExample7))
println(Flux.trainable(layerExample8))
println(Flux.trainable(layerExample9))

(Dense(1, 4), Dense(4, 1))
(Dense(1, 4), Dense(4, 4), Dense(4, 1))
(Chain(Dense(1, 4), Dense(4, 4)), Dense(4, 4), Dense(4, 1))
(Dense(1, 4), Dense(4, 1))
(Chain(Dense(1, 4), Dense(4, 4)), Dense(4, 1))
(Dense(2, 4), Dense(4, 3))
(Dense(2, 4), Dense(4, 4), Dense(4, 3))
(Chain(Dense(2, 4), Dense(4, 4)), Dense(4, 4), Dense(4, 3))
(Dense(2, 4), Dense(4, 3))


In [99]:
println(Flux.trainable(mapExample1))
println(Flux.trainable(mapExample2))
println(Flux.trainable(mapExample3))
println(Flux.trainable(mapExample4))
println(Flux.trainable(mapExample5))
println(Flux.trainable(mapExample6))
println(Flux.trainable(mapExample7))
println(Flux.trainable(mapExample8))
println(Flux.trainable(mapExample9))

(Chain(Dense(1, 4), Dense(4, 1)),)
(Chain(Dense(1, 4), Dense(4, 4), Dense(4, 1)),)
(Chain(Chain(Dense(1, 4), Dense(4, 4)), Dense(4, 4), Dense(4, 1)),)
(Dense(1, 4), Dense(4, 1))
(Chain(Dense(1, 4), Dense(4, 4)), Dense(4, 1))
(Chain(Dense(2, 4), Dense(4, 3)),)
(Chain(Dense(2, 4), Dense(4, 4), Dense(4, 3)),)
(Chain(Chain(Dense(2, 4), Dense(4, 4)), Dense(4, 4), Dense(4, 3)),)
(Dense(2, 4), Dense(4, 3))


## Array comprehension flattening
Testing `(a::MapIn)(x::Array{Float32,1}) = broadcast(first, a(map(y -> [y], x)))`

In [100]:
testInputArray = Float32[10.0, 150.0]

2-element Array{Float32,1}:
  10.0
 150.0

In [106]:
println(mapExample1(testInput))
println(mapExample2(testInput))
println(mapExample3(testInput))
println(mapExample4(testInput))
println(mapExample5(testInput))

Array{Float32,1}[[13.179995], [197.69992]]
Array{Float32,1}[[-0.22761732], [-3.4142616]]
Array{Float32,1}[[4.955084], [74.32627]]
Array{Float32,1}[[4.214313], [63.214706]]
Array{Float32,1}[[-0.04624009], [-0.6935978]]


In [105]:
println(mapExample1(testInputArray))
println(mapExample2(testInputArray))
println(mapExample3(testInputArray))
println(mapExample4(testInputArray))
println(mapExample5(testInputArray))

Float32[13.179995, 197.69992]
Float32[-0.22761732, -3.4142616]
Float32[4.955084, 74.32627]
Float32[4.214313, 63.214706]
Float32[-0.04624009, -0.6935978]


In [107]:
println(Flux.trainable(mapExample1))
println(Flux.trainable(mapExample2))
println(Flux.trainable(mapExample3))
println(Flux.trainable(mapExample4))
println(Flux.trainable(mapExample5))
println(Flux.trainable(mapExample6))

(Chain(Dense(1, 4), Dense(4, 1)),)
(Chain(Dense(1, 4), Dense(4, 4), Dense(4, 1)),)
(Chain(Chain(Dense(1, 4), Dense(4, 4)), Dense(4, 4), Dense(4, 1)),)
(Dense(1, 4), Dense(4, 1))
(Chain(Dense(1, 4), Dense(4, 4)), Dense(4, 1))
(Chain(Dense(2, 4), Dense(4, 3)),)
