Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradients wrt summary network #93

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/networks/invertible_network_conditional_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,13 @@ function inverse(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi
end

# Backward pass and compute gradients
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow;) where {T, N}
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, ΔC::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow;) where {T, N}
# Split data and gradients
if G.split_scales
ΔZ_save, ΔX = split_states(ΔX[:], G.Z_dims)
Z_save, X = split_states(X[:], G.Z_dims)
end

ΔC = T(0) .* C
for i=G.L:-1:1
if G.split_scales && i < G.L
X = tensor_cat(X, Z_save[i])
Expand Down
4 changes: 2 additions & 2 deletions src/networks/summarized_net.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ function inverse(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNe
end

# Backward pass and compute gradients
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet; Y_save=nothing) where {T, N}
ΔX, X, ΔY = S.cond_net.backward(ΔX,X,Y)
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet; Y_save=nothing) where {T, N}
ΔX, X, ΔY = S.cond_net.backward(ΔX,X,ΔY,Y)
ΔY = S.sum_net.backward(ΔY, Y_save)
return ΔX, X, ΔY
end
87 changes: 68 additions & 19 deletions test/test_networks/test_conditional_glow_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Author: Philipp Witte, pwitte3@gatech.edu
# Date: January 2020

using Revise
using InvertibleNetworks, LinearAlgebra, Test, Random
using Flux

Expand All @@ -26,16 +27,16 @@ N = (nx,ny)

# Network and input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)
X = rand(Float32, N..., n_in, batchsize);
Cond = rand(Float32, N..., n_cond, batchsize);

Y, Cond = G.forward(X,Cond)
X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes
Y, Cond = G.forward(X,Cond);
X_ = G.inverse(Y,Cond); # saving the cond is important in split scales because of reshapes

@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

# Test gradients are set and cleared
G.backward(Y, Y, Cond)
G.backward(Y, Y, Cond, Cond);

P = get_params(G)
gsum = 0
Expand All @@ -54,15 +55,54 @@ end
###################################################################################################
# Gradient test


# Gradient test
function loss_y(G, X, Cond)
Y, ZC, logdet = G.forward(X, Cond)
f = -log_likelihood(Y) -log_likelihood(ZC) - logdet
ΔY = -∇log_likelihood(Y)
ΔZC = -∇log_likelihood(ZC)
ΔX, X_, ΔY = G.backward(ΔY, Y, ΔZC, ZC)
return f, ΔX, ΔY, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad
end

# Gradient test w.r.t. condition input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize);
Cond = rand(Float32, N..., n_cond, batchsize);
X0 = rand(Float32, N..., n_in, batchsize);
Cond0 = rand(Float32, N..., n_cond, batchsize);

dCond = Cond - Cond0

f0, ΔX, ΔY = loss_y(G, X0, Cond0)[1:3]
h = 0.1f0
maxiter = 4
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
f = loss_y(G, X0, Cond0 + h*dCond)[1]
err1[j] = abs(f - f0)
err2[j] = abs(f - f0 - h*dot(dCond, ΔY))
print(err1[j], "; ", err2[j], "\n")
global h = h/2f0
end

@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)



function loss(G, X, Cond)
Y, ZC, logdet = G.forward(X, Cond)
f = -log_likelihood(Y) - logdet
ΔY = -∇log_likelihood(Y)
ΔX, X_ = G.backward(ΔY, Y, ZC)
ΔX, X_ = G.backward(ΔY, Y, 0f0 .* ZC, ZC)
return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad
end


# Gradient test w.r.t. input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Expand Down Expand Up @@ -91,6 +131,14 @@ end
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)










# Gradient test w.r.t. parameters
X = rand(Float32, N..., n_in, batchsize)
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
Expand Down Expand Up @@ -141,7 +189,7 @@ X_ = G.inverse(Y,ZCond) # saving the cond is important in split scales because o
@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

# Test gradients are set and cleared
G.backward(Y, Y, ZCond; Y_save = Cond)
G.backward(Y, Y, ZCond, ZCond; Y_save = Cond)

P = get_params(G)
gsum = 0
Expand All @@ -161,10 +209,11 @@ end
# Gradient test
function loss_sum(G, X, Cond)
Y, ZC, logdet = G.forward(X, Cond)
f = -log_likelihood(Y) - logdet
f = -log_likelihood(Y) -log_likelihood(ZC) - logdet
ΔY = -∇log_likelihood(Y)
ΔX, X_ = G.backward(ΔY, Y, ZC; Y_save=Cond)
return f, ΔX, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad
ΔZC = -∇log_likelihood(ZC)
ΔX, X_, ΔC = G.backward(ΔY, Y, ΔZC, ZC; Y_save=Cond)
return f, ΔX, ΔC, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad
end

# Gradient test w.r.t. input
Expand All @@ -173,19 +222,19 @@ Cond = rand(Float32, N..., n_cond, batchsize);
X0 = rand(Float32, N..., n_in, batchsize);
Cond0 = rand(Float32, N..., n_cond, batchsize);

dX = X - X0
dCond = Cond - Cond0

f0, ΔX = loss_sum(G, X0, Cond0)[1:2]
f0, ΔX, ΔC = loss_sum(G, X0, Cond0)[1:3]
h = 0.1f0
maxiter = 4
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
f = loss_sum(G, X0 + h*dX, Cond0)[1]
f = loss_sum(G, X0, Cond0 + h*dCond)[1]
err1[j] = abs(f - f0)
err2[j] = abs(f - f0 - h*dot(dX, ΔX))
err2[j] = abs(f - f0 - h*dot(dCond, ΔC))
print(err1[j], "; ", err2[j], "\n")
global h = h/2f0
end
Expand All @@ -204,7 +253,7 @@ Gini = deepcopy(G0)
dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data
dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond)
f0, ΔX, ΔC, ΔW, Δv = loss_sum(G0, X, Cond)
h = 0.1f0
maxiter = 4
err3 = zeros(Float32, maxiter)
Expand Down Expand Up @@ -241,7 +290,7 @@ X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of
@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

# Test gradients are set and cleared
G.backward(Y, Y, Cond)
G.backward(Y, Y, Cond,Cond)

P = get_params(G)
gsum = 0
Expand Down Expand Up @@ -338,7 +387,7 @@ X_ = G.inverse(Y,ZCond); # saving the cond is important in split scales because
@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

# Test gradients are set and cleared
G.backward(Y, Y, ZCond; Y_save=Cond)
G.backward(Y, Y,ZCond,ZCond; Y_save=Cond)

P = get_params(G)
gsum = 0
Expand Down Expand Up @@ -394,7 +443,7 @@ Gini = deepcopy(G0)
dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data
dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond);
f0, ΔX,ΔC, ΔW, Δv = loss_sum(G0, X, Cond);
h = 0.1f0
maxiter = 4
err3 = zeros(Float32, maxiter)
Expand Down
Loading