-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add norm flows #1
Changes from 22 commits
9758bcd
b8c22db
22b1936
5719e5d
8104a28
494075d
59aeceb
b0a01e9
96c2c7f
2109d6d
b727f5e
6c1a5a0
0c026d4
d3ff9a6
8ca489a
fb2399f
45a6713
3f86c99
d882406
bac200f
f276e57
23a0501
dc73717
4a7b260
b22ef2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,15 +46,15 @@ end | |
Broadcast.broadcastable(b::Bijector) = Ref(b) | ||
|
||
"Computes the log(abs(det(J(x)))) where J is the jacobian of the transform." | ||
logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = | ||
logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = | ||
error("`logabsdetjac(b::$T1, y::$T2)` is not implemented.") | ||
|
||
"Transforms the input using the bijector." | ||
transform(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = | ||
error("`transform(b::$T1, y::$T2)` is not implemented.") | ||
|
||
"Computes both `transform` and `logabsdetjac` in one forward pass." | ||
forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = | ||
forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = | ||
error("`forward(b::$T1, y::$T2)` is not implemented.") | ||
|
||
|
||
|
@@ -112,7 +112,7 @@ end | |
|
||
function compose(ts...) | ||
res = [] | ||
|
||
for b ∈ ts | ||
if b isa Composed | ||
# "lift" the transformations | ||
|
@@ -151,6 +151,14 @@ function transform(cb::Composed{<: Bijector}, x) | |
return res | ||
end | ||
|
||
function inv(cb::Composed{<: Bijector}, y) | ||
res = y | ||
for b ∈ reverse(cb.ts) | ||
res = inv(b, res) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This inv(b::Bijector, y) = inv(b)(y) as a default implementation just before this function. EDIT: See my comment regarding the |
||
end | ||
return res | ||
end | ||
|
||
(cb::Composed{<: Bijector})(x) = transform(cb, x) | ||
|
||
function forward(cb::Composed{<:Bijector}, x) | ||
|
@@ -162,6 +170,12 @@ function forward(cb::Composed{<:Bijector}, x) | |
return res | ||
end | ||
|
||
function rand(flow::Composed, dims::Integer, shape::Integer=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not quite certain why this is necessary. Is it supposed to be used as the "source distribution" for the transformation? Is there a reason why this cannot be done with a Even if we have to keep this in, this is not supported by any There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Wouldn't this in general work for any composition of bijections? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I said it wouldn't work, I was referring to the fact that But I think it makes more sense to make this more general, allowing one to attach any With your approach you might run into an issue where a function dispatches on Does this make sense? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree we don't need rand in Bijectors.jl |
||
dims = [dims] | ||
append!(dims, shape) | ||
print(dims) | ||
return transform(flow, randn(dims...)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we keep this in; should this be |
||
end | ||
############################## | ||
# Example bijector: Identity # | ||
############################## | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
using Distributions | ||
using LinearAlgebra | ||
using Random | ||
using StatsFuns: softplus | ||
using Roots # for inverse | ||
|
||
################################################################################ | ||
# Planar and Radial Flows # | ||
# Ref: Variational Inference with Normalizing Flows, # | ||
# D. Rezende, S. Mohamed(2015) arXiv:1505.05770 # | ||
################################################################################ | ||
|
||
(b::Bijector)(x) = transform(b, x) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
mutable struct PlanarLayer{T1,T2} <: Bijector | ||
w::T1 | ||
u::T1 | ||
b::T2 | ||
end | ||
|
||
function get_u_hat(u, w) | ||
# To preserve invertibility | ||
return ( | ||
u + (planar_flow_m(transpose(w) * u) - transpose(w) * u)[1] | ||
* w / (norm(w[:,1],2) ^ 2) | ||
) # from A.1 | ||
end | ||
|
||
function PlanarLayer(dims::Int, container=Array) | ||
w = container(randn(dims, 1)) | ||
u = container(randn(dims, 1)) | ||
b = container(randn(1)) | ||
return PlanarLayer(w, u, b) | ||
end | ||
|
||
planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow from A.1 | ||
dtanh(x) = 1 .- (tanh.(x)) .^ 2 # for planar flow | ||
ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # for planar flow from eq(11) | ||
|
||
function transform(flow::PlanarLayer, z) | ||
u_hat = get_u_hat(flow.u, flow.w) | ||
return z + u_hat * tanh.(transpose(flow.w) * z .+ flow.b) # from eq(10) | ||
end | ||
|
||
function forward(flow::T, z) where {T<:PlanarLayer} | ||
u_hat = get_u_hat(flow.u, flow.w) | ||
# Compute log_det_jacobian | ||
psi = ψ(z, flow.w, flow.b) | ||
log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * u_hat)) # from eq(12) | ||
transformed = z + u_hat * tanh.(transpose(flow.w) * z .+ flow.b) | ||
return (rv=transformed, logabsdetjac=log_det_jacobian) # from eq(10) | ||
end | ||
|
||
function inv(flow::PlanarLayer, y) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason why we cannot do this using function (ib::Inversed{<: PlanarLayer})(y)
flow = ib.orig
# same as in this inv(flow, y) ...
end |
||
u_hat = get_u_hat(flow.u, flow.w) | ||
# Implemented with reference from A.1 | ||
function f(y) | ||
return loss(alpha) = ( | ||
(transpose(flow.w) * y)[1] - alpha | ||
- (transpose(flow.w) * u_hat)[1] | ||
* tanh(alpha+flow.b[1]) | ||
) | ||
end | ||
alphas_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] | ||
alphas = transpose(alphas_) | ||
z_para = (flow.w ./ norm(flow.w,2)) * alphas | ||
z_per = ( | ||
y - z_para - u_hat * tanh.( | ||
transpose(flow.w) * z_para | ||
.+ flow.b | ||
) | ||
) | ||
|
||
return z_para + z_per | ||
end | ||
|
||
mutable struct RadialLayer{T1,T2} <: Bijector | ||
α_::T1 | ||
β::T1 | ||
z_0::T2 | ||
end | ||
|
||
function RadialLayer(dims::Int, container=Array) | ||
α_ = container(randn(1)) | ||
β = container(randn(1)) | ||
z_0 = container(randn(dims, 1)) | ||
return RadialLayer(α_, β, z_0) | ||
end | ||
|
||
h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) | ||
dh(α, r) = - h(α, r) .^ 2 # for radial flow, derivative of h() | ||
|
||
function transform(flow::RadialLayer, z) | ||
α = softplus(flow.α_[1]) # from A.2 | ||
β_hat = -α + softplus(flow.β[1]) # from A.2 | ||
r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2)) | ||
return z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) | ||
end | ||
|
||
function forward(flow::T, z) where {T<:RadialLayer} | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
α = softplus(flow.α_[1]) # from A.2 | ||
β_hat = -α + softplus(flow.β[1]) # from A.2 | ||
r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2)) | ||
transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Compute log_det_jacobian | ||
d = size(flow.z_0, 1) | ||
h_ = h(α, r) | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
log_det_jacobian = @. ( | ||
(d-1) * log(1.0 + β_hat * h_) | ||
+ log(1.0 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) | ||
) # from eq(14) | ||
return (rv=transformed, logabsdetjac=log_det_jacobian) | ||
end | ||
|
||
function inv(flow::RadialLayer, y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See other comment on |
||
α = softplus(flow.α_[1]) # from A.2 | ||
β_hat = - α + softplus(flow.β[1]) # from A.2 | ||
function f(y) | ||
# From eq(26) | ||
return loss(r) = ( | ||
norm(y - flow.z_0, 2) | ||
- r * (1 + β_hat / (α + r)) | ||
) | ||
end | ||
rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] # A.2 | ||
rs = transpose(rs_) | ||
z_hat = (y .- flow.z_0) ./ (rs .* (1 .+ β_hat ./ (α .+ rs)) ) # from eq(25) | ||
z = flow.z_0 .+ rs .* z_hat # from A.2 | ||
return z | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
using Test | ||
using Bijectors, ForwardDiff, LinearAlgebra | ||
|
||
@testset "planar flows" begin | ||
for i in 1:10 | ||
flow = PlanarLayer(10) | ||
z = randn(10, 100) | ||
forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) | ||
our_method = sum(forward(flow, z).logabsdetjacob) | ||
@test our_method ≈ forward_diff | ||
|
||
# Inverse not accurate enough to pass with `≈` operator. | ||
@test_broken inv(flow, transform(flow, z)) ≈ z | ||
end | ||
|
||
w = ones(10, 1) | ||
u = zeros(10, 1) | ||
b = ones(1) | ||
flow = PlanarLayer(w, u, b) | ||
z = ones(10, 100) | ||
@test inv(flow, transform(flow, z)) ≈ z | ||
end | ||
|
||
@testset "radial flows" begin | ||
sharanry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for i in 1:10 | ||
flow = RadialLayer(2) | ||
z = randn(2, 100) | ||
forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) | ||
our_method = sum(forward(flow, z).logabsdetjacob) | ||
@test our_method ≈ forward_diff | ||
@test inv(flow, transform(flow, z)) ≈ z | ||
end | ||
α_ = ones(1) | ||
β = ones(1) | ||
z_0 = zeros(10, 1) | ||
z = ones(10, 100) | ||
flow = RadialLayer(α_, β, z_0) | ||
@test inv(flow, transform(flow, z)) ≈ z | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sould this be
Composed{<: AbstractArray{<: Bijector}}
?Also, I'd suggest just dispatching on
Composed
if so, as this also allows usingTuple
as the container.EDIT: I see I've probably previously made the same mistake below! This has been fixed in the more recent commits.