diff --git a/src/hyperparameter/autotuning.jl b/src/hyperparameter/autotuning.jl index d9cf4586..3c1977f4 100644 --- a/src/hyperparameter/autotuning.jl +++ b/src/hyperparameter/autotuning.jl @@ -23,7 +23,6 @@ function update_hyperparameters!(m::GP) Δ = ForwardDiff.gradient(θ) do θ ELBO(m, re(θ)...) end - @show Δ end end # end @@ -38,7 +37,7 @@ end # end @traitfn function update_hyperparameters!(m::TGP) where {TGP <: AbstractGP; IsFull{TGP}} - if any((!) ∘ isnothing ∘ opt, m) # Check there is a least one optimiser + if any((!) ∘ isnothing ∘ opt, m.f) # Check there is a least one optimiser μ₀ = pr_means(m) # Get prior means ks = kernels(m) # Get kernels if ADBACKEND[] == :Zygote @@ -60,7 +59,6 @@ end Δ = ForwardDiff.gradient(θ) do θ ELBO(m, re(θ)...) end - @show Δ end end return nothing @@ -68,7 +66,7 @@ end @traitfn function update_hyperparameters!(m::TGP) where {TGP <: AbstractGP; !IsFull{TGP}} # Check that here is least one optimiser - if any((!) ∘ isnothing ∘ opt, m) || any((!) ∘ isnothing ∘ opt ∘ Zview, m) + if any((!) ∘ isnothing ∘ opt, m.f) || any((!) ∘ isnothing ∘ opt ∘ Zview, m.f) μ₀ = pr_means(m) ks = kernels(m) Zs = Zviews(m) @@ -76,6 +74,7 @@ end Δμ₀, Δk, ΔZ = Zygote.gradient(μ₀, ks, Zs) do μ₀, ks, Zs ELBO(m, μ₀, ks, Zs) end + @show ΔZ # Optimize prior mean isnothing(Δμ₀) || update!.(μ₀, Δμ₀, Ref(xview(m))) # Optimize kernel parameters diff --git a/src/hyperparameter/autotuning_utils.jl b/src/hyperparameter/autotuning_utils.jl index f1de8738..5b3efb1b 100644 --- a/src/hyperparameter/autotuning_utils.jl +++ b/src/hyperparameter/autotuning_utils.jl @@ -1,6 +1,6 @@ ### Global constant allowing to chose between forward_diff and zygote_diff for hyperparameter optimization ### -const ADBACKEND = Ref(:zygote) +const ADBACKEND = Ref(:Zygote) function setadbackend(ad_backend::Symbol) (ad_backend == :ForwardDiff || ad_backend == :Zygote) || diff --git a/src/models/SVGP.jl b/src/models/SVGP.jl index 5560a4be..800a3131 100644 --- a/src/models/SVGP.jl +++ b/src/models/SVGP.jl @@ -56,12 +56,13 @@ function SVGP( kernel, likelihood, inference, - KmeansIP(X, nInducingPoints), - verbose = verbose, - optimiser = optimiser, - atfrequency = atfrequency, - mean = mean, - obsdim = obsdim + KmeansIP(X, nInducingPoints); + verbose, + optimiser, + atfrequency, + mean, + Zoptimiser, + obsdim, ) end