diff --git a/src/networks/invertible_network_conditional_glow.jl b/src/networks/invertible_network_conditional_glow.jl index 7bd57c69..8170a045 100644 --- a/src/networks/invertible_network_conditional_glow.jl +++ b/src/networks/invertible_network_conditional_glow.jl @@ -148,14 +148,13 @@ function inverse(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi end # Backward pass and compute gradients -function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow;) where {T, N} +function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, ΔC::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow;) where {T, N} # Split data and gradients if G.split_scales ΔZ_save, ΔX = split_states(ΔX[:], G.Z_dims) Z_save, X = split_states(X[:], G.Z_dims) end - ΔC = T(0) .* C for i=G.L:-1:1 if G.split_scales && i < G.L X = tensor_cat(X, Z_save[i]) diff --git a/src/networks/summarized_net.jl b/src/networks/summarized_net.jl index 701a04f1..ac0fd547 100644 --- a/src/networks/summarized_net.jl +++ b/src/networks/summarized_net.jl @@ -49,8 +49,8 @@ function inverse(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNe end # Backward pass and compute gradients -function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet; Y_save=nothing) where {T, N} - ΔX, X, ΔY = S.cond_net.backward(ΔX,X,Y) +function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet; Y_save=nothing) where {T, N} + ΔX, X, ΔY = S.cond_net.backward(ΔX,X,ΔY,Y) ΔY = S.sum_net.backward(ΔY, Y_save) return ΔX, X, ΔY end diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl index 642f76f2..a7bd4627 100644 --- a/test/test_networks/test_conditional_glow_network.jl +++ b/test/test_networks/test_conditional_glow_network.jl @@ -2,6 +2,7 @@ # Author: Philipp Witte, pwitte3@gatech.edu # Date: January 2020 +using Revise using InvertibleNetworks, LinearAlgebra, Test, Random using Flux @@ -26,16 +27,16 @@ N = (nx,ny) # Network and input G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) -X = rand(Float32, N..., n_in, batchsize) -Cond = rand(Float32, N..., n_cond, batchsize) +X = rand(Float32, N..., n_in, batchsize); +Cond = rand(Float32, N..., n_cond, batchsize); -Y, Cond = G.forward(X,Cond) -X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes +Y, Cond = G.forward(X,Cond); +X_ = G.inverse(Y,Cond); # saving the cond is important in split scales because of reshapes @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) # Test gradients are set and cleared -G.backward(Y, Y, Cond) +G.backward(Y, Y, Cond, Cond); P = get_params(G) gsum = 0 @@ -54,15 +55,54 @@ end ################################################################################################### # Gradient test + +# Gradient test +function loss_y(G, X, Cond) + Y, ZC, logdet = G.forward(X, Cond) + f = -log_likelihood(Y) -log_likelihood(ZC) - logdet + ΔY = -∇log_likelihood(Y) + ΔZC = -∇log_likelihood(ZC) + ΔX, X_, ΔY = G.backward(ΔY, Y, ΔZC, ZC) + return f, ΔX, ΔY, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad +end + +# Gradient test w.r.t. condition input +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) +X = rand(Float32, N..., n_in, batchsize); +Cond = rand(Float32, N..., n_cond, batchsize); +X0 = rand(Float32, N..., n_in, batchsize); +Cond0 = rand(Float32, N..., n_cond, batchsize); + +dCond = Cond - Cond0 + +f0, ΔX, ΔY = loss_y(G, X0, Cond0)[1:3] +h = 0.1f0 +maxiter = 4 +err1 = zeros(Float32, maxiter) +err2 = zeros(Float32, maxiter) + +print("\nGradient test glow: input\n") +for j=1:maxiter + f = loss_y(G, X0, Cond0 + h*dCond)[1] + err1[j] = abs(f - f0) + err2[j] = abs(f - f0 - h*dot(dCond, ΔY)) + print(err1[j], "; ", err2[j], "\n") + global h = h/2f0 +end + +@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) + + + function loss(G, X, Cond) Y, ZC, logdet = G.forward(X, Cond) f = -log_likelihood(Y) - logdet ΔY = -∇log_likelihood(Y) - ΔX, X_ = G.backward(ΔY, Y, ZC) + ΔX, X_ = G.backward(ΔY, Y, 0f0 .* ZC, ZC) return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad end - # Gradient test w.r.t. input G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) X = rand(Float32, N..., n_in, batchsize) @@ -91,6 +131,14 @@ end @test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) + + + + + + + + # Gradient test w.r.t. parameters X = rand(Float32, N..., n_in, batchsize) G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) @@ -141,7 +189,7 @@ X_ = G.inverse(Y,ZCond) # saving the cond is important in split scales because o @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) # Test gradients are set and cleared -G.backward(Y, Y, ZCond; Y_save = Cond) +G.backward(Y, Y, ZCond, ZCond; Y_save = Cond) P = get_params(G) gsum = 0 @@ -161,10 +209,11 @@ end # Gradient test function loss_sum(G, X, Cond) Y, ZC, logdet = G.forward(X, Cond) - f = -log_likelihood(Y) - logdet + f = -log_likelihood(Y) -log_likelihood(ZC) - logdet ΔY = -∇log_likelihood(Y) - ΔX, X_ = G.backward(ΔY, Y, ZC; Y_save=Cond) - return f, ΔX, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad + ΔZC = -∇log_likelihood(ZC) + ΔX, X_, ΔC = G.backward(ΔY, Y, ΔZC, ZC; Y_save=Cond) + return f, ΔX, ΔC, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad end # Gradient test w.r.t. input @@ -173,9 +222,9 @@ Cond = rand(Float32, N..., n_cond, batchsize); X0 = rand(Float32, N..., n_in, batchsize); Cond0 = rand(Float32, N..., n_cond, batchsize); -dX = X - X0 +dCond = Cond - Cond0 -f0, ΔX = loss_sum(G, X0, Cond0)[1:2] +f0, ΔX, ΔC = loss_sum(G, X0, Cond0)[1:3] h = 0.1f0 maxiter = 4 err1 = zeros(Float32, maxiter) @@ -183,9 +232,9 @@ err2 = zeros(Float32, maxiter) print("\nGradient test glow: input\n") for j=1:maxiter - f = loss_sum(G, X0 + h*dX, Cond0)[1] + f = loss_sum(G, X0, Cond0 + h*dCond)[1] err1[j] = abs(f - f0) - err2[j] = abs(f - f0 - h*dot(dX, ΔX)) + err2[j] = abs(f - f0 - h*dot(dCond, ΔC)) print(err1[j], "; ", err2[j], "\n") global h = h/2f0 end @@ -204,7 +253,7 @@ Gini = deepcopy(G0) dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data -f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond) +f0, ΔX, ΔC, ΔW, Δv = loss_sum(G0, X, Cond) h = 0.1f0 maxiter = 4 err3 = zeros(Float32, maxiter) @@ -241,7 +290,7 @@ X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) # Test gradients are set and cleared -G.backward(Y, Y, Cond) +G.backward(Y, Y, Cond,Cond) P = get_params(G) gsum = 0 @@ -338,7 +387,7 @@ X_ = G.inverse(Y,ZCond); # saving the cond is important in split scales because @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) # Test gradients are set and cleared -G.backward(Y, Y, ZCond; Y_save=Cond) +G.backward(Y, Y,ZCond,ZCond; Y_save=Cond) P = get_params(G) gsum = 0 @@ -394,7 +443,7 @@ Gini = deepcopy(G0) dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data -f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond); +f0, ΔX,ΔC, ΔW, Δv = loss_sum(G0, X, Cond); h = 0.1f0 maxiter = 4 err3 = zeros(Float32, maxiter)