Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add norm flows #1

Merged
merged 25 commits into from
Aug 22, 2019
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9758bcd
Add norm flows
sharanry Jul 23, 2019
b8c22db
minor changes
sharanry Jul 23, 2019
22b1936
fix bugs
sharanry Jul 23, 2019
5719e5d
fix more bugs
sharanry Jul 23, 2019
8104a28
fix bugs and follow style guide
sharanry Jul 25, 2019
494075d
add tests on logabsdetjacob
sharanry Jul 25, 2019
59aeceb
fix spaces
sharanry Jul 25, 2019
b0a01e9
add iterative norm for planar flows
sharanry Jul 31, 2019
96c2c7f
fix minor bug
sharanry Jul 31, 2019
2109d6d
adhere to stylecode, add radial inverse, remove tracker dependency, r…
sharanry Aug 1, 2019
b727f5e
fix radius bug
sharanry Aug 7, 2019
6c1a5a0
fix param dependency
sharanry Aug 7, 2019
0c026d4
fix inv() for radial flow; follow style guidelines
sharanry Aug 13, 2019
d3ff9a6
minor change to test
sharanry Aug 13, 2019
8ca489a
minor fix
sharanry Aug 14, 2019
fb2399f
add ref to paper for each equation
sharanry Aug 14, 2019
45a6713
fix forward and remove redundant
sharanry Aug 14, 2019
3f86c99
remove update_u_hat!() requirement
sharanry Aug 15, 2019
d882406
update tests and ifx bug
sharanry Aug 17, 2019
bac200f
update tests
sharanry Aug 17, 2019
f276e57
implement Bijector call. We can now transform using BijectorName(x)
sharanry Aug 17, 2019
23a0501
Add inv and rand functions for composed bijectors
sharanry Aug 17, 2019
dc73717
Add direct calls for norm flows
sharanry Aug 18, 2019
4a7b260
Fix the two remaining issues in https://github.com/torfjelde/Bijector…
xukai92 Aug 19, 2019
b22ef2a
Merge branch 'tor/interface' into norm_flow
torfjelde Aug 22, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ using Reexport, Requires
using StatsFuns
using LinearAlgebra
using MappedArrays
using Roots

export TransformDistribution,
export TransformDistribution,
RealDistribution,
PositiveDistribution,
UnitDistribution,
Expand All @@ -28,7 +29,9 @@ export TransformDistribution,
bijector,
transformed,
UnivariateTransformed,
MultivariateTransformed
MultivariateTransformed,
PlanarLayer,
RadialLayer

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0")))

Expand Down Expand Up @@ -177,8 +180,8 @@ function _clamp(x::T, dist::SimplexDistribution) where T
end

function link(
d::SimplexDistribution,
x::AbstractVector{T},
d::SimplexDistribution,
x::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
y, K = similar(x), length(x)
Expand Down Expand Up @@ -206,8 +209,8 @@ end

# Vectorised implementation of the above.
function link(
d::SimplexDistribution,
X::AbstractMatrix{T},
d::SimplexDistribution,
X::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
Y, K, N = similar(X), size(X, 1), size(X, 2)
Expand All @@ -234,8 +237,8 @@ function link(
end

function invlink(
d::SimplexDistribution,
y::AbstractVector{T},
d::SimplexDistribution,
y::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
x, K = similar(y), length(y)
Expand All @@ -260,8 +263,8 @@ end

# Vectorised implementation of the above.
function invlink(
d::SimplexDistribution,
Y::AbstractMatrix{T},
d::SimplexDistribution,
Y::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
X, K, N = similar(Y), size(Y, 1), size(Y, 2)
Expand Down Expand Up @@ -355,8 +358,8 @@ function invlink(d::PDMatDistribution, Y::AbstractMatrix{T}) where {T<:Real}
end

function logpdf_with_trans(
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
transform::Bool
)
T = eltype(X)
Expand Down Expand Up @@ -436,4 +439,6 @@ end

include("interface.jl")

include("norm_flows.jl")

end # module
20 changes: 17 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ end
Broadcast.broadcastable(b::Bijector) = Ref(b)

"Computes the log(abs(det(J(x)))) where J is the jacobian of the transform."
logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} =
logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} =
error("`logabsdetjac(b::$T1, y::$T2)` is not implemented.")

"Transforms the input using the bijector."
transform(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} =
error("`transform(b::$T1, y::$T2)` is not implemented.")

"Computes both `transform` and `logabsdetjac` in one forward pass."
forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} =
forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} =
error("`forward(b::$T1, y::$T2)` is not implemented.")


Expand Down Expand Up @@ -112,7 +112,7 @@ end

function compose(ts...)
res = []

for b ∈ ts
if b isa Composed
# "lift" the transformations
Expand Down Expand Up @@ -151,6 +151,14 @@ function transform(cb::Composed{<: Bijector}, x)
return res
end

function inv(cb::Composed{<: Bijector}, y)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sould this be Composed{<: AbstractArray{<: Bijector}}?
Also, I'd suggest just dispatching on Composed if so, as this also allows using Tuple as the container.

EDIT: I see I've probably previously made the same mistake below! This has been fixed in the more recent commits.

res = y
for b ∈ reverse(cb.ts)
res = inv(b, res)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This inv(b, res) does not exist for other bijectors. Just add something like

inv(b::Bijector, y) = inv(b)(y)

as a default implementation just before this function.

EDIT: See my comment regarding the inv for the flows. If that is true, then we don't even need this function.

end
return res
end

(cb::Composed{<: Bijector})(x) = transform(cb, x)

function forward(cb::Composed{<:Bijector}, x)
Expand All @@ -162,6 +170,12 @@ function forward(cb::Composed{<:Bijector}, x)
return res
end

function rand(flow::Composed, dims::Integer, shape::Integer=1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite certain why this is necessary. Is it supposed to be used as the "source distribution" for the transformation? Is there a reason why this cannot be done with a MultivariateTransformed <: Distribution?

Even if we have to keep this in, this is not supported by any Bijector but is specific to flows, right? If so, I'd suggest adding an abstract type Flows and let the normalizing flows subtype from this. This way we could make this method dispatch on Composed{<: AbstractArray{<: Flow}} or something.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rand for the transformed distribution is a requirement of Normalised Flows VI.

Wouldn't this in general work for any composition of bijections?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I said it wouldn't work, I was referring to the fact that transform is no longer supported.
But I agree, with the change below, it would work not just for any composition, but for any Bijector I believe.

But I think it makes more sense to make this more general, allowing one to attach any Distribution to a Bijector, composed or not. The MultivariateTransformed is the approach I've taken for this, and it seems to work nicely. Your function would then simply be a Composed bijector attached to a MvNormal.

With your approach you might run into an issue where a function dispatches on d::Distribution, but PlanarFlow() isa Distribution is false. I think what I'm proposing with the transformed distributions is more modular and will "just work" in a lot of packages that uses Distributions.jl. For example, if you look at my PR for Bijectors.jl you can see an example where because of this MultivariateTransformed I could simply plug a transformed distribution into another package I'd implemend ages ago, and it just worked :)

Does this make sense?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we don't need rand in Bijectors.jl

dims = [dims]
append!(dims, shape)
print(dims)
return transform(flow, randn(dims...))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we keep this in; should this be flow(randn(dims...)) now?

end
##############################
# Example bijector: Identity #
##############################
Expand Down
130 changes: 130 additions & 0 deletions src/norm_flows.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
using Distributions
using LinearAlgebra
using Random
using StatsFuns: softplus
using Roots # for inverse

################################################################################
# Planar and Radial Flows #
# Ref: Variational Inference with Normalizing Flows, #
# D. Rezende, S. Mohamed(2015) arXiv:1505.05770 #
################################################################################

(b::Bijector)(x) = transform(b, x)
sharanry marked this conversation as resolved.
Show resolved Hide resolved

mutable struct PlanarLayer{T1,T2} <: Bijector
w::T1
u::T1
b::T2
end

function get_u_hat(u, w)
# To preserve invertibility
return (
u + (planar_flow_m(transpose(w) * u) - transpose(w) * u)[1]
* w / (norm(w[:,1],2) ^ 2)
) # from A.1
end

function PlanarLayer(dims::Int, container=Array)
w = container(randn(dims, 1))
u = container(randn(dims, 1))
b = container(randn(1))
return PlanarLayer(w, u, b)
end

planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow from A.1
dtanh(x) = 1 .- (tanh.(x)) .^ 2 # for planar flow
ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # for planar flow from eq(11)

function transform(flow::PlanarLayer, z)
u_hat = get_u_hat(flow.u, flow.w)
return z + u_hat * tanh.(transpose(flow.w) * z .+ flow.b) # from eq(10)
end

function forward(flow::T, z) where {T<:PlanarLayer}
u_hat = get_u_hat(flow.u, flow.w)
# Compute log_det_jacobian
psi = ψ(z, flow.w, flow.b)
log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * u_hat)) # from eq(12)
transformed = z + u_hat * tanh.(transpose(flow.w) * z .+ flow.b)
return (rv=transformed, logabsdetjac=log_det_jacobian) # from eq(10)
end

function inv(flow::PlanarLayer, y)
sharanry marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we cannot do this using Inverted{<: PlanarLayer}? Seems to me that we could just implement this as

function (ib::Inversed{<: PlanarLayer})(y)
    flow = ib.orig
    # same as in this inv(flow, y) ...
end

u_hat = get_u_hat(flow.u, flow.w)
# Implemented with reference from A.1
function f(y)
return loss(alpha) = (
(transpose(flow.w) * y)[1] - alpha
- (transpose(flow.w) * u_hat)[1]
* tanh(alpha+flow.b[1])
)
end
alphas_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)]
alphas = transpose(alphas_)
z_para = (flow.w ./ norm(flow.w,2)) * alphas
z_per = (
y - z_para - u_hat * tanh.(
transpose(flow.w) * z_para
.+ flow.b
)
)

return z_para + z_per
end

mutable struct RadialLayer{T1,T2} <: Bijector
α_::T1
β::T1
z_0::T2
end

function RadialLayer(dims::Int, container=Array)
α_ = container(randn(1))
β = container(randn(1))
z_0 = container(randn(dims, 1))
return RadialLayer(α_, β, z_0)
end

h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14)
dh(α, r) = - h(α, r) .^ 2 # for radial flow, derivative of h()

function transform(flow::RadialLayer, z)
α = softplus(flow.α_[1]) # from A.2
β_hat = -α + softplus(flow.β[1]) # from A.2
r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2))
return z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14)
end

function forward(flow::T, z) where {T<:RadialLayer}
sharanry marked this conversation as resolved.
Show resolved Hide resolved
α = softplus(flow.α_[1]) # from A.2
β_hat = -α + softplus(flow.β[1]) # from A.2
r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2))
transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14)
sharanry marked this conversation as resolved.
Show resolved Hide resolved
# Compute log_det_jacobian
d = size(flow.z_0, 1)
h_ = h(α, r)
sharanry marked this conversation as resolved.
Show resolved Hide resolved
log_det_jacobian = @. (
(d-1) * log(1.0 + β_hat * h_)
+ log(1.0 + β_hat * h_ + β_hat * (- h_ ^ 2) * r)
) # from eq(14)
return (rv=transformed, logabsdetjac=log_det_jacobian)
end

function inv(flow::RadialLayer, y)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See other comment on inv; same applies here.

α = softplus(flow.α_[1]) # from A.2
β_hat = - α + softplus(flow.β[1]) # from A.2
function f(y)
# From eq(26)
return loss(r) = (
norm(y - flow.z_0, 2)
- r * (1 + β_hat / (α + r))
)
end
rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] # A.2
rs = transpose(rs_)
z_hat = (y .- flow.z_0) ./ (rs .* (1 .+ β_hat ./ (α .+ rs)) ) # from eq(25)
z = flow.z_0 .+ rs .* z_hat # from A.2
return z
end
39 changes: 39 additions & 0 deletions test/norm_flows.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using Test
using Bijectors, ForwardDiff, LinearAlgebra

@testset "planar flows" begin
for i in 1:10
flow = PlanarLayer(10)
z = randn(10, 100)
forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z))))
our_method = sum(forward(flow, z).logabsdetjacob)
@test our_method ≈ forward_diff

# Inverse not accurate enough to pass with `≈` operator.
@test_broken inv(flow, transform(flow, z)) ≈ z
end

w = ones(10, 1)
u = zeros(10, 1)
b = ones(1)
flow = PlanarLayer(w, u, b)
z = ones(10, 100)
@test inv(flow, transform(flow, z)) ≈ z
end

@testset "radial flows" begin
sharanry marked this conversation as resolved.
Show resolved Hide resolved
for i in 1:10
flow = RadialLayer(2)
z = randn(2, 100)
forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z))))
our_method = sum(forward(flow, z).logabsdetjacob)
@test our_method ≈ forward_diff
@test inv(flow, transform(flow, z)) ≈ z
end
α_ = ones(1)
β = ones(1)
z_0 = zeros(10, 1)
z = ones(10, 100)
flow = RadialLayer(α_, β, z_0)
@test inv(flow, transform(flow, z)) ≈ z
end