Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Kernel Embeddings #36

Merged
merged 3 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BayesianQuadrature"
uuid = "609f5bd8-aef1-42b2-b90e-083e3346dba9"
authors = ["Theo Galy-Fajou <theo.galyfajou@gmail.com> and contributors"]
version = "0.2.0"
version = "0.2.1"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
4 changes: 3 additions & 1 deletion src/BayesianQuadrature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ export PriorSampling
export BayesModel
export prior, integrand, logintegrand, logprior, logjoint
export BQ # Short version for calling BayesianQuadrature
export KernelEmbedding
export kernel_mean, kernel_variance

const BQ = BayesianQuadrature

Expand Down Expand Up @@ -42,7 +44,7 @@ abstract type AbstractBQModel{Tp,Ti} <: AbstractMCMC.AbstractModel end

include("bayesquads/abstractbq.jl")
include("samplers/abstractbqsampler.jl")
include("kernelmeans/kernels.jl")
include("kernelembeddings/kernelembedding.jl")
include("interface.jl")
include("models.jl")
include("utils.jl")
Expand Down
6 changes: 1 addition & 5 deletions src/bayesquads/abstractbq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end

function get_kernel_params(k::TransformedKernel; kwargs...)
check_transform(k.transform)
return get_kernel_params(k.kernel; kwargs..., l=param(k.transform))
return get_kernel_params(k.kernel; kwargs..., l=transform_param(k.transform))
end

function check_transform(transform)
Expand All @@ -44,7 +44,3 @@ function kernel(b::AbstractBQ)
return b.σ * (b.kernel ∘ ScaleTransform(inv.(b.l)))
end
end

Λ(l::Real) = abs2(l) * I
Λ(l::AbstractVector) = Diagonal(abs2.(l))
Λ(l::LowerTriangular) = l * l'
5 changes: 3 additions & 2 deletions src/bayesquads/bayesquad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ function quadrature(
isempty(samples) && error("The collection of samples is empty")
y = integrand(model).(samples)
K = kernelpdmat(kernel(bquad), samples)
z = calc_z(samples, p_0(model), bquad)
C = calc_C(p_0(model), bquad)
ke = KernelEmbedding(bquad, p_0(model))
z = kernel_mean(ke, samples)
C = kernel_variance(ke)
var = evaluate_var(z, K, C)
if var < 0
if var > -1e-5
Expand Down
7 changes: 4 additions & 3 deletions src/bayesquads/logbayesquad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,23 @@ function quadrature(
f = exp.(logf) # Compute integrand on samples

x_c = sample_candidates(bquad, samples, bquad.n_candidates) # Sample candidates around the samples
ke = KernelEmbedding(bquad, p_0(model))

gp = create_gp(bquad, samples)
f_c_0 = mean.(predict(gp, f, x_c)) # Predict integrand on x_c
logf_c_0 = mean.(predict(gp, logf, x_c)) # Predict log-integrand on x_c
Δ_c = exp.(logf_c_0) - f_c_0 # Compute difference of predictions

z = calc_z(samples, p_0(model), bquad) # Compute mean for the basic BQ
z = kernel_mean(ke, samples) # Compute mean for the basic BQ
K = kernelpdmat(kernel(bquad), samples) # and the kernel matrix

z_c = calc_z(x_c, p_0(model), bquad) # Compute mean for the ΔlogBQ
z_c = kernel_mean(ke, x_c) # Compute mean for the ΔlogBQ
K_c = kernelpdmat(kernel(bquad), x_c) # and the kernel matrix for the candidates

m_evidence = evaluate_mean(z, K, f) # Compute m(Z|samples)
m_correction = evaluate_mean(z_c, K_c, Δ_c) #

C = calc_C(p_0(model), bquad) # Compute the C component for the variance
C = kernel_variance(ke) # Compute the kernel variance

var_evidence = evaluate_var(z, K, C)
var_correction = evaluate_var(z_c, K_c, C)
Expand Down
62 changes: 62 additions & 0 deletions src/kernelembeddings/kernelembedding.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
abstract type AbstractKernelEmbedding end

measure(ke::AbstractKernelEmbedding) = ke.measure

struct KernelEmbedding{Tk<:Kernel,Tm,Tσ<:Real,Tl} <: AbstractKernelEmbedding
kernel::Tk # Kernel function
σ::Tσ # Kernel variance
l::Tl # Kernel lengthscale
measure::Tm # Measure
end

function KernelEmbedding(bquad::AbstractBQ, prior)
return KernelEmbedding(bquad.kernel, bquad.σ, bquad.l, prior)
end

scale(ke::KernelEmbedding) = ke.σ

@doc raw"""
kernel_mean(ke::KernelEmbedding{Kernel,Measure}, samples::AbstractVector)

Compute the kernel mean of the kernel embedding `ke` for each one of
the `samples` $$x_i$$:
```math
z_i = \int k(x, x_i)d\mu(x)
```
"""
kernel_mean

function kernel_mean(ke::KernelEmbedding{<:SqExponentialKernel,<:AbstractMvNormal}, samples::AbstractVector)
z = samples .- Ref(mean(measure(ke)))
B = Λ(ke.l)
return scale(ke) / sqrt(det(inv(B) * cov(measure(ke)) + I)) *
exp.(- PDMats.invquad.(Ref(PDMat(B + cov(measure(ke)))), z) / 2)
end

@doc raw"""
kernel_variance(ke::KernelEmbedding{Kernel,Measure})

Compute the kernel variance of the given kernel embedding:
```math
C = \int\int k(x,x')d\mu(x)d\mu(x')
```
"""
kernel_variance


function kernel_variance(ke::KernelEmbedding{<:SqExponentialKernel,<:AbstractMvNormal})
B = Λ(ke.l)
return scale(ke) / sqrt(det(2 * inv(B) * cov(measure(ke)) + I))
end


# Turn the lengthscale into a Diagonal matrix of noise
Λ(l::Real) = abs2(l) * I
Λ(l::AbstractVector) = Diagonal(abs2.(l))
Λ(l::LowerTriangular) = Cholesky(l, :L, 1)
Λ(l::AbstractMatrix) = l * l'

# Turns a transform into a lengthscale
transform_param(t::ScaleTransform) = inv(first(t.s))
transform_param(t::ARDTransform) = inv.(t.v)
transform_param(t::LinearTransform) = inv(t.A)
5 changes: 0 additions & 5 deletions src/kernelmeans/kernels.jl

This file was deleted.

18 changes: 0 additions & 18 deletions src/kernelmeans/sekernel.jl

This file was deleted.

6 changes: 0 additions & 6 deletions test/bayesquads/bayesquad.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
@testset "bayesquad" begin
s = 2.0
l = [1.0, 2.0]
L = LowerTriangular(rand(2, 2))
k = SqExponentialKernel()
@test BQ.Λ(s) ≈ s^2 * I
@test BQ.Λ(l) ≈ Diagonal(abs2.(l))
@test BQ.Λ(L) ≈ L * L'

σ = 4.0
@test BQ.scale(BayesQuad(σ * k)) == σ
Expand Down
22 changes: 22 additions & 0 deletions test/kernelembeddings/kernelembedding.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@testset "kernelembeddings" begin
rng = MersenneTwister(42)
N = 3
measure = MvNormal(ones(3), ones(3))
k = SqExponentialKernel()
l = 2.0
σ = 0.5
ke = KernelEmbedding(k, σ, l, measure)
kernel = σ * with_lengthscale(k, l)
@test BQ.scale(ke) == σ

sample = [rand(rng, 3)]
@test kernel_mean(ke, sample) ≈ [mean(kernel.(sample, eachcol(rand(rng, measure, 10000))))] atol=1e-2
@test kernel_variance(ke) ≈ mean(kernel.(eachcol(rand(rng, measure, 10000)), eachcol(rand(rng, measure, 10000)))) atol=1e-2

s = 2.0
l = [1.0, 2.0]
L = LowerTriangular(rand(2, 2))
@test BQ.Λ(s) ≈ s^2 * I
@test BQ.Λ(l) ≈ Diagonal(abs2.(l))
@test Matrix(BQ.Λ(L)) ≈ L * L'
end
2 changes: 0 additions & 2 deletions test/kernelmeans/kernels.jl

This file was deleted.

1 change: 0 additions & 1 deletion test/kernelmeans/sekernel.jl

This file was deleted.

7 changes: 3 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ include("testing_tools.jl")
include(joinpath("samplers", "priorsampling.jl"))
end

@info "Testing kernel means"
@testset "Kernel Means" begin
include(joinpath("kernelmeans", "kernels.jl"))
include(joinpath("kernelmeans", "sekernel.jl"))
@info "Testing kernel embeddings"
@testset "Kernel Embeddings" begin
include(joinpath("kernelembeddings", "kernelembedding.jl"))
end
end