From 395036c15def73095f1a0a21975939304ca51f74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Fri, 10 Feb 2023 15:16:07 +0100 Subject: [PATCH] Minor fixes for static correlation Cholesky factor. - missing continuation index - fix AD with ReverseDiff (its tracker type is not a bits type) --- Project.toml | 2 +- src/aggregation.jl | 2 +- src/special_arrays.jl | 12 +++++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index f6c0f4c..b7ab7d0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TransformVariables" uuid = "84d833dd-6860-57f9-a1a7-6da5db126cff" authors = ["Tamas K. Papp "] -version = "0.8.0" +version = "0.8.1" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/aggregation.jl b/src/aggregation.jl index f1c1130..c125716 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -128,7 +128,7 @@ function transform_with(flag::LogJacFlag, transformation::StaticArrayTransformat ℓ += ℓΔ y end - for _ in 1:D), ℓ + for _ in 1:D), ℓ, index end function inverse_eltype(transformation::Union{ArrayTransformation,StaticArrayTransformation}, diff --git a/src/special_arrays.jl b/src/special_arrays.jl index cdb5dbc..6c234a4 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -238,9 +238,15 @@ end function transform_with(flag::LogJacFlag, transformation::StaticCorrCholeskyFactor{D,S}, x::AbstractVector{T}, index) where {D,S,T} # NOTE: add an unrolled version for small sizes - U, ℓ, index′ = calculate_corr_cholesky_factor!(zero(MMatrix{S,S,robust_eltype(x)}), - flag, x, index) - UpperTriangular(SMatrix(U)), ℓ, index′ + E = robust_eltype(x) + U = if isbitstype(E) + zero(MMatrix{S,S,robust_eltype(x)}) + else + # NOTE: currently allocating because non-bitstype based AD (eg ReverseDiff) does not work with MMatrix + zeros(E, S, S) + end + U, ℓ, index′ = calculate_corr_cholesky_factor!(U, flag, x, index) + UpperTriangular(SMatrix{S,S}(U)), ℓ, index′ end inverse_eltype(t::Union{CorrCholeskyFactor,StaticCorrCholeskyFactor}, U::UpperTriangular) = robust_eltype(U)