diff --git a/Project.toml b/Project.toml index 9ad97054c..08e12b6ff 100644 --- a/Project.toml +++ b/Project.toml @@ -9,21 +9,22 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Compat = "2.2" Distances = "0.8" PDMats = "0.9" -SpecialFunctions = "0" +SpecialFunctions = "0.8, 0.9" StatsFuns = "0.8" -Zygote = "0.4" +ZygoteRules = "0.2" julia = "1.0" [extras] +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Random", "Test", "FiniteDifferences"] +test = ["Random", "Test", "FiniteDifferences", "Zygote"] diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index d6f18dcd0..847c27d33 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -16,8 +16,8 @@ export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransf using Compat using Distances, LinearAlgebra -using SpecialFunctions: lgamma, besselk -using Zygote: @adjoint +using SpecialFunctions: logabsgamma, besselk +using ZygoteRules: @adjoint using StatsFuns: logtwo using PDMats: PDMat diff --git a/src/kernels/matern.jl b/src/kernels/matern.jl index 63b856c91..8d41f4edb 100644 --- a/src/kernels/matern.jl +++ b/src/kernels/matern.jl @@ -33,7 +33,7 @@ end params(k::MaternKernel) = (params(transform(k)),k.ν) opt_params(k::MaternKernel) = (opt_params(transform(k)),k.ν) -@inline kappa(κ::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((1.0-κ.ν)*logtwo-lgamma(κ.ν) + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk(κ.ν,sqrt(2κ.ν)*d))) +@inline kappa(κ::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((1.0-κ.ν)*logtwo-logabsgamma(κ.ν)[1] + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk(κ.ν,sqrt(2κ.ν)*d))) """ `Matern32Kernel([ρ=1.0])`