Skip to content

Commit

Permalink
Fix element type calculations, especially for Flux. (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
tpapp committed Dec 23, 2018
1 parent 23add70 commit 697c248
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 17 deletions.
7 changes: 4 additions & 3 deletions src/aggregation.jl
Expand Up @@ -59,15 +59,15 @@ function transform_with(flag::LogJacFlag, t::ArrayTransform, x::RealVector)
d = dimension(transformation)
I = reshape(range(firstindex(x); length = prod(dims), step = d), dims)
yℓ = map(i -> transform_with(flag, transformation, view_into(x, i, d)), I)
ℓz = logjac_zero(flag, eltype(x))
ℓz = logjac_zero(flag, extended_eltype(x))
first.(yℓ), isempty(yℓ) ? ℓz : ℓz + sum(last, yℓ)
end

function transform_with(flag::LogJacFlag, t::ArrayTransform{Identity}, x::RealVector)
# TODO use version below when https://github.com/FluxML/Flux.jl/issues/416 is fixed
# y = reshape(copy(x), t.dims)
y = reshape(map(identity, x), t.dims)
y, logjac_zero(flag, eltype(x))
y, logjac_zero(flag, extended_eltype(x))
end

inverse_eltype(t::ArrayTransform, x::AbstractArray) =
Expand Down Expand Up @@ -152,7 +152,8 @@ $(SIGNATURES)
Helper function for transforming tuples. Used internally, to help type inference. Use via
`transfom_tuple`.
"""
_transform_tuple(flag::LogJacFlag, x::RealVector, index, ::Tuple{}) = (), logjac_zero(flag, eltype(x))
_transform_tuple(flag::LogJacFlag, x::RealVector, index, ::Tuple{}) =
(), logjac_zero(flag, extended_eltype(x))

function _transform_tuple(flag::LogJacFlag, x::RealVector, index, ts)
tfirst = first(ts)
Expand Down
13 changes: 7 additions & 6 deletions src/special_arrays.jl
Expand Up @@ -41,8 +41,9 @@ end

dimension(t::UnitVector) = t.n - 1

function transform_with(flag::LogJacFlag, t::UnitVector, x::RealVector{T}) where T
function transform_with(flag::LogJacFlag, t::UnitVector, x::RealVector)
@unpack n = t
T = extended_eltype(x)
r = one(T)
y = Vector{T}(undef, n)
= logjac_zero(flag, T)
Expand All @@ -57,7 +58,7 @@ function transform_with(flag::LogJacFlag, t::UnitVector, x::RealVector{T}) where
y, ℓ
end

inverse_eltype(t::UnitVector, y::RealVector) = float(eltype(y))
inverse_eltype(t::UnitVector, y::RealVector) = extended_eltype(y)

function inverse!(x::RealVector, t::UnitVector, y::RealVector)
@unpack n = t
Expand Down Expand Up @@ -93,11 +94,11 @@ end

dimension(t::CorrCholeskyFactor) = unit_triangular_dimension(t.n)

function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor,
x::RealVector{T}) where T
function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, x::RealVector)
@unpack n = t
T = extended_eltype(x)
= logjac_zero(flag, T)
U = zeros(typeof(one(T)), n, n)
U = Matrix{T}(undef, n, n)
index = firstindex(x)
@inbounds for col in 1:n
r = one(T)
Expand All @@ -112,7 +113,7 @@ function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor,
UpperTriangular(U), ℓ
end

inverse_eltype(t::CorrCholeskyFactor, U::UpperTriangular) = float(eltype(U))
inverse_eltype(t::CorrCholeskyFactor, U::UpperTriangular) = extended_eltype(U)

function inverse!(x::RealVector, t::CorrCholeskyFactor, U::UpperTriangular)
@unpack n = t
Expand Down
42 changes: 34 additions & 8 deletions src/utilities.jl
@@ -1,5 +1,6 @@

# logistic and logit
###
### logistic and logit
###

logistic(x::Real) = inv(one(x) + exp(-x))

Expand All @@ -12,8 +13,9 @@ logit(x::Real) = log(x / (one(x) - x))

logit_logjac(y) = -log(y) - log1p(-y)


# calculations
###
### calculations
###

"""
$SIGNATURES
Expand All @@ -22,8 +24,27 @@ Number of elements (strictly) above the diagonal in an ``n×n`` matrix.
"""
unit_triangular_dimension(n::Int) = n * (n-1) ÷ 2


# view management
###
### type calculations
###

"""
$(SIGNATURES)
Extend element type of argument so that it is closed under the algebra used by this package.
Pessimistic default for non-real types.
"""
function extended_eltype(::Type{S}) where S
T = eltype(S)
T <: Real ? typeof((one(T))) : Any
end

extended_eltype(x::T) where T = extended_eltype(T)

###
### view management
###

"""
$SIGNATURES
Expand All @@ -32,8 +53,9 @@ A view of `v` starting from `i` for `len` elements, no bounds checking.
"""
view_into(v::AbstractVector, i, len) = @inbounds view(v, i:(i+len-1))


# macros
###
### macros
###

"""
$(SIGNATURES)
Expand All @@ -57,6 +79,10 @@ macro calltrans(ex)
end
end

####
#### random values
####

"Shared part of docstrings for keyword arguments of or passed to [`random_reals`](@ref)."
const _RANDOM_REALS_KWARGS_DOC = """
A standard multivaritate normal or Cauchy is used, depending on `cauchy`, then scaled with
Expand Down
7 changes: 7 additions & 0 deletions test/runtests.jl
Expand Up @@ -288,6 +288,13 @@ end
@test g2.value == v.value
@test g2.gradient g1.gradient

# test element type calculations for Flux
t2 = CorrCholeskyFactor(4)
@test t2(Flux.param(ones(dimension(t2)))) isa UpperTriangular

t3 = UnitVector(3)
@test sum(abs2, t3(Flux.param(ones(dimension(t3))))) Flux.param(1.0)

# ReverseDiff
P3 = ADgradient(:ReverseDiff, P)
g3 = @inferred logdensity(ValueGradient, P3, x)
Expand Down

0 comments on commit 697c248

Please sign in to comment.