Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[weakdeps]
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[extensions]
ChangesOfVariablesExt = "ChangesOfVariables"
InverseFunctionsExt = "InverseFunctions"
UnitfulExt = "Unitful"

[compat]
ArgCheck = "1, 2"
Expand All @@ -30,4 +32,5 @@ LinearAlgebra = "1.6"
LogExpFunctions = "0.3"
Random = "1.6"
StaticArrays = "1"
Unitful = "1"
julia = "1.10"
105 changes: 105 additions & 0 deletions ext/UnitfulExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
module UnitfulExt

using ArgCheck: @argcheck
using DocStringExtensions: FUNCTIONNAME, SIGNATURES, TYPEDEF
using LogExpFunctions: logistic, logit
import Unitful: Quantity, Units, ustrip, NoUnits
import TransformVariables: ScalarTransform, transform, transform_and_logjac, inverse, as, Infinity, asℝ₊


####
#### shifted exponential with units
####
"""
$(TYPEDEF)

Shifted exponential with units. When `D::Bool == true`, maps to `(shift, ∞)` using `x ↦
shift + eˣ`, otherwise to `(-∞, shift)` using `x ↦ shift - eˣ`.
"""
struct DimShiftedExp{D, T <: Quantity} <: ScalarTransform
shift::T
function DimShiftedExp{D,T}(shift::T) where {D, T <: Quantity}
@argcheck D isa Bool
new(shift)
end
end

DimShiftedExp(D::Bool, shift::T) where {T <: Quantity} = DimShiftedExp{D,T}(shift)

transform(t::DimShiftedExp{D, Quantity{V,DD,U}}, x::Real) where {D, V, DD, U} =
D ? t.shift + exp(x)*U() : t.shift - exp(x)*U()

# NOTE: not sure if or how this should be defined, since units.
# transform_and_logjac(t::DimShiftedExp, x::Real) = transform(t, x), x

function inverse(t::DimShiftedExp{D, Quantity{V,DD,U}}, x::Quantity) where {D, V, DD, U}
(; shift) = t
if D
@argcheck x > shift DomainError
log(ustrip(U(), x - shift))
else
@argcheck x < shift DomainError
log(ustrip(U(), shift - x))
end
end

###
#### scaled-shifted logistic with units
####

"""
$(TYPEDEF)

Maps to `(scale, shift + scale)` using `logistic(x) * scale + shift`.
"""
struct DimScaledShiftedLogistic{T <: Quantity} <: ScalarTransform
scale::T
shift::T
function DimScaledShiftedLogistic{T}(scale::T, shift::T) where {T <: Quantity}
@argcheck scale > zero(typeof(scale))
new(scale, shift)
end
end

DimScaledShiftedLogistic(scale::T, shift::T) where {T <: Quantity} =
DimScaledShiftedLogistic{T}(scale, shift)

function DimScaledShiftedLogistic(scale::T1, shift::T2) where {T1 <: Quantity, T2 <: Quantity}
DimScaledShiftedLogistic(promote(scale, shift)...)
end

# # Switch to muladd and now it does have a DiffRule defined
transform(t::DimScaledShiftedLogistic, x::Real) = muladd(logistic(x), t.scale, t.shift)

# NOTE: not sure if or how this should be defined, since units.
# transform_and_logjac(t::ScaledShiftedLogistic, x) =
# transform(t, x), log(t.scale) + logistic_logjac(x)

function inverse(t::DimScaledShiftedLogistic{Quantity{N,D,U}}, y) where {N,D,U}
@argcheck y > t.shift DomainError
@argcheck y < t.scale + t.shift DomainError
logit(ustrip(NoUnits, (y - t.shift)/t.scale))
end

# NOTE: not sure if or how this should be defined, since units.
# # NOTE: inverse_and_logjac interface experimental and sporadically implemented for now
# function inverse_and_logjac(t::ScaledShiftedLogistic, y)
# @argcheck y > t.shift DomainError
# @argcheck y < t.scale + t.shift DomainError
# z = (y - t.shift) / t.scale
# logit(z), logit_logjac(z) - log(t.scale)
# end

function as(::Type{Real}, left::Quantity, right::Quantity)
@argcheck left < right "the interval ($(left), $(right)) is empty"
DimScaledShiftedLogistic(right - left, left)
end

as(::Type{Real}, left::Quantity, ::Infinity{true}) = DimShiftedExp(true, left)

as(::Type{Real}, ::Infinity{false}, right::Quantity) = DimShiftedExp(false, right)

Base.:(*)(a::typeof(asℝ₊), u::Units) = as(Real, 0.0*u, Infinity{true}())
Base.:(*)(a::typeof(asℝ₋), u::Units) = as(Real, Infinity{false}(), 0.0*u)

end # module
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TransformedLogDensities = "f9bc47f6-f3f8-4f3b-ab21-f8bc73906f26"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
41 changes: 41 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using TransformVariables:
unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac,
NOLOGJAC, transform_with
import ChangesOfVariables, InverseFunctions
import Unitful: @u_str, ustrip
using Enzyme: autodiff, ReverseWithPrimal, Active, Const

const CIENV = get(ENV, "CI", "") == "true"
Expand Down Expand Up @@ -598,6 +599,46 @@ end
InverseFunctions.test_inverse(inv_f, 1.7)
end

@testset "Unitful" begin
@testset "finite" begin
t = as(Real, 0.0u"s", 2u"hr")
f = transform(t)
inv_f = inverse(t)
q = transform(t, 1.0)
@test ustrip(u"s", q) > 0
@test_throws DomainError inverse(t, -1.0u"s")
@test_throws DomainError inverse(t, 3u"hr")
InverseFunctions.test_inverse(f, -4.2)
InverseFunctions.test_inverse(inv_f, 1.7u"s")
end

@testset "positive real" begin
t = as(Real, -1.0u"s", ∞)
f = transform(t)
inv_f = inverse(t)
q = transform(t, 1.0)
@test q ≈ exp(1.0)*u"s" - 1.0u"s"
@test ustrip(u"s", q) > -1
@test_throws DomainError inverse(t, -2.0u"s")
InverseFunctions.test_inverse(f, -4.2)
InverseFunctions.test_inverse(inv_f, 1.7u"s")
end

@testset "negative real" begin
t = as(Real, -∞, 1.0u"s")
f = transform(t)
inv_f = inverse(t)
q = transform(t, 2.0)
@test q ≈ -exp(2.0)*u"s" + 1.0u"s"
@test ustrip(u"s", q) < 1
@test_throws DomainError inverse(t, 2.0u"s")
InverseFunctions.test_inverse(f, -4.2)
InverseFunctions.test_inverse(inv_f, -1.7u"s")
end
end



@testset "as static array" begin
S = Tuple{2,3,4}
t = as(SArray{S})
Expand Down
Loading