diff --git a/src/scalar.jl b/src/scalar.jl index 13384cc..db3503e 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -6,8 +6,11 @@ $(TYPEDEF) Transform a scalar (real number) to another scalar. -Subtypes mustdefine `transform`, `transform_and_logjac`, and `inverse`; other -methods of of the interface should have the right defaults. +Subtypes must define `transform`, `transform_and_logjac`, and `inverse`. +Other methods of of the interface should have the right defaults. + +!!! NOTE + This type is for code organization within the package, and is not part of the public API. """ abstract type ScalarTransform <: AbstractTransform end @@ -26,7 +29,11 @@ function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y::Real) index + 1 end -inverse_eltype(t::ScalarTransform, y::T) where {T <: Real} = float(T) +function inverse_eltype(t::ScalarTransform, y::Real) + # NOTE this is a shortcut to get sensible types for all subtypes of ScalarTransform, which + # we test for. If it breaks it should be extended accordingly. + return Base.promote_typejoin_union(Base.promote_op(inverse, typeof(t), typeof(y))) +end _domain_label(::ScalarTransform, index::Int) = () diff --git a/test/runtests.jl b/test/runtests.jl index f972660..73b3e22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -699,3 +699,34 @@ end U = transform(t, x) @test isfinite(logabsdet(U)[1]) end + +@testset "inverse_eltype of scalar transforms with parameters" begin + # `Float64` parameters and `Float32` input + for t in (as(Real, 0.5, ∞), as(Real, -∞, 2.1), as(Real, 0.5, 2.1)) + @test @inferred(inverse_eltype(t, 1.1f0)) === Float64 + @test @inferred(inverse(t, 1.1f0)) isa Float64 + end + + # Derivatives wrt parameters of the transforms + d1 = ForwardDiff.derivative(5.3) do x + return @inferred only(inverse(as(Vector, as(Real, x, ∞), 1), [10])) + end + d2 = ForwardDiff.derivative(5.3) do x + return @inferred inverse(as(Real, x, ∞), 10) + end + @test d1 == d2 + d1 = ForwardDiff.derivative(-3) do x + return @inferred only(inverse(as(Vector, as(Real, -∞, x), 1), [-6.1])) + end + d2 = ForwardDiff.derivative(-3) do x + return @inferred inverse(as(Real, -∞, x), -6.1) + end + @test d1 == d2 + d1 = ForwardDiff.gradient([-0.3, 4.7]) do x + return @inferred only(inverse(as(Vector, as(Real, x[1], x[2]), 1), [2.3])) + end + d2 = ForwardDiff.gradient([-0.3, 4.7]) do x + return @inferred inverse(as(Real, x[1], x[2]), 2.3) + end + @test d1 == d2 +end