Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ $(TYPEDEF)

Transform a scalar (real number) to another scalar.

Subtypes mustdefine `transform`, `transform_and_logjac`, and `inverse`; other
methods of of the interface should have the right defaults.
Subtypes must define `transform`, `transform_and_logjac`, and `inverse`.
Other methods of of the interface should have the right defaults.

!!! NOTE
This type is for code organization within the package, and is not part of the public API.
"""
abstract type ScalarTransform <: AbstractTransform end

Expand All @@ -26,7 +29,11 @@ function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y::Real)
index + 1
end

inverse_eltype(t::ScalarTransform, y::T) where {T <: Real} = float(T)
function inverse_eltype(t::ScalarTransform, y::Real)
# NOTE this is a shortcut to get sensible types for all subtypes of ScalarTransform, which
# we test for. If it breaks it should be extended accordingly.
return Base.promote_typejoin_union(Base.promote_op(inverse, typeof(t), typeof(y)))
end

_domain_label(::ScalarTransform, index::Int) = ()

Expand Down
31 changes: 31 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,34 @@ end
U = transform(t, x)
@test isfinite(logabsdet(U)[1])
end

@testset "inverse_eltype of scalar transforms with parameters" begin
# `Float64` parameters and `Float32` input
for t in (as(Real, 0.5, ∞), as(Real, -∞, 2.1), as(Real, 0.5, 2.1))
@test @inferred(inverse_eltype(t, 1.1f0)) === Float64
@test @inferred(inverse(t, 1.1f0)) isa Float64
end

# Derivatives wrt parameters of the transforms
d1 = ForwardDiff.derivative(5.3) do x
return @inferred only(inverse(as(Vector, as(Real, x, ∞), 1), [10]))
end
d2 = ForwardDiff.derivative(5.3) do x
return @inferred inverse(as(Real, x, ∞), 10)
end
@test d1 == d2
d1 = ForwardDiff.derivative(-3) do x
return @inferred only(inverse(as(Vector, as(Real, -∞, x), 1), [-6.1]))
end
d2 = ForwardDiff.derivative(-3) do x
return @inferred inverse(as(Real, -∞, x), -6.1)
end
@test d1 == d2
d1 = ForwardDiff.gradient([-0.3, 4.7]) do x
return @inferred only(inverse(as(Vector, as(Real, x[1], x[2]), 1), [2.3]))
end
d2 = ForwardDiff.gradient([-0.3, 4.7]) do x
return @inferred inverse(as(Real, x[1], x[2]), 2.3)
end
@test d1 == d2
end
Loading