From b2eeb1209b7ff0b7937441381c653d1c331f9ac9 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 20 Apr 2023 11:14:16 -0400 Subject: [PATCH] fix Flux compat --- Project.toml | 2 +- src/layers/invertible_layer_glow.jl | 2 +- src/layers/invertible_layer_hint.jl | 2 +- src/utils/chainrules.jl | 28 ++++++++------- src/utils/dimensionality_operations.jl | 2 +- src/utils/neuralnet.jl | 5 +-- test/runtests.jl | 3 +- test/test_utils/test_chainrules.jl | 28 +++++++-------- test/test_utils/test_flux.jl | 50 ++++++++++++++++++++++++++ 9 files changed, 88 insertions(+), 34 deletions(-) create mode 100644 test/test_utils/test_flux.jl diff --git a/Project.toml b/Project.toml index 32bf227..9a0b3b9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "InvertibleNetworks" uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3" authors = ["Philipp Witte ", "Ali Siahkoohi ", "Mathias Louboutin ", "Gabrio Rizzuti ", "Rafael Orozco ", "Felix J. herrmann "] -version = "2.2.4" +version = "2.2.5" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/layers/invertible_layer_glow.jl b/src/layers/invertible_layer_glow.jl index 39d8e1c..b261172 100644 --- a/src/layers/invertible_layer_glow.jl +++ b/src/layers/invertible_layer_glow.jl @@ -142,7 +142,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL ΔX2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X2) + ΔY2 else ΔX2, Δθrb = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT; ), X2; set_grad=set_grad) - _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS_, S), 0f0.*ΔT;), X2; set_grad=set_grad) + _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), 0f0.*ΔT;), X2; set_grad=set_grad) ΔX2 += ΔY2 end ΔX_ = tensor_cat(ΔX1, ΔX2) diff --git a/src/layers/invertible_layer_hint.jl b/src/layers/invertible_layer_hint.jl index ca21a45..fcac839 100644 --- a/src/layers/invertible_layer_hint.jl +++ b/src/layers/invertible_layer_hint.jl @@ -299,7 +299,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL end # Input are two tensors ΔX, X -function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing) where {T, N} +function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, set_grad::Bool=true) where {T, N} isnothing(permute) ? permute = H.permute : permute = permute # Permutation diff --git a/src/utils/chainrules.jl b/src/utils/chainrules.jl index 36456c7..46bc769 100644 --- a/src/utils/chainrules.jl +++ b/src/utils/chainrules.jl @@ -1,8 +1,9 @@ using ChainRulesCore -export logdetjac -import ChainRulesCore: frule, rrule - +export logdetjac, getrrule +import ChainRulesCore: frule, rrule, @non_differentiable +@non_differentiable get_params(::Invertible) +@non_differentiable get_params(::Reversed) ## Tape types and utilities """ @@ -81,7 +82,6 @@ function forward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N}, if logdet isa Float32 state.logdet === nothing ? (state.logdet = logdet) : (state.logdet += logdet) end - end """ @@ -97,15 +97,13 @@ function backward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N} state.Y[state.counter_block] = X state.counter_layer -= 1 end - state.counter_block == 0 && reset!(state) # reset state when first block/first layer is reached - end ## Chain rules for invertible networks # General pullback function function pullback(net::Invertible, ΔY::AbstractArray{T,N}; - state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N} + state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N} # Check state coherency check_coherence(state, net) @@ -114,19 +112,19 @@ function pullback(net::Invertible, ΔY::AbstractArray{T,N}; T2 = typeof(current(state)) ΔY = convert(T2, ΔY) # Backward pass - ΔX, X_ = net.backward(ΔY, current(state)) - + ΔX, X_ = net.backward(ΔY, current(state); set_grad=true) + Δθ = getfield.(get_params(net), :grad) # Update state backward_update!(state, X_) - return nothing, ΔX + return NoTangent(), NoTangent(), ΔX, Δθ end # Reverse-mode AD rule -function ChainRulesCore.rrule(net::Invertible, X::AbstractArray{T, N}; +function ChainRulesCore.rrule(::typeof(forward_net), net::Invertible, X::AbstractArray{T, N}, θ...; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N} - + # Forward pass net.logdet ? ((Y, logdet) = net.forward(X)) : (Y = net.forward(X); logdet = nothing) @@ -142,4 +140,8 @@ end ## Logdet utilities for Zygote pullback -logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet \ No newline at end of file +logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet + +## Utility to get the pullback directly for testing + +getrrule(net::Invertible, X::AbstractArray) = rrule(forward_net, net, X, getfield.(get_params(net), :data)) diff --git a/src/utils/dimensionality_operations.jl b/src/utils/dimensionality_operations.jl index aa9b2a5..872d958 100644 --- a/src/utils/dimensionality_operations.jl +++ b/src/utils/dimensionality_operations.jl @@ -475,7 +475,7 @@ end # Split and reshape 1D vector Y in latent space back to states Zi # where Zi is the split tensor at each multiscale level. -function split_states(Y::AbstractVector{T}, Z_dims) where {T, N} +function split_states(Y::AbstractVector{T}, Z_dims) where {T} L = length(Z_dims) + 1 inds = cumsum([1, [prod(Z_dims[j]) for j=1:L-1]...]) Z_save = [reshape(Y[inds[j]:inds[j+1]-1], xy_dims(Z_dims[j], Val(j==L), Val(length(Z_dims[j])))) for j=1:L-1] diff --git a/src/utils/neuralnet.jl b/src/utils/neuralnet.jl index 60b5c39..206e576 100644 --- a/src/utils/neuralnet.jl +++ b/src/utils/neuralnet.jl @@ -33,7 +33,7 @@ end getproperty(I::Invertible, s::Symbol) = _get_property(I, Val{s}()) _get_property(I::Invertible, ::Val{s}) where {s} = getfield(I, s) -_get_property(R::Reversed, ::Val{:I}) where s = getfield(R, :I) +_get_property(R::Reversed, ::Val{:I}) = getfield(R, :I) _get_property(R::Reversed, ::Val{s}) where s = _get_property(R.I, Val{s}()) for m ∈ _INet_modes @@ -128,4 +128,5 @@ function set_params!(N::Invertible, θnew::Array{Parameter, 1}) end # Make invertible nets callable objects -(N::Invertible)(X::AbstractArray{T,N} where {T, N}) = N.forward(X) +(net::Invertible)(X::AbstractArray{T,N} where {T, N}) = forward_net(net, X, getfield.(get_params(net), :data)) +forward_net(net::Invertible, X::AbstractArray{T,N}, ::Any) where {T, N} = net.forward(X) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 368f6f6..227bcaa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,8 @@ basics = ["test_utils/test_objectives.jl", "test_utils/test_activations.jl", "test_utils/test_squeeze.jl", "test_utils/test_jacobian.jl", - "test_utils/test_chainrules.jl"] + "test_utils/test_chainrules.jl", + "test_utils/test_flux.jl"] # Layers layers = ["test_layers/test_residual_block.jl", diff --git a/test/test_utils/test_chainrules.jl b/test/test_utils/test_chainrules.jl index 941d628..bde7944 100644 --- a/test/test_utils/test_chainrules.jl +++ b/test/test_utils/test_chainrules.jl @@ -23,22 +23,22 @@ N10 = CouplingLayerHINT(n_ch, n_hidden; logdet=logdet, permute="full") # Forward pass + gathering pullbacks function fw(X) - X1, ∂1 = rrule(N1, X) - X2, ∂2 = rrule(N2, X1) - X3, ∂3 = rrule(N3, X2) + X1, ∂1 = getrrule(N1, X) + X2, ∂2 = getrrule(N2, X1) + X3, ∂3 = getrrule(N3, X2) X5, ∂5 = Flux.Zygote.pullback(Chain(N4, N5), X3) - X6, ∂6 = rrule(N6, X5) - X7, ∂7 = rrule(N7, X6) + X6, ∂6 = getrrule(N6, X5) + X7, ∂7 = getrrule(N7, X6) X9, ∂9 = Flux.Zygote.pullback(Chain(N8, N9), X7) - X10, ∂10 = rrule(N10, X9) - d1 = x -> ∂1(x)[2] - d2 = x -> ∂2(x)[2] - d3 = x -> ∂3(x)[2] + X10, ∂10 = getrrule(N10, X9) + d1 = x -> ∂1(x)[3] + d2 = x -> ∂2(x)[3] + d3 = x -> ∂3(x)[3] d5 = x -> ∂5(x)[1] - d6 = x -> ∂6(x)[2] - d7 = x -> ∂7(x)[2] + d6 = x -> ∂6(x)[3] + d7 = x -> ∂7(x)[3] d9 = x -> ∂9(x)[1] - d10 = x -> ∂10(x)[2] + d10 = x -> ∂10(x)[3] return X10, ΔY -> d1(d2(d3(d5(d6(d7(d9(d10(ΔY)))))))) end @@ -65,9 +65,9 @@ g2 = gradient(X -> loss(X), X) ## test Reverse network AD Nrev = reverse(N10) -Xrev, ∂rev = rrule(Nrev, X) +Xrev, ∂rev = getrrule(Nrev, X) grev = ∂rev(Xrev-Y0) g2rev = gradient(X -> 0.5f0*norm(Nrev(X) - Y0)^2, X) -@test grev[2] ≈ g2rev[1] rtol=1f-6 +@test grev[3] ≈ g2rev[1] rtol=1f-6 diff --git a/test/test_utils/test_flux.jl b/test/test_utils/test_flux.jl new file mode 100644 index 0000000..47dad13 --- /dev/null +++ b/test/test_utils/test_flux.jl @@ -0,0 +1,50 @@ +using InvertibleNetworks, Flux, Test, LinearAlgebra + +# Define network +nx = 1 +ny = 1 +n_in = 2 +n_hidden = 10 +batchsize = 32 + +# net +AN = ActNorm(n_in; logdet = false) +C = CouplingLayerGlow(n_in, n_hidden; logdet = false, k1 = 1, k2 = 1, p1 = 0, p2 = 0) +pan, pc = deepcopy(get_params(AN)), deepcopy(get_params(C)) +model = Chain(AN, C) + +# dummy input & target +X = randn(Float32, nx, ny, n_in, batchsize) +Y = model(X) +X0 = rand(Float32, nx, ny, n_in, batchsize) .+ 1 + +# loss fn +loss(model, X, Y) = Flux.mse(Y, model(X)) + +# old, implicit-style Flux +θ = Flux.params(model) +opt = Descent(0.001) + +l, grads = Flux.withgradient(θ) do + loss(model, X0, Y) +end + +for θi in θ + @test θi ∈ keys(grads.grads) + @test !isnothing(grads.grads[θi]) + @test size(grads.grads[θi]) == size(θi) +end + +Flux.update!(opt, θ, grads) + +for i = 1:5 + li, grads = Flux.withgradient(θ) do + loss(model, X, Y) + end + + @info "Loss: $li" + @test li != l + global l = li + + Flux.update!(opt, θ, grads) +end \ No newline at end of file