Skip to content

Commit

Permalink
Tp/inference fix (#34)
Browse files Browse the repository at this point in the history
* more inference tests

* merge TupleTransform and NamedTupleTransform implementations.

* add test for inference MWE, also fix
  • Loading branch information
tpapp committed Dec 22, 2018
1 parent dac14b7 commit 23add70
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 56 deletions.
92 changes: 44 additions & 48 deletions src/aggregation.jl
Expand Up @@ -59,7 +59,8 @@ function transform_with(flag::LogJacFlag, t::ArrayTransform, x::RealVector)
d = dimension(transformation)
I = reshape(range(firstindex(x); length = prod(dims), step = d), dims)
yℓ = map(i -> transform_with(flag, transformation, view_into(x, i, d)), I)
first.(yℓ), isempty(yℓ) ? logjac_zero(flag, eltype(x)) : sum(last, yℓ)
ℓz = logjac_zero(flag, eltype(x))
first.(yℓ), isempty(yℓ) ? ℓz : ℓz + sum(last, yℓ)
end

function transform_with(flag::LogJacFlag, t::ArrayTransform{Identity}, x::RealVector)
Expand Down Expand Up @@ -104,11 +105,15 @@ $(TYPEDEF)
Transform consecutive groups of real numbers to a tuple, using the given transformations.
"""
@calltrans struct TransformTuple{K, T <: NTransforms{K}} <: VectorTransform
@calltrans struct TransformTuple{T} <: VectorTransform
transformations::T
dimension::Int
function TransformTuple(transformations::T) where {K, T <: NTransforms{K}}
new{K,T}(transformations, _sum_dimensions(transformations))
function TransformTuple(transformations::T) where {T <: NTransforms}
new{T}(transformations, _sum_dimensions(transformations))
end
function TransformTuple(transformations::T
) where {N, S <: NTransforms, T <: NamedTuple{N, S}}
new{T}(transformations, _sum_dimensions(transformations))
end
end

Expand Down Expand Up @@ -145,14 +150,15 @@ as(transformations::NTransforms) = TransformTuple(transformations)
$(SIGNATURES)
Helper function for transforming tuples. Used internally, to help type inference. Use via
`transfom_tuple` only.
`transfom_tuple`.
"""
_transform_tuple(flag::LogJacFlag, x::RealVector, index) = (), logjac_zero(flag, eltype(x))
_transform_tuple(flag::LogJacFlag, x::RealVector, index, ::Tuple{}) = (), logjac_zero(flag, eltype(x))

function _transform_tuple(flag::LogJacFlag, x::RealVector, index, tfirst, trest...)
function _transform_tuple(flag::LogJacFlag, x::RealVector, index, ts)
tfirst = first(ts)
d = dimension(tfirst)
yfirst, ℓfirst = transform_with(flag, tfirst, view_into(x, index, d))
yrest, ℓrest = _transform_tuple(flag, x, index + d, trest...)
yrest, ℓrest = _transform_tuple(flag, x, index + d, Base.tail(ts))
(yfirst, yrest...), ℓfirst + ℓrest
end

Expand All @@ -162,24 +168,27 @@ $(SIGNATURES)
Helper function for tuple transformations.
"""
transform_tuple(flag::LogJacFlag, tt::NTransforms, x::RealVector) =
_transform_tuple(flag, x, firstindex(x), tt...)
_transform_tuple(flag, x, firstindex(x), tt)

"""
$(SIGNATURES)
Helper function determining element type of inverses from tuples. Used
internally.
*Performs no argument validation, caller should do this.*
"""
_inverse_eltype_tuple(ts::NTransforms{K}, ys::NTuple{K,Any}) where K =
_inverse_eltype_tuple(ts::NTransforms, ys::Tuple) =
mapreduce(((t, y),) -> inverse_eltype(t, y), promote_type, zip(ts, ys))

"""
$(SIGNATURES)
Helper function for inverting tuples of transformations. Used internally.
*Performs no argument validation, caller should do this.*
"""
function _inverse!_tuple(x::RealVector, ts::NTransforms{K},
ys::NTuple{K,Any}) where K
function _inverse!_tuple(x::RealVector, ts::NTransforms, ys::Tuple)
index = firstindex(x)
for (t, y) in zip(ts, ys)
d = dimension(t)
Expand All @@ -189,53 +198,40 @@ function _inverse!_tuple(x::RealVector, ts::NTransforms{K},
x
end

transform_with(flag::LogJacFlag, tt::TransformTuple, x::RealVector) =
transform_with(flag::LogJacFlag, tt::TransformTuple{<:Tuple}, x::RealVector) =
transform_tuple(flag, tt.transformations, x)

inverse_eltype(tt::TransformTuple{K}, y::NTuple{K,Any}) where K =
_inverse_eltype_tuple(tt.transformations, y)
function inverse_eltype(tt::TransformTuple{<:Tuple}, y::Tuple)
@unpack transformations = tt
@argcheck length(transformations) == length(y)
_inverse_eltype_tuple(transformations, y)
end

function inverse!(x::RealVector, tt::TransformTuple{K},
y::NTuple{K,Any}) where K
function inverse!(x::RealVector, tt::TransformTuple{<:Tuple}, y::Tuple)
@unpack transformations = tt
@argcheck length(transformations) == length(y)
@argcheck length(x) == dimension(tt)
_inverse!_tuple(x, tt.transformations, y)
end

"""
$(TYPEDEF)
as(transformations::NamedTuple{N,<:NTransforms}) where N =
TransformTuple(transformations)

Transform consecutive groups of real numbers to a named tuple, using the given
transformations.
"""
@calltrans struct TransformNamedTuple{names, T <: NTransforms} <: VectorTransform
transformations::T
dimension::Int
function TransformNamedTuple(transformations::NamedTuple{names,T}) where
{names, T <: NTransforms}
new{names,T}(values(transformations), _sum_dimensions(transformations))
end
function transform_with(flag::LogJacFlag, tt::TransformTuple{<:NamedTuple}, x::RealVector)
@unpack transformations = tt
y, ℓ = transform_tuple(flag, values(transformations), x)
NamedTuple{keys(transformations)}(y), ℓ
end

"""
$(SIGNATURES)
"""
as(transformations::NamedTuple{T,<:NTransforms}) where T =
TransformNamedTuple(transformations)

dimension(tn::TransformNamedTuple) = tn.dimension

function transform_with(flag::LogJacFlag, tt::TransformNamedTuple{names},
x::RealVector) where {names}
y, ℓ = transform_tuple(flag, tt.transformations, x)
NamedTuple{names}(y), ℓ
function inverse_eltype(tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
@unpack transformations = tt
@argcheck keys(transformations) == keys(y)
_inverse_eltype_tuple(values(transformations), values(y))
end

inverse_eltype(tt::TransformNamedTuple{names},
y::NamedTuple{names}) where names =
_inverse_eltype_tuple(tt.transformations, values(y))

function inverse!(x::RealVector, tt::TransformNamedTuple{names},
y::NamedTuple{names}) where names
function inverse!(x::RealVector, tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
@unpack transformations = tt
@argcheck keys(transformations) == keys(y)
@argcheck length(x) == dimension(tt)
_inverse!_tuple(x, tt.transformations, values(y))
_inverse!_tuple(x, values(transformations), values(y))
end
50 changes: 42 additions & 8 deletions test/runtests.jl
Expand Up @@ -173,30 +173,39 @@ end
znt = as(NamedTuple())
za = as(Array, asℝ₊, 0)
@test dimension(zt) == dimension(znt) == 0
@test transform(zt, Float64[]) == ()
@test @inferred(transform(zt, Float64[])) == ()
@test_skip inverse(zt, ()) == []
@test transform_and_logjac(zt, Float64[]) == ((), 0.0)
@test transform(znt, Float64[]) == NamedTuple()
@test transform_and_logjac(znt, Float64[]) == (NamedTuple(), 0.0)
@test @inferred(transform_and_logjac(zt, Float64[])) == ((), 0.0)
@test @inferred(transform(znt, Float64[])) == NamedTuple()
@test @inferred(transform_and_logjac(znt, Float64[])) == (NamedTuple(), 0.0)
@test_skip inverse(znt, ()) == []
@test transform(za, Float64[]) == Float64[]
@test transform_and_logjac(za, Float64[]) == (Float64[], 0.0)
@test @inferred(transform(za, Float64[])) == Float64[]
@test @inferred(transform_and_logjac(za, Float64[])) == (Float64[], 0.0)
@test_skip inverse(za, []) == []
end

@testset "transform logdensity" begin
@testset "transform logdensity: correctness" begin
# the density is p(σ) = σ⁻³
# let z = log(σ), so σ = exp(z)
# the transformed density is q(z) = -3z + z = -2z
f(σ) = -3*log(σ)
q(z) = -2*z
for _ in 1:1000
z = randn()
qz = transform_logdensity(asℝ₊, f, z)
qz = @inferred transform_logdensity(asℝ₊, f, z)
@test q(z) qz
end
end

@testset "transform logdensity: type inference" begin
t = as((a = asℝ₋, b = as𝕀, c = as((d = UnitVector(7), e = CorrCholeskyFactor(3))),
f = as(Array, 9)))
z = zeros(dimension(t))
f(θ) = θ.a + θ.b + sum(abs2, θ.c.d) + sum(abs2, θ.c.e)
@test (@inferred f(t(z))) isa Float64
@test (@inferred transform_logdensity(t, f, z)) isa Float64
end

@testset "custom transformation: triangle below diagonal in [0,1]²" begin
tfun(y) = y[1], y[1]*y[2] # triangle below diagonal in unit square
t = CustomTransform(as(Array, as𝕀, 2), tfun, collect;)
Expand Down Expand Up @@ -298,3 +307,28 @@ end
@test lj2 -lj
end
end

@testset "inference of nested tuples" begin
# An MWE adapted from a real-life problem
ABOVE1 = as(Real, 1, ∞) # transformation for μ ≥ 1

trans_β̃s = as((asℝ, asℝ)) # a tuple of 2 elements, otherwise identity

PARAMS_TRANSFORMATIONS =
(EE = as((β̃s = trans_β̃s, μs = as((as𝕀, as𝕀)))),
EN = as((w̃₂ = asℝ, β̃s = trans_β̃s, μs = as((as𝕀, ABOVE1)))),
NE = as((w̃₁ = asℝ, β̃s = trans_β̃s, μs = as((ABOVE1, as𝕀)))),
NN = as((w̃s = as((asℝ, asℝ)), β̃s = trans_β̃s, μs = as((ABOVE1, ABOVE1)))))

function make_transformation(ls)
as((hyper_parameters = as((μ = as(Array, 6),
σ = as(Array, asℝ₊, 6),
= CorrCholeskyFactor(6))),
couple_parameters = as(map((t, l) -> as(Array, t, l),
PARAMS_TRANSFORMATIONS, ls))))
end
t = make_transformation((EE = 1, EN = 2 , NE = 3, NN = 4,))
x = zeros(dimension(t))
@test_nowarn @inferred transform(t, x)
@test_nowarn @inferred transform_and_logjac(t, x)
end

0 comments on commit 23add70

Please sign in to comment.