diff --git a/Project.toml b/Project.toml index a442357..280bead 100644 --- a/Project.toml +++ b/Project.toml @@ -15,10 +15,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [weakdeps] ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [extensions] ChangesOfVariablesExt = "ChangesOfVariables" InverseFunctionsExt = "InverseFunctions" +UnitfulExt = "Unitful" [compat] ArgCheck = "1, 2" @@ -30,4 +32,5 @@ LinearAlgebra = "1.6" LogExpFunctions = "0.3" Random = "1.6" StaticArrays = "1" +Unitful = "1" julia = "1.10" diff --git a/ext/UnitfulExt.jl b/ext/UnitfulExt.jl new file mode 100644 index 0000000..8f2c9ba --- /dev/null +++ b/ext/UnitfulExt.jl @@ -0,0 +1,105 @@ +module UnitfulExt + +using ArgCheck: @argcheck +using DocStringExtensions: FUNCTIONNAME, SIGNATURES, TYPEDEF +using LogExpFunctions: logistic, logit +import Unitful: Quantity, Units, ustrip, NoUnits +import TransformVariables: ScalarTransform, transform, transform_and_logjac, inverse, as, Infinity, asℝ₊ + + +#### +#### shifted exponential with units +#### +""" +$(TYPEDEF) + +Shifted exponential with units. When `D::Bool == true`, maps to `(shift, ∞)` using `x ↦ +shift + eˣ`, otherwise to `(-∞, shift)` using `x ↦ shift - eˣ`. +""" +struct DimShiftedExp{D, T <: Quantity} <: ScalarTransform + shift::T + function DimShiftedExp{D,T}(shift::T) where {D, T <: Quantity} + @argcheck D isa Bool + new(shift) + end +end + +DimShiftedExp(D::Bool, shift::T) where {T <: Quantity} = DimShiftedExp{D,T}(shift) + +transform(t::DimShiftedExp{D, Quantity{V,DD,U}}, x::Real) where {D, V, DD, U} = + D ? t.shift + exp(x)*U() : t.shift - exp(x)*U() + +# NOTE: not sure if or how this should be defined, since units. +# transform_and_logjac(t::DimShiftedExp, x::Real) = transform(t, x), x + +function inverse(t::DimShiftedExp{D, Quantity{V,DD,U}}, x::Quantity) where {D, V, DD, U} + (; shift) = t + if D + @argcheck x > shift DomainError + log(ustrip(U(), x - shift)) + else + @argcheck x < shift DomainError + log(ustrip(U(), shift - x)) + end +end + +### +#### scaled-shifted logistic with units +#### + +""" +$(TYPEDEF) + +Maps to `(scale, shift + scale)` using `logistic(x) * scale + shift`. +""" +struct DimScaledShiftedLogistic{T <: Quantity} <: ScalarTransform + scale::T + shift::T + function DimScaledShiftedLogistic{T}(scale::T, shift::T) where {T <: Quantity} + @argcheck scale > zero(typeof(scale)) + new(scale, shift) + end +end + +DimScaledShiftedLogistic(scale::T, shift::T) where {T <: Quantity} = + DimScaledShiftedLogistic{T}(scale, shift) + +function DimScaledShiftedLogistic(scale::T1, shift::T2) where {T1 <: Quantity, T2 <: Quantity} + DimScaledShiftedLogistic(promote(scale, shift)...) +end + +# # Switch to muladd and now it does have a DiffRule defined +transform(t::DimScaledShiftedLogistic, x::Real) = muladd(logistic(x), t.scale, t.shift) + +# NOTE: not sure if or how this should be defined, since units. +# transform_and_logjac(t::ScaledShiftedLogistic, x) = +# transform(t, x), log(t.scale) + logistic_logjac(x) + +function inverse(t::DimScaledShiftedLogistic{Quantity{N,D,U}}, y) where {N,D,U} + @argcheck y > t.shift DomainError + @argcheck y < t.scale + t.shift DomainError + logit(ustrip(NoUnits, (y - t.shift)/t.scale)) +end + +# NOTE: not sure if or how this should be defined, since units. +# # NOTE: inverse_and_logjac interface experimental and sporadically implemented for now +# function inverse_and_logjac(t::ScaledShiftedLogistic, y) +# @argcheck y > t.shift DomainError +# @argcheck y < t.scale + t.shift DomainError +# z = (y - t.shift) / t.scale +# logit(z), logit_logjac(z) - log(t.scale) +# end + +function as(::Type{Real}, left::Quantity, right::Quantity) + @argcheck left < right "the interval ($(left), $(right)) is empty" + DimScaledShiftedLogistic(right - left, left) +end + +as(::Type{Real}, left::Quantity, ::Infinity{true}) = DimShiftedExp(true, left) + +as(::Type{Real}, ::Infinity{false}, right::Quantity) = DimShiftedExp(false, right) + +Base.:(*)(a::typeof(asℝ₊), u::Units) = as(Real, 0.0*u, Infinity{true}()) +Base.:(*)(a::typeof(asℝ₋), u::Units) = as(Real, Infinity{false}(), 0.0*u) + +end # module \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index bfaad77..e26b067 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,3 +13,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TransformedLogDensities = "f9bc47f6-f3f8-4f3b-ab21-f8bc73906f26" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/test/runtests.jl b/test/runtests.jl index f972660..a0edccb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using TransformVariables: unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, NOLOGJAC, transform_with import ChangesOfVariables, InverseFunctions +import Unitful: @u_str, ustrip using Enzyme: autodiff, ReverseWithPrimal, Active, Const const CIENV = get(ENV, "CI", "") == "true" @@ -598,6 +599,46 @@ end InverseFunctions.test_inverse(inv_f, 1.7) end +@testset "Unitful" begin + @testset "finite" begin + t = as(Real, 0.0u"s", 2u"hr") + f = transform(t) + inv_f = inverse(t) + q = transform(t, 1.0) + @test ustrip(u"s", q) > 0 + @test_throws DomainError inverse(t, -1.0u"s") + @test_throws DomainError inverse(t, 3u"hr") + InverseFunctions.test_inverse(f, -4.2) + InverseFunctions.test_inverse(inv_f, 1.7u"s") + end + + @testset "positive real" begin + t = as(Real, -1.0u"s", ∞) + f = transform(t) + inv_f = inverse(t) + q = transform(t, 1.0) + @test q ≈ exp(1.0)*u"s" - 1.0u"s" + @test ustrip(u"s", q) > -1 + @test_throws DomainError inverse(t, -2.0u"s") + InverseFunctions.test_inverse(f, -4.2) + InverseFunctions.test_inverse(inv_f, 1.7u"s") + end + + @testset "negative real" begin + t = as(Real, -∞, 1.0u"s") + f = transform(t) + inv_f = inverse(t) + q = transform(t, 2.0) + @test q ≈ -exp(2.0)*u"s" + 1.0u"s" + @test ustrip(u"s", q) < 1 + @test_throws DomainError inverse(t, 2.0u"s") + InverseFunctions.test_inverse(f, -4.2) + InverseFunctions.test_inverse(inv_f, -1.7u"s") + end +end + + + @testset "as static array" begin S = Tuple{2,3,4} t = as(SArray{S})