Skip to content

Commit

Permalink
fix Flux compat
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Apr 20, 2023
1 parent 5c092f9 commit b2eeb12
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte <p.witte@ymail.com>", "Ali Siahkoohi <alisk@gatech.edu>", "Mathias Louboutin <mlouboutin3@gatech.edu>", "Gabrio Rizzuti <g.rizzuti@umcutrecht.nl>", "Rafael Orozco <rorozco@gatech.edu>", "Felix J. herrmann <fherrmann@gatech.edu>"]
version = "2.2.4"
version = "2.2.5"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
2 changes: 1 addition & 1 deletion src/layers/invertible_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL
ΔX2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X2) + ΔY2
else
ΔX2, Δθrb = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT; ), X2; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS_, S), 0f0.*ΔT;), X2; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), 0f0.*ΔT;), X2; set_grad=set_grad)
ΔX2 += ΔY2
end
ΔX_ = tensor_cat(ΔX1, ΔX2)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/invertible_layer_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL
end

# Input are two tensors ΔX, X
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing) where {T, N}
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, set_grad::Bool=true) where {T, N}
isnothing(permute) ? permute = H.permute : permute = permute

# Permutation
Expand Down
28 changes: 15 additions & 13 deletions src/utils/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using ChainRulesCore
export logdetjac
import ChainRulesCore: frule, rrule

export logdetjac, getrrule
import ChainRulesCore: frule, rrule, @non_differentiable

@non_differentiable get_params(::Invertible)
@non_differentiable get_params(::Reversed)
## Tape types and utilities

"""
Expand Down Expand Up @@ -81,7 +82,6 @@ function forward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N},
if logdet isa Float32
state.logdet === nothing ? (state.logdet = logdet) : (state.logdet += logdet)
end

end

"""
Expand All @@ -97,15 +97,13 @@ function backward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N}
state.Y[state.counter_block] = X
state.counter_layer -= 1
end

state.counter_block == 0 && reset!(state) # reset state when first block/first layer is reached

end

## Chain rules for invertible networks
# General pullback function
function pullback(net::Invertible, ΔY::AbstractArray{T,N};
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}

# Check state coherency
check_coherence(state, net)
Expand All @@ -114,19 +112,19 @@ function pullback(net::Invertible, ΔY::AbstractArray{T,N};
T2 = typeof(current(state))
ΔY = convert(T2, ΔY)
# Backward pass
ΔX, X_ = net.backward(ΔY, current(state))

ΔX, X_ = net.backward(ΔY, current(state); set_grad=true)
Δθ = getfield.(get_params(net), :grad)
# Update state
backward_update!(state, X_)

return nothing, ΔX
return NoTangent(), NoTangent(), ΔX, Δθ
end


# Reverse-mode AD rule
function ChainRulesCore.rrule(net::Invertible, X::AbstractArray{T, N};
function ChainRulesCore.rrule(::typeof(forward_net), net::Invertible, X::AbstractArray{T, N}, θ...;
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}

# Forward pass
net.logdet ? ((Y, logdet) = net.forward(X)) : (Y = net.forward(X); logdet = nothing)

Expand All @@ -142,4 +140,8 @@ end

## Logdet utilities for Zygote pullback

logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet
logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet

## Utility to get the pullback directly for testing

getrrule(net::Invertible, X::AbstractArray) = rrule(forward_net, net, X, getfield.(get_params(net), :data))
2 changes: 1 addition & 1 deletion src/utils/dimensionality_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ end

# Split and reshape 1D vector Y in latent space back to states Zi
# where Zi is the split tensor at each multiscale level.
function split_states(Y::AbstractVector{T}, Z_dims) where {T, N}
function split_states(Y::AbstractVector{T}, Z_dims) where {T}
L = length(Z_dims) + 1
inds = cumsum([1, [prod(Z_dims[j]) for j=1:L-1]...])
Z_save = [reshape(Y[inds[j]:inds[j+1]-1], xy_dims(Z_dims[j], Val(j==L), Val(length(Z_dims[j])))) for j=1:L-1]
Expand Down
5 changes: 3 additions & 2 deletions src/utils/neuralnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
getproperty(I::Invertible, s::Symbol) = _get_property(I, Val{s}())

_get_property(I::Invertible, ::Val{s}) where {s} = getfield(I, s)
_get_property(R::Reversed, ::Val{:I}) where s = getfield(R, :I)
_get_property(R::Reversed, ::Val{:I}) = getfield(R, :I)
_get_property(R::Reversed, ::Val{s}) where s = _get_property(R.I, Val{s}())

for m _INet_modes
Expand Down Expand Up @@ -128,4 +128,5 @@ function set_params!(N::Invertible, θnew::Array{Parameter, 1})
end

# Make invertible nets callable objects
(N::Invertible)(X::AbstractArray{T,N} where {T, N}) = N.forward(X)
(net::Invertible)(X::AbstractArray{T,N} where {T, N}) = forward_net(net, X, getfield.(get_params(net), :data))
forward_net(net::Invertible, X::AbstractArray{T,N}, ::Any) where {T, N} = net.forward(X)
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ basics = ["test_utils/test_objectives.jl",
"test_utils/test_activations.jl",
"test_utils/test_squeeze.jl",
"test_utils/test_jacobian.jl",
"test_utils/test_chainrules.jl"]
"test_utils/test_chainrules.jl",
"test_utils/test_flux.jl"]

# Layers
layers = ["test_layers/test_residual_block.jl",
Expand Down
28 changes: 14 additions & 14 deletions test/test_utils/test_chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@ N10 = CouplingLayerHINT(n_ch, n_hidden; logdet=logdet, permute="full")

# Forward pass + gathering pullbacks
function fw(X)
X1, ∂1 = rrule(N1, X)
X2, ∂2 = rrule(N2, X1)
X3, ∂3 = rrule(N3, X2)
X1, ∂1 = getrrule(N1, X)
X2, ∂2 = getrrule(N2, X1)
X3, ∂3 = getrrule(N3, X2)
X5, ∂5 = Flux.Zygote.pullback(Chain(N4, N5), X3)
X6, ∂6 = rrule(N6, X5)
X7, ∂7 = rrule(N7, X6)
X6, ∂6 = getrrule(N6, X5)
X7, ∂7 = getrrule(N7, X6)
X9, ∂9 = Flux.Zygote.pullback(Chain(N8, N9), X7)
X10, ∂10 = rrule(N10, X9)
d1 = x -> ∂1(x)[2]
d2 = x -> ∂2(x)[2]
d3 = x -> ∂3(x)[2]
X10, ∂10 = getrrule(N10, X9)
d1 = x -> ∂1(x)[3]
d2 = x -> ∂2(x)[3]
d3 = x -> ∂3(x)[3]
d5 = x -> ∂5(x)[1]
d6 = x -> ∂6(x)[2]
d7 = x -> ∂7(x)[2]
d6 = x -> ∂6(x)[3]
d7 = x -> ∂7(x)[3]
d9 = x -> ∂9(x)[1]
d10 = x -> ∂10(x)[2]
d10 = x -> ∂10(x)[3]
return X10, ΔY -> d1(d2(d3(d5(d6(d7(d9(d10(ΔY))))))))
end

Expand All @@ -65,9 +65,9 @@ g2 = gradient(X -> loss(X), X)
## test Reverse network AD

Nrev = reverse(N10)
Xrev, ∂rev = rrule(Nrev, X)
Xrev, ∂rev = getrrule(Nrev, X)
grev = ∂rev(Xrev-Y0)

g2rev = gradient(X -> 0.5f0*norm(Nrev(X) - Y0)^2, X)

@test grev[2] g2rev[1] rtol=1f-6
@test grev[3] g2rev[1] rtol=1f-6
50 changes: 50 additions & 0 deletions test/test_utils/test_flux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using InvertibleNetworks, Flux, Test, LinearAlgebra

# Define network
nx = 1
ny = 1
n_in = 2
n_hidden = 10
batchsize = 32

# net
AN = ActNorm(n_in; logdet = false)
C = CouplingLayerGlow(n_in, n_hidden; logdet = false, k1 = 1, k2 = 1, p1 = 0, p2 = 0)
pan, pc = deepcopy(get_params(AN)), deepcopy(get_params(C))
model = Chain(AN, C)

# dummy input & target
X = randn(Float32, nx, ny, n_in, batchsize)
Y = model(X)
X0 = rand(Float32, nx, ny, n_in, batchsize) .+ 1

# loss fn
loss(model, X, Y) = Flux.mse(Y, model(X))

# old, implicit-style Flux
θ = Flux.params(model)
opt = Descent(0.001)

l, grads = Flux.withgradient(θ) do
loss(model, X0, Y)
end

for θi in θ
@test θi keys(grads.grads)
@test !isnothing(grads.grads[θi])
@test size(grads.grads[θi]) == size(θi)
end

Flux.update!(opt, θ, grads)

for i = 1:5
li, grads = Flux.withgradient(θ) do
loss(model, X, Y)
end

@info "Loss: $li"
@test li != l
global l = li

Flux.update!(opt, θ, grads)
end

0 comments on commit b2eeb12

Please sign in to comment.