diff --git a/src/inference/analyticVI.jl b/src/inference/analyticVI.jl index 49d14254..ec429b2f 100644 --- a/src/inference/analyticVI.jl +++ b/src/inference/analyticVI.jl @@ -266,9 +266,9 @@ end state.local_vars, ) tot -= GaussianKL(model, state) - tot -= Zygote.@ignore( + tot -= ChainRulesCore.ignore_derivatives() do ρ(inference(model)) * AugmentedKL(likelihood(model), state.local_vars, y) - ) + end tot -= extraKL(model, state) return tot end @@ -290,9 +290,9 @@ end state.local_vars, ) tot -= GaussianKL(model, state) - tot -= Zygote.@ignore( + tot -= ChainRulesCore.ignore_derivatives() do sum(ρ(inference(model)) * AugmentedKL.(likelihood(model), state.local_vars, y)) - ) + end return tot end diff --git a/src/likelihood/bayesiansvm.jl b/src/likelihood/bayesiansvm.jl index 6bfefbb6..c70e4dd5 100644 --- a/src/likelihood/bayesiansvm.jl +++ b/src/likelihood/bayesiansvm.jl @@ -80,7 +80,7 @@ function expec_loglikelihood( end function AugmentedKL(l::BernoulliLikelihood{<:SVMLink}, state, ::Any) - Zygote.@ignore(GIGEntropy(l, state)) + ChainRulesCore.@ignore_derivatives GIGEntropy(l, state) end function GIGEntropy(::BernoulliLikelihood{<:SVMLink}, state) diff --git a/src/likelihood/laplace.jl b/src/likelihood/laplace.jl index af712034..53b345fb 100644 --- a/src/likelihood/laplace.jl +++ b/src/likelihood/laplace.jl @@ -99,7 +99,7 @@ function expec_loglikelihood( state, ) tot = -length(y) * log(twoπ) / 2 - tot += Zygote.@ignore(sum(log, state.θ)) / 2 + tot += ChainRulesCore.@ignore_derivatives sum(log, state.θ) / 2 tot += -( dot(state.θ, diag_cov) + dot(state.θ, abs2.(μ)) - 2.0 * dot(state.θ, μ .* y) + diff --git a/src/likelihood/negativebinomial.jl b/src/likelihood/negativebinomial.jl index 4e81b5b5..8573b96e 100644 --- a/src/likelihood/negativebinomial.jl +++ b/src/likelihood/negativebinomial.jl @@ -116,7 +116,7 @@ function expec_loglikelihood( diag_cov::AbstractVector, state, ) - tot = Zygote.@ignore(sum(negbin_logconst(y, l.r))) - log(2.0) * sum(y .+ l.r) + tot = ChainRulesCore.@ignore_derivatives(sum(negbin_logconst(y, l.r))) - log(2.0) * sum(y .+ l.r) tot += dot(μ, (y .- l.r)) / 2 - dot(state.θ, μ) / 2 - dot(state.θ, diag_cov) / 2 return tot end diff --git a/src/likelihood/poisson.jl b/src/likelihood/poisson.jl index 73d5e0a0..89ec4aa8 100644 --- a/src/likelihood/poisson.jl +++ b/src/likelihood/poisson.jl @@ -114,9 +114,9 @@ function expec_loglikelihood( state, ) tot = (dot(μ, (y - state.γ)) - dot(state.θ, abs2.(μ)) - dot(state.θ, Σ)) / 2 - tot += Zygote.@ignore( + tot += ChainRulesCore.ignore_derivatives() do sum(y * log(l.invlink.λ[1])) - sum(logfactorial, y) - logtwo * sum((y + state.γ)) - ) + end return tot end diff --git a/test/Project.toml b/test/Project.toml index ad8f0d6c..a4d7ba19 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,3 +7,10 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Distances = "0.10" +Distributions = "0.25" +MLDataUtils = "0.5" +PDMats = "0.11" +Zygote = "0.6" \ No newline at end of file