Skip to content

Commit

Permalink
Fixed autotuning
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf committed Mar 30, 2021
1 parent ca5044e commit d0f634b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
7 changes: 3 additions & 4 deletions src/hyperparameter/autotuning.jl
Expand Up @@ -23,7 +23,6 @@ function update_hyperparameters!(m::GP)
Δ = ForwardDiff.gradient(θ) do θ
ELBO(m, re(θ)...)
end
@show Δ
end
end
# end
Expand All @@ -38,7 +37,7 @@ end
# end

@traitfn function update_hyperparameters!(m::TGP) where {TGP <: AbstractGP; IsFull{TGP}}
if any((!) isnothing opt, m) # Check there is a least one optimiser
if any((!) isnothing opt, m.f) # Check there is a least one optimiser
μ₀ = pr_means(m) # Get prior means
ks = kernels(m) # Get kernels
if ADBACKEND[] == :Zygote
Expand All @@ -60,22 +59,22 @@ end
Δ = ForwardDiff.gradient(θ) do θ
ELBO(m, re(θ)...)
end
@show Δ
end
end
return nothing
end

@traitfn function update_hyperparameters!(m::TGP) where {TGP <: AbstractGP; !IsFull{TGP}}
# Check that here is least one optimiser
if any((!) isnothing opt, m) || any((!) isnothing opt Zview, m)
if any((!) isnothing opt, m.f) || any((!) isnothing opt Zview, m.f)
μ₀ = pr_means(m)
ks = kernels(m)
Zs = Zviews(m)
if ADBACKEND[] == :Zygote
Δμ₀, Δk, ΔZ = Zygote.gradient(μ₀, ks, Zs) do μ₀, ks, Zs
ELBO(m, μ₀, ks, Zs)
end
@show ΔZ
# Optimize prior mean
isnothing(Δμ₀) || update!.(μ₀, Δμ₀, Ref(xview(m)))
# Optimize kernel parameters
Expand Down
2 changes: 1 addition & 1 deletion src/hyperparameter/autotuning_utils.jl
@@ -1,6 +1,6 @@

### Global constant allowing to chose between forward_diff and zygote_diff for hyperparameter optimization ###
const ADBACKEND = Ref(:zygote)
const ADBACKEND = Ref(:Zygote)

function setadbackend(ad_backend::Symbol)
(ad_backend == :ForwardDiff || ad_backend == :Zygote) ||
Expand Down
13 changes: 7 additions & 6 deletions src/models/SVGP.jl
Expand Up @@ -56,12 +56,13 @@ function SVGP(
kernel,
likelihood,
inference,
KmeansIP(X, nInducingPoints),
verbose = verbose,
optimiser = optimiser,
atfrequency = atfrequency,
mean = mean,
obsdim = obsdim
KmeansIP(X, nInducingPoints);
verbose,
optimiser,
atfrequency,
mean,
Zoptimiser,
obsdim,
)
end

Expand Down

0 comments on commit d0f634b

Please sign in to comment.