Skip to content

Commit

Permalink
Merge 8bd9f18 into b86d5fe
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf committed Oct 29, 2020
2 parents b86d5fe + 8bd9f18 commit 1796bf3
Show file tree
Hide file tree
Showing 36 changed files with 228 additions and 205 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Expand Up @@ -14,7 +14,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Expand All @@ -41,7 +40,6 @@ Distributions = "0.21.5, 0.22, 0.23, 0.24"
FastGaussQuadrature = "0.4"
Flux = "0.10, 0.11"
ForwardDiff = "0.10"
Functors = "0.1"
KernelFunctions = "0.5, 0.6, 0.7, 0.8"
MCMCChains = "0.3.15, 2.0, 3.0, 4.0"
MLDataUtils = "0.5"
Expand Down
3 changes: 1 addition & 2 deletions src/AugmentedGaussianProcesses.jl
Expand Up @@ -27,8 +27,7 @@ module AugmentedGaussianProcesses
@reexport using KernelFunctions
using KernelFunctions: ColVecs, RowVecs
using Zygote, ForwardDiff
using Functors
using Flux # Remove full dependency on Flux once params for KernelFunctions is set
using Flux: params, destructure
@reexport using Flux.Optimise
using PDMats: PDMat, invquad
using AdvancedHMC
Expand Down
2 changes: 1 addition & 1 deletion src/functions/KLdivergences.jl
Expand Up @@ -25,7 +25,7 @@ function extraKL(model::OnlineSVGP{T}) where {T}
for gp in model.f
κₐμ = gp.κₐ * mean(gp)
KLₐ += gp.prev𝓛ₐ
KLₐ += -0.5 * sum(opt_trace.(Ref(gp.invDₐ), [gp.K̃ₐ, gp.κₐ * cov(gp) * transpose(gp.κₐ)]))
KLₐ += -0.5 * sum(trace_ABt.(Ref(gp.invDₐ), [gp.K̃ₐ, gp.κₐ * cov(gp) * transpose(gp.κₐ)]))
KLₐ += dot(gp.prevη₁, κₐμ) - 0.5 * dot(κₐμ, gp.invDₐ * κₐμ)
end
return KLₐ
Expand Down
10 changes: 5 additions & 5 deletions src/functions/utils.jl
Expand Up @@ -29,20 +29,20 @@ end
end

## Return the trace of A*B' ##
@inline function opt_trace(A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
@inline function trace_ABt(A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
dot(A, B)
end

## Return the diagonal of A*B' ##
@inline function opt_diag(
A::AbstractArray{T,N},
B::AbstractArray{T,N},
@inline function diag_ABt(
A::AbstractMatrix,
B::AbstractMatrix,
) where {T<:Real,N}
vec(sum(A .* B, dims = 2))
end

## Return the multiplication of Diagonal(v)*B ##
function opt_diag_mul_mat(
function diagv_B(
v::AbstractVector{T},
B::AbstractMatrix{T},
) where {T<:Real}
Expand Down
6 changes: 3 additions & 3 deletions src/gpblocks/latentgp.jl
Expand Up @@ -250,7 +250,7 @@ mean_f(model::AbstractGP) = mean_f.(model.f)
var_f(model::AbstractGP) = var_f.(model.f)

@traitfn var_f(gp::T) where {T <: AbstractLatent; IsFull{T}} = var(gp)
@traitfn var_f(gp::T) where {T <: AbstractLatent; !IsFull{T}} = opt_diag(gp.κ * cov(gp), gp.κ) + gp.
@traitfn var_f(gp::T) where {T <: AbstractLatent; !IsFull{T}} = diag_ABt(gp.κ * cov(gp), gp.κ) + gp.

Zview(gp::SparseVarLatent) = gp.Z
Zview(gp::OnlineVarLatent) = gp.Z
Expand All @@ -270,7 +270,7 @@ function compute_κ!(gp::SparseVarLatent, X::AbstractVector, jitt::Real)
gp.κ .= gp.Knm / pr_cov(gp)
gp.K̃ .=
kerneldiagmatrix(kernel(gp), X) .+ jitt -
opt_diag(gp.κ, gp.Knm)
diag_ABt(gp.κ, gp.Knm)

@assert all(gp..> 0) "K̃ has negative values"
end
Expand All @@ -285,6 +285,6 @@ function compute_κ!(gp::OnlineVarLatent, X::AbstractVector, jitt::Real)
# Covariance with a new batch
gp.Knm = kernelmatrix(kernel(gp), X, gp.Z)
gp.κ = gp.Knm / pr_cov(gp)
gp.= kerneldiagmatrix(kernel(gp), X) .+ jitt - opt_diag(gp.κ, gp.Knm)
gp.= kerneldiagmatrix(kernel(gp), X) .+ jitt - diag_ABt(gp.κ, gp.Knm)
@assert all(gp..> 0) "K̃ has negative values"
end
38 changes: 20 additions & 18 deletions src/hyperparameter/autotuning.jl
Expand Up @@ -35,16 +35,16 @@ function update_hyperparameters!(
if !isnothing(gp.opt)
f_l, f_μ₀ = hyperparameter_gradient_function(gp, X)
ad_use = K_ADBACKEND[] == :auto ? ADBACKEND[] : K_ADBACKEND[]
grads = if ad_use == :forward_diff
Δμ₀ = f_μ₀()
Δk = if ad_use == :forward_diff
∇L_ρ_forward(f_l, gp, X)
elseif ad_use == :reverse_diff
∇L_ρ_reverse(f_l, gp, X)
else
error("Uncompatible ADBackend")
end
grads[pr_mean(gp)] = f_μ₀()
apply_grads_kernel_params!(gp.opt, kernel(gp), grads) # Apply gradients to the kernel parameters
apply_gradients_mean_prior!(pr_mean(gp), grads[pr_mean(gp)], X)
apply_Δk!(gp.opt, kernel(gp), Δk) # Apply gradients to the kernel parameters
apply_gradients_mean_prior!(pr_mean(gp), Δμ₀, X)
end
end

Expand All @@ -57,18 +57,20 @@ function update_hyperparameters!(
i::Inference,
vi_opt::InferenceOptimizer,
)
if !isnothing(gp.opt)
Δμ₀, Δk = if !isnothing(gp.opt)
f_ρ, f_Z, f_μ₀ = hyperparameter_gradient_function(gp)
k_aduse = K_ADBACKEND[] == :auto ? ADBACKEND[] : K_ADBACKEND[]
grads = if k_aduse == :forward_diff
Δμ₀ = f_μ₀()
Δk = if k_aduse == :forward_diff
∇L_ρ_forward(f_ρ, gp, X, ∇E_μ, ∇E_Σ, i, vi_opt)
elseif k_aduse == :reverse_diff
∇L_ρ_reverse(f_ρ, gp, X, ∇E_μ, ∇E_Σ, i, vi_opt)
end
# @show grads[kernel(gp).transform.s]
grads[pr_mean(gp)] = f_μ₀()
(Δμ₀, Δk)
else
nothing, nothing
end
if !isnothing(opt(gp.Z)) && !isnothing(gp.opt)
if !isnothing(opt(gp.Z))
Z_aduse = Z_ADBACKEND[] == :auto ? ADBACKEND[] : Z_ADBACKEND[]
Z_grads = if Z_aduse == :forward_diff
Z_gradient_forward(gp, f_Z, X, ∇E_μ, ∇E_Σ, i, vi_opt) #Compute the gradient given the inducing points location
Expand All @@ -77,16 +79,16 @@ function update_hyperparameters!(
end
update!(opt(gp.Z), gp.Z.Z, Z_grads) #Apply the gradients on the location
end
if !isnothing(gp.opt)
apply_grads_kernel_params!(gp.opt, kernel(gp), grads) # Apply gradients to the kernel parameters
apply_gradients_mean_prior!(pr_mean(gp), grads[pr_mean(gp)], X)
if !all([isnothing(Δk), isnothing(Δμ₀)])
apply_Δk!(gp.opt, kernel(gp), Δk) # Apply gradients to the kernel parameters
apply_gradients_mean_prior!(pr_mean(gp), Δμ₀, X)
end
end


## Return the derivative of the KL divergence between the posterior and the GP prior ##
function hyperparameter_KL_gradient(J::AbstractMatrix, A::AbstractMatrix)
return 0.5 * opt_trace(J, A)
return 0.5 * trace_ABt(J, A)
end


Expand Down Expand Up @@ -244,10 +246,10 @@ function hyperparameter_expec_gradient(
Jnn::AbstractVector{<:Real},
)
ι = (Jnm - gp.κ * Jmm) / pr_cov(gp)
= Jnn - (opt_diag(ι, gp.Knm) + opt_diag(gp.κ, Jnm))
= Jnn - (diag_ABt(ι, gp.Knm) + diag_ABt(gp.κ, Jnm))
= dot(∇E_μ, ι * mean(gp))
= -dot(∇E_Σ, J̃)
+= -dot(∇E_Σ, 2.0 * (opt_diag(ι, κΣ)))
+= -dot(∇E_Σ, 2.0 * (diag_ABt(ι, κΣ)))
+= -dot(∇E_Σ, 2.0 ** mean(gp)) .* (gp.κ * mean(gp)))
return getρ(i) * (dμ + dΣ)
end
Expand All @@ -266,9 +268,9 @@ function hyperparameter_expec_gradient(
Jnn::AbstractVector{<:Real},
)
ι = (Jnm - gp.κ * Jmm) / pr_cov(gp)
= Jnn - (opt_diag(ι, gp.Knm) + opt_diag(gp.κ, Jnm))
= Jnn - (diag_ABt(ι, gp.Knm) + diag_ABt(gp.κ, Jnm))
= dot(∇E_μ, ι * mean(gp))
= dot(∇E_Σ, J̃ + 2.0 * opt_diag(ι, κΣ))
= dot(∇E_Σ, J̃ + 2.0 * diag_ABt(ι, κΣ))
return getρ(i) * (dμ + dΣ)
end

Expand All @@ -281,7 +283,7 @@ function hyperparameter_online_gradient(
)
ιₐ = (Jab - gp.κₐ * Jmm) / pr_cov(gp)
trace_term =
-0.5 * sum(opt_trace.(
-0.5 * sum(trace_ABt.(
[gp.invDₐ],
[
Jaa,
Expand Down
21 changes: 17 additions & 4 deletions src/hyperparameter/autotuning_utils.jl
Expand Up @@ -21,16 +21,29 @@ function setZadbackend(backend_sym)
Z_ADBACKEND[] = backend_sym
end

##
function apply_grads_kernel_params!(opt, k::Kernel, Δ::IdDict)
ps = Flux.params(k)
## Updating kernel parameters ##
function apply_Δk!(opt, k::Kernel, Δ::IdDict)
ps = params(k)
for p in ps
isnothing(Δ[p]) && continue
Δlogp = Flux.Optimise.apply!(opt, p, p .* vec(Δ[p]))
Δlogp = Optimise.apply!(opt, p, p .* vec(Δ[p]))
p .= exp.(log.(p) + Δlogp)
end
end

function apply_Δk!(opt, k::Kernel, Δ::AbstractVector)
ps = params(k)
i = 1
for p in ps
d = length(p)
Δp = Δ[i:i+d-1]
Δlogp = Optimise.apply!(opt, p, p .* Δp)
@. p = exp(log(p) + Δlogp)
i += d
end
end


function apply_gradients_mean_prior!::PriorMean, g::AbstractVector, X::AbstractVector)
update!(μ, g, X)
end
19 changes: 8 additions & 11 deletions src/hyperparameter/forwarddiff_rules.jl
@@ -1,37 +1,34 @@
function ∇L_ρ_forward(f, gp::AbstractLatent, X::AbstractVector)
θ, re = functor(kernel(gp))
g = ForwardDiff.gradient(θ) do x
θ, re = destructure(kernel(gp))
return g = ForwardDiff.gradient(θ) do x
k = re(x)
Knn = kernelmatrix(k, X)
f(Knn)
end
return IdDict{Any,Any}=> g)
end

function ∇L_ρ_forward(f, gp::SparseVarLatent, X::AbstractVector, ∇E_μ, ∇E_Σ, i, opt)
θ, re = functor(kernel(gp))
g = ForwardDiff.gradient(θ) do x
θ, re = destructure(kernel(gp))
return g = ForwardDiff.gradient(θ) do x
k = re(x)
Kmm = kernelmatrix(k, Zview(gp))
Knm = kernelmatrix(k, X, Zview(gp))
Knn = kerneldiagmatrix(k, X)
f(Kmm, Knm, Knn, Ref(∇E_μ), Ref(∇E_Σ), Ref(i), Ref(opt))
f(Kmm, Knm, Knn, ∇E_μ, ∇E_Σ, i, opt)
end
return IdDict{Any,Any}=> g)
end

function ∇L_ρ_forward(f, gp::OnlineVarLatent, X::AbstractVector, ∇E_μ, ∇E_Σ, i, opt)
θ, re = functor(kernel(gp))
g = ForwardDiff.gradient(θ) do x
θ, re = destructure(kernel(gp))
return g = ForwardDiff.gradient(θ) do x
k = re(x)
Kmm = kernelmatrix(k, Zview(gp))
Knm = kernelmatrix(k, X, Zview(gp))
Knn = kerneldiagmatrix(k, X)
Kaa = kernelmatrix(k, gp.Zₐ)
Kab = kernelmatrix(k, gp.Zₐ, Zview(gp))
f(Kmm, Knm, Jnn, Jab, Jaa, Ref(∇E_μ), Ref(∇E_Σ), Ref(i), Ref(opt))
f(Kmm, Knm, Knn, Kab, Kaa, ∇E_μ, ∇E_Σ, i, opt)
end
return IdDict{Any,Any}=> g)
end

## Return a function computing the gradient of the ELBO given the inducing point locations ##
Expand Down
6 changes: 3 additions & 3 deletions src/hyperparameter/zygote_rules.jl
@@ -1,6 +1,6 @@
function ∇L_ρ_reverse(f, gp::AbstractLatent, X)
k = kernel(gp)
return (Flux.gradient(Flux.params(k)) do
return (Zygote.gradient(params(k)) do
_∇L_ρ_reverse(f, k, X)
end).grads # Zygote gradient
end
Expand All @@ -9,7 +9,7 @@ _∇L_ρ_reverse(f, k, X) = f(kernelmatrix(k, X))

function ∇L_ρ_reverse(f, gp::SparseVarLatent, X, ∇E_μ, ∇E_Σ, i, opt)
k = kernel(gp)
return (Zygote.gradient(Flux.params(k)) do
return (Zygote.gradient(params(k)) do
_∇L_ρ_reverse(f, k, gp.Z, X, ∇E_μ, ∇E_Σ, i, opt)
end).grads
end
Expand All @@ -26,7 +26,7 @@ function ∇L_ρ_reverse(f, gp::OnlineVarLatent, X, ∇E_μ, ∇E_Σ, i, opt)
k = kernel(gp)
Zrv = RowVecs(copy(hcat(gp.Z...)')) # TODO Fix that once https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/151 is solved
Zarv = RowVecs(copy(hcat(gp.Zₐ...)'))
return (Zygote.gradient(Flux.params(k)) do
return (Zygote.gradient(params(k)) do
_∇L_ρ_reverse(f, k, Zrv, X, Zarv, ∇E_μ, ∇E_Σ, i, opt)
end).grads
end
Expand Down
2 changes: 1 addition & 1 deletion src/inference/MCVI.jl
Expand Up @@ -178,7 +178,7 @@ function expec_log_likelihood(
samples .=
raw_samples .* [sqrt(var_f[k][j]) for k = 1:N]' .+ [μ_f[k][j] for k = 1:N]'
loglike += sum(
f -> logpdf(l, getindex.(y, j), f),
f -> loglikelihood(l, getindex.(y, j), f),
eachrow(samples)
) / i.nMC
end
Expand Down
2 changes: 1 addition & 1 deletion src/inference/analytic.jl
Expand Up @@ -54,7 +54,7 @@ function analytic_updates!(m::GP{T}) where {T}
f.post.α .= cov(f) \ (yview(m) - pr_mean(f, xview(m)))
if !isnothing(l.opt_noise)
g = 0.5 * (norm(mean(f), 2) - tr(inv(cov(f))))
Δlogσ² = Flux.Optimise.apply!(l.opt_noise, l.σ², g .* l.σ²)
Δlogσ² = Optimise.apply!(l.opt_noise, l.σ², g .* l.σ²)
l.σ² .= exp.(log.(l.σ²) .+ Δlogσ²)
end
end
Expand Down
2 changes: 0 additions & 2 deletions src/inference/gibbssampling.jl
Expand Up @@ -164,8 +164,6 @@ sample_local!(l::Likelihood, y, f::Tuple{<:AbstractVector{T}}) where {T} =
set_ω!(l::Likelihood, ω) = l.θ .= ω
get_ω(l::Likelihood) = l.θ

# logpdf(model::AbstractGP{T,<:Likelihood,<:GibbsSampling}) where {T} = zero(T)

function sample_global!(
∇E_μ::AbstractVector,
∇E_Σ::AbstractVector,
Expand Down

0 comments on commit 1796bf3

Please sign in to comment.