Skip to content

Commit

Permalink
fix Flux compat
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Apr 20, 2023
1 parent 5c092f9 commit ff6c402
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte <p.witte@ymail.com>", "Ali Siahkoohi <alisk@gatech.edu>", "Mathias Louboutin <mlouboutin3@gatech.edu>", "Gabrio Rizzuti <g.rizzuti@umcutrecht.nl>", "Rafael Orozco <rorozco@gatech.edu>", "Felix J. herrmann <fherrmann@gatech.edu>"]
version = "2.2.4"
version = "2.2.5"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
2 changes: 1 addition & 1 deletion src/layers/invertible_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL
ΔX2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X2) + ΔY2
else
ΔX2, Δθrb = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT; ), X2; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS_, S), 0f0.*ΔT;), X2; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), 0f0.*ΔT;), X2; set_grad=set_grad)
ΔX2 += ΔY2
end
ΔX_ = tensor_cat(ΔX1, ΔX2)
Expand Down
18 changes: 10 additions & 8 deletions src/utils/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using ChainRulesCore
export logdetjac
import ChainRulesCore: frule, rrule

import ChainRulesCore: frule, rrule, @non_differentiable

@non_differentiable get_params(::Invertible)
@non_differentiable get_params(::Reversed)
## Tape types and utilities

"""
Expand Down Expand Up @@ -105,7 +106,7 @@ end
## Chain rules for invertible networks
# General pullback function
function pullback(net::Invertible, ΔY::AbstractArray{T,N};
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}

# Check state coherency
check_coherence(state, net)
Expand All @@ -114,17 +115,17 @@ function pullback(net::Invertible, ΔY::AbstractArray{T,N};
T2 = typeof(current(state))
ΔY = convert(T2, ΔY)
# Backward pass
ΔX, X_ = net.backward(ΔY, current(state))

ΔX, X_ = net.backward(ΔY, current(state); set_grad=true)
Δθ = getfield.(get_params(net), :grad)
# Update state
backward_update!(state, X_)

return nothing, ΔX
return NoTangent(), NoTangent(), ΔX, Δθ
end


# Reverse-mode AD rule
function ChainRulesCore.rrule(net::Invertible, X::AbstractArray{T, N};
function ChainRulesCore.rrule(::typeof(forward_net), net::Invertible, X::AbstractArray{T, N}, θ...;
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}

# Forward pass
Expand All @@ -142,4 +143,5 @@ end

## Logdet utilities for Zygote pullback

logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet
logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet

2 changes: 1 addition & 1 deletion src/utils/dimensionality_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ end

# Split and reshape 1D vector Y in latent space back to states Zi
# where Zi is the split tensor at each multiscale level.
function split_states(Y::AbstractVector{T}, Z_dims) where {T, N}
function split_states(Y::AbstractVector{T}, Z_dims) where {T}
L = length(Z_dims) + 1
inds = cumsum([1, [prod(Z_dims[j]) for j=1:L-1]...])
Z_save = [reshape(Y[inds[j]:inds[j+1]-1], xy_dims(Z_dims[j], Val(j==L), Val(length(Z_dims[j])))) for j=1:L-1]
Expand Down
5 changes: 3 additions & 2 deletions src/utils/neuralnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
getproperty(I::Invertible, s::Symbol) = _get_property(I, Val{s}())

_get_property(I::Invertible, ::Val{s}) where {s} = getfield(I, s)
_get_property(R::Reversed, ::Val{:I}) where s = getfield(R, :I)
_get_property(R::Reversed, ::Val{:I}) = getfield(R, :I)
_get_property(R::Reversed, ::Val{s}) where s = _get_property(R.I, Val{s}())

for m _INet_modes
Expand Down Expand Up @@ -128,4 +128,5 @@ function set_params!(N::Invertible, θnew::Array{Parameter, 1})
end

# Make invertible nets callable objects
(N::Invertible)(X::AbstractArray{T,N} where {T, N}) = N.forward(X)
(net::Invertible)(X::AbstractArray{T,N} where {T, N}) = forward_net(net, X, getfield.(get_params(net), :data))
forward_net(net::Invertible, X::AbstractArray{T,N}, ::Any) where {T, N} = net.forward(X)
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ basics = ["test_utils/test_objectives.jl",
"test_utils/test_activations.jl",
"test_utils/test_squeeze.jl",
"test_utils/test_jacobian.jl",
"test_utils/test_chainrules.jl"]
"test_utils/test_chainrules.jl",
"test_utils/test_flux.jl"]

# Layers
layers = ["test_layers/test_residual_block.jl",
Expand Down
41 changes: 41 additions & 0 deletions test/test_utils/test_flux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using InvertibleNetworks, Flux, Test, LinearAlgebra

# Define network
nx = 1
ny = 1
n_in = 2
n_hidden = 10
batchsize = 32

# net
AN = ActNorm(n_in; logdet = false)
C = CouplingLayerGlow(n_in, n_hidden; logdet = false, k1 = 1, k2 = 1, p1 = 0, p2 = 0)
model = Chain(AN, C)

# dummy input & target
X = randn(Float32, nx, ny, n_in, batchsize)
Y = model(X)
X0 = randn(Float32, nx, ny, n_in, batchsize) .+ 1

# loss fn
loss(model, X, Y) = Flux.mse(Y, model(X))
# old, implicit-style Flux
θ = Flux.params(model)
opt = Descent(0.001f0)

l, grads = Flux.withgradient(θ) do
loss(model, X0, Y)
end

for θi in θ
@test θi keys(grads.grads)
@test !isnothing(grads.grads[θi])
@test norm(grads.grads[θi]) > 0
end

bck = 1 .* θ
Flux.update!(opt, θ, grads)

for (bi, θi) in zip(bck, θ)
@test bi - θi grads.grads[θi]
end

0 comments on commit ff6c402

Please sign in to comment.