diff --git a/Project.toml b/Project.toml
index 32bf227..9a0b3b9 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte
", "Ali Siahkoohi ", "Mathias Louboutin ", "Gabrio Rizzuti ", "Rafael Orozco ", "Felix J. herrmann "]
-version = "2.2.4"
+version = "2.2.5"
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
diff --git a/src/layers/invertible_layer_glow.jl b/src/layers/invertible_layer_glow.jl
index 39d8e1c..b261172 100644
--- a/src/layers/invertible_layer_glow.jl
+++ b/src/layers/invertible_layer_glow.jl
@@ -142,7 +142,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL
ΔX2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X2) + ΔY2
else
ΔX2, Δθrb = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT; ), X2; set_grad=set_grad)
- _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS_, S), 0f0.*ΔT;), X2; set_grad=set_grad)
+ _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), 0f0.*ΔT;), X2; set_grad=set_grad)
ΔX2 += ΔY2
end
ΔX_ = tensor_cat(ΔX1, ΔX2)
diff --git a/src/layers/invertible_layer_hint.jl b/src/layers/invertible_layer_hint.jl
index ca21a45..fcac839 100644
--- a/src/layers/invertible_layer_hint.jl
+++ b/src/layers/invertible_layer_hint.jl
@@ -299,7 +299,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL
end
# Input are two tensors ΔX, X
-function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing) where {T, N}
+function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, set_grad::Bool=true) where {T, N}
isnothing(permute) ? permute = H.permute : permute = permute
# Permutation
diff --git a/src/utils/chainrules.jl b/src/utils/chainrules.jl
index 36456c7..46bc769 100644
--- a/src/utils/chainrules.jl
+++ b/src/utils/chainrules.jl
@@ -1,8 +1,9 @@
using ChainRulesCore
-export logdetjac
-import ChainRulesCore: frule, rrule
-
+export logdetjac, getrrule
+import ChainRulesCore: frule, rrule, @non_differentiable
+@non_differentiable get_params(::Invertible)
+@non_differentiable get_params(::Reversed)
## Tape types and utilities
"""
@@ -81,7 +82,6 @@ function forward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N},
if logdet isa Float32
state.logdet === nothing ? (state.logdet = logdet) : (state.logdet += logdet)
end
-
end
"""
@@ -97,15 +97,13 @@ function backward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N}
state.Y[state.counter_block] = X
state.counter_layer -= 1
end
-
state.counter_block == 0 && reset!(state) # reset state when first block/first layer is reached
-
end
## Chain rules for invertible networks
# General pullback function
function pullback(net::Invertible, ΔY::AbstractArray{T,N};
- state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
+ state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
# Check state coherency
check_coherence(state, net)
@@ -114,19 +112,19 @@ function pullback(net::Invertible, ΔY::AbstractArray{T,N};
T2 = typeof(current(state))
ΔY = convert(T2, ΔY)
# Backward pass
- ΔX, X_ = net.backward(ΔY, current(state))
-
+ ΔX, X_ = net.backward(ΔY, current(state); set_grad=true)
+ Δθ = getfield.(get_params(net), :grad)
# Update state
backward_update!(state, X_)
- return nothing, ΔX
+ return NoTangent(), NoTangent(), ΔX, Δθ
end
# Reverse-mode AD rule
-function ChainRulesCore.rrule(net::Invertible, X::AbstractArray{T, N};
+function ChainRulesCore.rrule(::typeof(forward_net), net::Invertible, X::AbstractArray{T, N}, θ...;
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
-
+
# Forward pass
net.logdet ? ((Y, logdet) = net.forward(X)) : (Y = net.forward(X); logdet = nothing)
@@ -142,4 +140,8 @@ end
## Logdet utilities for Zygote pullback
-logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet
\ No newline at end of file
+logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet
+
+## Utility to get the pullback directly for testing
+
+getrrule(net::Invertible, X::AbstractArray) = rrule(forward_net, net, X, getfield.(get_params(net), :data))
diff --git a/src/utils/dimensionality_operations.jl b/src/utils/dimensionality_operations.jl
index aa9b2a5..872d958 100644
--- a/src/utils/dimensionality_operations.jl
+++ b/src/utils/dimensionality_operations.jl
@@ -475,7 +475,7 @@ end
# Split and reshape 1D vector Y in latent space back to states Zi
# where Zi is the split tensor at each multiscale level.
-function split_states(Y::AbstractVector{T}, Z_dims) where {T, N}
+function split_states(Y::AbstractVector{T}, Z_dims) where {T}
L = length(Z_dims) + 1
inds = cumsum([1, [prod(Z_dims[j]) for j=1:L-1]...])
Z_save = [reshape(Y[inds[j]:inds[j+1]-1], xy_dims(Z_dims[j], Val(j==L), Val(length(Z_dims[j])))) for j=1:L-1]
diff --git a/src/utils/neuralnet.jl b/src/utils/neuralnet.jl
index 60b5c39..206e576 100644
--- a/src/utils/neuralnet.jl
+++ b/src/utils/neuralnet.jl
@@ -33,7 +33,7 @@ end
getproperty(I::Invertible, s::Symbol) = _get_property(I, Val{s}())
_get_property(I::Invertible, ::Val{s}) where {s} = getfield(I, s)
-_get_property(R::Reversed, ::Val{:I}) where s = getfield(R, :I)
+_get_property(R::Reversed, ::Val{:I}) = getfield(R, :I)
_get_property(R::Reversed, ::Val{s}) where s = _get_property(R.I, Val{s}())
for m ∈ _INet_modes
@@ -128,4 +128,5 @@ function set_params!(N::Invertible, θnew::Array{Parameter, 1})
end
# Make invertible nets callable objects
-(N::Invertible)(X::AbstractArray{T,N} where {T, N}) = N.forward(X)
+(net::Invertible)(X::AbstractArray{T,N} where {T, N}) = forward_net(net, X, getfield.(get_params(net), :data))
+forward_net(net::Invertible, X::AbstractArray{T,N}, ::Any) where {T, N} = net.forward(X)
\ No newline at end of file
diff --git a/test/runtests.jl b/test/runtests.jl
index 368f6f6..227bcaa 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -17,7 +17,8 @@ basics = ["test_utils/test_objectives.jl",
"test_utils/test_activations.jl",
"test_utils/test_squeeze.jl",
"test_utils/test_jacobian.jl",
- "test_utils/test_chainrules.jl"]
+ "test_utils/test_chainrules.jl",
+ "test_utils/test_flux.jl"]
# Layers
layers = ["test_layers/test_residual_block.jl",
diff --git a/test/test_utils/test_chainrules.jl b/test/test_utils/test_chainrules.jl
index 941d628..bde7944 100644
--- a/test/test_utils/test_chainrules.jl
+++ b/test/test_utils/test_chainrules.jl
@@ -23,22 +23,22 @@ N10 = CouplingLayerHINT(n_ch, n_hidden; logdet=logdet, permute="full")
# Forward pass + gathering pullbacks
function fw(X)
- X1, ∂1 = rrule(N1, X)
- X2, ∂2 = rrule(N2, X1)
- X3, ∂3 = rrule(N3, X2)
+ X1, ∂1 = getrrule(N1, X)
+ X2, ∂2 = getrrule(N2, X1)
+ X3, ∂3 = getrrule(N3, X2)
X5, ∂5 = Flux.Zygote.pullback(Chain(N4, N5), X3)
- X6, ∂6 = rrule(N6, X5)
- X7, ∂7 = rrule(N7, X6)
+ X6, ∂6 = getrrule(N6, X5)
+ X7, ∂7 = getrrule(N7, X6)
X9, ∂9 = Flux.Zygote.pullback(Chain(N8, N9), X7)
- X10, ∂10 = rrule(N10, X9)
- d1 = x -> ∂1(x)[2]
- d2 = x -> ∂2(x)[2]
- d3 = x -> ∂3(x)[2]
+ X10, ∂10 = getrrule(N10, X9)
+ d1 = x -> ∂1(x)[3]
+ d2 = x -> ∂2(x)[3]
+ d3 = x -> ∂3(x)[3]
d5 = x -> ∂5(x)[1]
- d6 = x -> ∂6(x)[2]
- d7 = x -> ∂7(x)[2]
+ d6 = x -> ∂6(x)[3]
+ d7 = x -> ∂7(x)[3]
d9 = x -> ∂9(x)[1]
- d10 = x -> ∂10(x)[2]
+ d10 = x -> ∂10(x)[3]
return X10, ΔY -> d1(d2(d3(d5(d6(d7(d9(d10(ΔY))))))))
end
@@ -65,9 +65,9 @@ g2 = gradient(X -> loss(X), X)
## test Reverse network AD
Nrev = reverse(N10)
-Xrev, ∂rev = rrule(Nrev, X)
+Xrev, ∂rev = getrrule(Nrev, X)
grev = ∂rev(Xrev-Y0)
g2rev = gradient(X -> 0.5f0*norm(Nrev(X) - Y0)^2, X)
-@test grev[2] ≈ g2rev[1] rtol=1f-6
+@test grev[3] ≈ g2rev[1] rtol=1f-6
diff --git a/test/test_utils/test_flux.jl b/test/test_utils/test_flux.jl
new file mode 100644
index 0000000..47dad13
--- /dev/null
+++ b/test/test_utils/test_flux.jl
@@ -0,0 +1,50 @@
+using InvertibleNetworks, Flux, Test, LinearAlgebra
+
+# Define network
+nx = 1
+ny = 1
+n_in = 2
+n_hidden = 10
+batchsize = 32
+
+# net
+AN = ActNorm(n_in; logdet = false)
+C = CouplingLayerGlow(n_in, n_hidden; logdet = false, k1 = 1, k2 = 1, p1 = 0, p2 = 0)
+pan, pc = deepcopy(get_params(AN)), deepcopy(get_params(C))
+model = Chain(AN, C)
+
+# dummy input & target
+X = randn(Float32, nx, ny, n_in, batchsize)
+Y = model(X)
+X0 = rand(Float32, nx, ny, n_in, batchsize) .+ 1
+
+# loss fn
+loss(model, X, Y) = Flux.mse(Y, model(X))
+
+# old, implicit-style Flux
+θ = Flux.params(model)
+opt = Descent(0.001)
+
+l, grads = Flux.withgradient(θ) do
+ loss(model, X0, Y)
+end
+
+for θi in θ
+ @test θi ∈ keys(grads.grads)
+ @test !isnothing(grads.grads[θi])
+ @test size(grads.grads[θi]) == size(θi)
+end
+
+Flux.update!(opt, θ, grads)
+
+for i = 1:5
+ li, grads = Flux.withgradient(θ) do
+ loss(model, X, Y)
+ end
+
+ @info "Loss: $li"
+ @test li != l
+ global l = li
+
+ Flux.update!(opt, θ, grads)
+end
\ No newline at end of file