Skip to content

Commit

Permalink
Merge bc6651a into a42e714
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Sep 2, 2021
2 parents a42e714 + bc6651a commit 3dd09c1
Show file tree
Hide file tree
Showing 37 changed files with 260 additions and 218 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Expand Up @@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InducingPoints = "b4bd816d-b975-4295-ac05-5f2992945579"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand All @@ -28,12 +29,12 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
AbstractMCMC = "2, 3"
AdvancedHMC = "0.2.13, 0.3"
ChainRulesCore = "0.9"
ChainRulesCore = "0.9, 1"
Distributions = "0.21.5, 0.22, 0.23, 0.24, 0.25"
FastGaussQuadrature = "0.4"
Flux = "0.10, 0.11, 0.12"
ForwardDiff = "0.10"
InducingPoints = "0.1"
InducingPoints = "0.2"
KernelFunctions = "0.8, 0.9, 0.10"
ProgressMeter = "1"
RecipesBase = "1.0, 1.1"
Expand Down
27 changes: 16 additions & 11 deletions docs/Manifest.toml
@@ -1,5 +1,10 @@
# This file is machine-generated - editing it directly is not advised

[[ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9"
version = "0.0.1"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7"
Expand Down Expand Up @@ -29,9 +34,9 @@ version = "0.8.5"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "0b0aa9d61456940511416b59a0e902c57b154956"
git-tree-sha1 = "f53ca8d41e4753c41cdafa6ec5f7ce914b34be54"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.12"
version = "0.10.13"

[[ColorSchemes]]
deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random", "StaticArrays"]
Expand Down Expand Up @@ -81,9 +86,9 @@ version = "1.7.0"

[[DataFrames]]
deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "1dadfca11c0e08e03ab15b63aaeda55266754bad"
git-tree-sha1 = "a19645616f37a2c2c3077a44bc0d3e73e13441d7"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "1.2.0"
version = "1.2.1"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -121,10 +126,10 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.5"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48"
deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "95265abf7d7bf06dfdb8d58525a23ea5fb0bdeee"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.27.3"
version = "0.27.4"

[[EarCut_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down Expand Up @@ -508,9 +513,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[SentinelArrays]]
deps = ["Dates", "Random"]
git-tree-sha1 = "ffae887d0f0222a19c406a11c3831776d1383e3d"
git-tree-sha1 = "35927c2c11da0a86bcd482464b93dadd09ce420f"
uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
version = "1.3.3"
version = "1.3.5"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand Down Expand Up @@ -546,9 +551,9 @@ version = "1.5.1"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "a43a7b58a6e7dc933b2fa2e0ca653ccf8bb8fd0e"
git-tree-sha1 = "1b9a0f17ee0adde9e538227de093467348992397"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.2.6"
version = "1.2.7"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand Down
7 changes: 3 additions & 4 deletions docs/examples/onlinegp.jl
Expand Up @@ -11,10 +11,9 @@ using MLDataUtils, Distributions
N = 2000
σ = 0.1
X, y = noisy_sin(N, 0, 20; noise=σ)
X = reshape(X, :, 1) # Put X as a Matrix
X_train = X[1:2:end, :];
X_train = X[1:2:end];
y_train = y[1:2:end]; # We split the data equally
X_test = X[2:2:end, :];
X_test = X[2:2:end];
y_test = y[2:2:end];
scatter(X_train, y_train)

Expand All @@ -33,7 +32,7 @@ k = SqExponentialKernel();
# ### Create an inducing point selection method
IP_alg = OIPS(0.8);
# ### Create the model and stream the data
model = OnlineSVGP(k, GaussianLikelihood(σ), AnalyticVI(), IP_alg)
model = OnlineSVGP(k, GaussianLikelihood(σ), AnalyticVI(), IP_alg; optimiser=false)
anim = Animation()
size_batch = 100
for (i, (X_batch, y_batch)) in enumerate(eachbatch((X_train, y_train); obsdim=1, size=size_batch))
Expand Down
2 changes: 1 addition & 1 deletion src/AugmentedGaussianProcesses.jl
Expand Up @@ -35,7 +35,7 @@ using Reexport

using AbstractMCMC
using AdvancedHMC
using ChainRulesCore: ChainRulesCore, NO_FIELDS, DoesNotExist
using ChainRulesCore: ChainRulesCore, NoTangent
using Distributions:
Distributions,
Distribution,
Expand Down
4 changes: 2 additions & 2 deletions src/data/datacontainer.jl
Expand Up @@ -56,8 +56,8 @@ function wrap_X(X::AbstractMatrix{T}, obsdim::Int=1) where {T<:Real}
return KernelFunctions.vec_of_vecs(X; obsdim=obsdim), T
end

function wrap_X(X::AbstractVector{<:Real}, ::Int=1)
return wrap_X(reshape(X, :, 1), 1)
function wrap_X(X::AbstractVector{T}, ::Int=1) where {T<:Real}
return X, T
end

function wrap_X(X::AbstractVector{<:AbstractVector{T}}, ::Int=1) where {T<:Real}
Expand Down
20 changes: 10 additions & 10 deletions src/data/utils.jl
Expand Up @@ -19,13 +19,13 @@ function check_data!(y::AbstractArray, likelihood::Union{Distribution,AbstractLi
end

# Transform Z as an OptimIP if it's not the case already
function init_Z(Z::AbstractInducingPoints, Zoptimiser)
if Zoptimiser isa Bool
Zoptimiser = Zoptimiser ? ADAM(1e-3) : nothing
end
if Z isa OptimIP
return Z
else
return OptimIP(Z, Zoptimiser)
end
end
# function init_Z(Z::AbstractInducingPoints, Zoptimiser)
# if Zoptimiser isa Bool
# Zoptimiser = Zoptimiser ? ADAM(1e-3) : nothing
# end
# if Z isa OptimIP
# return Z
# else
# return OptimIP(Z, Zoptimiser)
# end
# end
9 changes: 3 additions & 6 deletions src/functions/ELBO.jl
@@ -1,26 +1,23 @@
function ELBO(model::GP, pr_mean, kernel)
# setprior!(model, pr_means, kernels, Zs)
setpr_mean!(model.f, pr_mean)
setkernel!(model.f, kernel)
computeMatrices!(model, true)
compute_kernel_matrices!(model, true)
return log_py(model)
end

@traitfn function ELBO(model::TGP, pr_means, kernels) where {TGP <: AbstractGP; IsFull{TGP}}
# setprior!(model, pr_means, kernels, Zs)
setpr_means!(model, pr_means)
setkernels!(model, kernels)
computeMatrices!(model, true)
compute_kernel_matrices!(model, true)
return ELBO(model)
end

@traitfn function ELBO(
model::TGP, pr_means, kernels, Zs
) where {TGP <: AbstractGP; !IsFull{TGP}}
# setprior!(model, pr_means, kernels, Zs)
setpr_means!(model, pr_means)
setkernels!(model, kernels)
setZs!(model, Zs)
computeMatrices!(model, true)
compute_kernel_matrices!(model, true)
return ELBO(model)
end
37 changes: 25 additions & 12 deletions src/gpblocks/latentgp.jl
Expand Up @@ -44,25 +44,27 @@ end
## Sparse Variational Gaussian Process

mutable struct SparseVarLatent{
T,Tpr<:GPPrior,Tpo<:VarPosterior{T},TZ<:AbstractInducingPoints,O
T,Tpr<:GPPrior,Tpo<:VarPosterior{T},Topt,TZ<:AbstractVector,TZopt
} <: AbstractVarLatent{T,Tpr,Tpo}
prior::Tpr
post::Tpo
Z::TZ
Knm::Matrix{T}
κ::Matrix{T}
::Vector{T}
opt::O
opt::Topt
Zopt::TZopt
end

function SparseVarLatent(
T::DataType,
dim::Int,
S::Int,
Z::AbstractInducingPoints,
Z::AbstractVector,
kernel::Kernel,
mean::PriorMean,
opt,
opt=nothing,
Zopt=nothing
)
return SparseVarLatent(
GPPrior(deepcopy(kernel), deepcopy(mean), cholesky(Matrix{T}(I(dim)))),
Expand All @@ -72,6 +74,7 @@ function SparseVarLatent(
Matrix{T}(undef, S, dim),
Vector{T}(undef, S),
deepcopy(opt),
deepcopy(Zopt),
)
end

Expand All @@ -95,16 +98,19 @@ end

## Online Sparse Variational Process

mutable struct OnlineVarLatent{T,Tpr<:GPPrior,Tpo<:AbstractVarPosterior{T},O} <:
mutable struct OnlineVarLatent{T,Tpr<:GPPrior,Tpo<:AbstractVarPosterior{T},Topt,TZ<:AbstractVector,
TZalg<:InducingPoints.OnIPSA,TZopt} <:
AbstractVarLatent{T,Tpo,Tpr}
prior::Tpr
post::Tpo
Z::InducingPoints.AIP
Z::TZ
Zalg::TZalg
Knm::Matrix{T}
κ::Matrix{T}
::Vector{T}
Zupdated::Bool
opt::O
opt::Topt
Zopt::TZopt
Zₐ::AbstractVector
Kab::Matrix{T}
κₐ::Matrix{T}
Expand All @@ -118,21 +124,25 @@ function OnlineVarLatent(
T::DataType,
dim::Int,
nSamplesUsed::Int,
Z::AbstractInducingPoints,
Z::AbstractVector,
Zalg::InducingPoints.OnIPSA,
kernel::Kernel,
mean::PriorMean,
opt,
opt=nothing,
Zopt=nothing
)
return OnlineVarLatent(
GPPrior(deepcopy(kernel), deepcopy(mean), cholesky(Matrix{T}(I, dim, dim))),
OnlineVarPosterior{T}(dim),
Z,
Zalg,
Matrix{T}(undef, nSamplesUsed, dim),
Matrix{T}(undef, nSamplesUsed, dim),
Vector{T}(undef, nSamplesUsed),
false,
deepcopy(opt),
vec(Z),
deepcopy(Zopt),
deepcopy(Z),
Matrix{T}(I, dim, dim),
Matrix{T}(I, dim, dim),
Matrix{T}(I, dim, dim),
Expand Down Expand Up @@ -209,9 +219,12 @@ var_f(Σ::AbstractMatrix, κ::AbstractMatrix, K̃::AbstractVector) = diag_ABt(κ
Zview(gp::SparseVarLatent) = gp.Z
Zview(gp::OnlineVarLatent) = gp.Z

setZ!(gp::AbstractLatent, Z::AbstractInducingPoints) = gp.Z = Z#InducingPoints.setZ!(Zview(gp), Z)
setZ!(gp::AbstractLatent, Z::AbstractVector) = gp.Z = Z

opt(gp::AbstractLatent) = gp.opt
Zopt(::AbstractLatent) = nothing
Zopt(gp::SparseVarLatent) = gp.Zopt
Zopt(gp::OnlineVarLatent) = gp.Zopt

@traitfn function compute_K!(
gp::TGP, X::AbstractVector, jitt::Real
Expand Down Expand Up @@ -241,5 +254,5 @@ function compute_κ!(gp::OnlineVarLatent, X::AbstractVector, jitt::Real)
gp.Knm = kernelmatrix(kernel(gp), X, gp.Z)
gp.κ = gp.Knm / pr_cov(gp)
gp.= kernelmatrix_diag(kernel(gp), X) .+ jitt - diag_ABt(gp.κ, gp.Knm)
@assert all(gp..> 0) "K̃ has negative values"
all(gp..> 0) || error("K̃ has negative values")
end
10 changes: 5 additions & 5 deletions src/hyperparameter/autotuning.jl
Expand Up @@ -67,7 +67,7 @@ end

@traitfn function update_hyperparameters!(m::TGP) where {TGP <: AbstractGP; !IsFull{TGP}}
# Check that here is least one optimiser
if any((!) isnothing opt, m.f) || any((!) isnothing opt Zview, m.f)
if any((!) isnothing opt, m.f) || any((!) isnothing Zopt, m.f)
μ₀ = pr_means(m)
ks = kernels(m)
Zs = Zviews(m)
Expand All @@ -82,7 +82,7 @@ end
@warn "Kernel gradients are equal to zero" maxlog = 1
else
for (f, Δ) in zip(m.f, Δk)
update!(opt(f), kernel(f), Δ)
update_kernel!(opt(f), kernel(f), Δ)
end
end

Expand All @@ -91,7 +91,7 @@ end
@warn "Inducing point locations gradients are equal to zero" maxlog = 1
else
for (f, Δ) in zip(m.f, ΔZ)
update!(opt(f.Z), data(f.Z), Δ)
update_Z!(Zopt(f), Zview(f), Δ)
end
end
elseif ADBACKEND[] == :ForwardDiff
Expand Down Expand Up @@ -142,14 +142,14 @@ function update_hyperparameters!(
else
nothing, nothing
end
if !isnothing(opt(gp.Z))
if !isnothing(Zopt(gp))
ad_backend = Z_ADBACKEND[] == :auto ? ADBACKEND[] : Z_ADBACKEND[]
Z_grads = if ad_backend == :forward
Z_gradient_forward(gp, f_Z, X, ∇E_μ, ∇E_Σ, i, vi_opt) #Compute the gradient given the inducing points location
elseif ad_backend == :zygote
Z_gradient_zygote(gp, f_Z, X, ∇E_μ, ∇E_Σ, i, vi_opt)
end
update!(opt(gp.Z), gp.Z.Z, Z_grads) #Apply the gradients on the location
update_Z!(opt(gp.Z), gp.Z, Z_grads) #Apply the gradients on the location
end
if !all([isnothing(Δk), isnothing(Δμ₀)])
apply_Δk!(gp.opt, kernel(gp), Δk) # Apply gradients to the kernel parameters
Expand Down

0 comments on commit 3dd09c1

Please sign in to comment.