This repository has been archived by the owner on Jun 26, 2023. It is now read-only.
forked from FluxML/Flux.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
control.jl
49 lines (38 loc) · 1.4 KB
/
control.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
type Chain
layers::Vector{Any}
Chain(xs...) = new([xs...])
end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward Chain.layers Base.start, Base.next, Base.done
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
params(s::Chain) = mapreduce(params, append!, s.layers)
function back!(s::Chain, Δ, x)
crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer
push!(crumbs, layer(crumbs[end]))
end
foldr(Δ, collect(zip(crumbs, s.layers))) do pack, Δ
x, layer = pack
back!(layer, Δ, x)
end
end
graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
# Chain Macros
inferred(f, in, args...; kws...) = f(args...; kws...)
# `inferchain` allows for overriding inference behaviour for convenience.
# For example, `infer(Affine(10, 20), nothing)` would normally return a shape
# error, but for the interface we just ignore any errors and return (1, 20).
inferchain(f, xs...) = infer(f, xs...)
macro Chain(x, xs...)
inferconstructor(x) =
@capture(x, f_(xs__)) ? :(inferred($(esc(f)), (shape,), $(esc.(xs)...))) : esc(x)
@q let
shape = nothing
c = Chain($(esc(x)))
$([:(shape = inferchain(c.layers[end], shape);
push!(c, $x)) for x in inferconstructor.(xs)]...)
c
end
end