Skip to content

Commit

Permalink
Merge pull request #5 from willtebbutt/wct/stheno-0.7-upgrades
Browse files Browse the repository at this point in the history
Upgrade to AbstractGPs API
  • Loading branch information
willtebbutt committed Apr 2, 2021
2 parents 34bde7d + 274940b commit 0eddee2
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 19 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
name = "GPARs"
uuid = "86ee5ac1-314e-4b70-9100-2b12109404ad"
authors = ["WT <wt0881@my.bristol.ac.uk> and contributors"]
version = "0.1.0"
authors = ["Will Tebbutt and contributors"]
version = "0.2.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Stheno = "8188c328-b5d6-583d-959b-9690869a5511"

[compat]
AbstractGPs = "0.2.25"
julia = "1.5"
Stheno = "0.6"
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ We also maintain a [Python version of this package](https://github.com/wesselb/g
## Basic Usage

```julia
using AbstractGPs
using GPARs
using Random
using Stheno

# Build a GPAR from a collection of GPs. For more info on how to specify particular
# kernels and their parameters, please see [Stheno.jl](https://github.com/willtebbutt/Stheno.jl). You should think of this as a vector-valued regressor.
f = GPAR([GP(EQ(), GPC()) for _ in 1:3])
# kernels and their parameters, please see [AbstractGPs.jl](https://github.com/willtebbutt/AbstractGPs.jl) or
# [Stheno.jl](https://github.com/willtebbutt/Stheno.jl)
# You should think of this as a vector-valued regressor.
f = GPAR([GP(SEKernel()) for _ in 1:3])

# Specify inputs. `ColVecs` says "interpret this matrix as a vector of column-vecrors".
# Inputs are 2 dimensional, and there are 10 of them. This means that the pth GP in f
Expand All @@ -40,14 +42,16 @@ logpdf(f(x, Σs), y)

# Generate a new GPAR that is conditioned on these observations. This is just another
# GPAR object (in the simplest case, GPARs are closed under conditioning).
f_post = GPARs.posterior(f(x, Σs), y)
f_post = posterior(f(x, Σs), y)

# Since `f_post` is just another GPAR, we can use it to generate posterior samples
# and to compute log posterior predictive probabilities in the same way as the prior.
x_post = ColVecs(randn(2, 15))
rng = MersenneTwister(123456)
y_post = rand(rng, f_post(x, Σs))
logpdf(f_post(x, Σs), y_post)
```

Using this functionality, you have everything you need to do learning using standard
off-the-shelf functionality ([Zygote.jl](https://github.com/FluxML/Zygote.jl/) to get gradients, [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl) to get optimisers such as (L-)BFGS, and [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl/) to make dealing with large numbers of model parameters more straightforward.
See the examples in [Stheno.jl](https://github.com/willtebbutt/Stheno.jl)'s docs for inspiration.
2 changes: 1 addition & 1 deletion src/GPARs.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module GPARs

using LinearAlgebra
using Stheno
using AbstractGPs
using Random

include("gpar.jl")
Expand Down
14 changes: 7 additions & 7 deletions src/gpar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
"""
struct GPAR{Tfs} <: Stheno.AbstractGP # this is a bit of a hack, as a GPAR isn't a GP.
struct GPAR{Tfs} <: AbstractGPs.AbstractGP # this is a bit of a hack, as a GPAR isn't a GP.
fs::Tfs
end

dim_out(f::GPAR) = length(f.fs)

(f::GPAR)(x::ColVecs, Σs) = Stheno.FiniteGP(f, x, Σs)
(f::GPAR)(x::ColVecs, Σs) = AbstractGPs.FiniteGP(f, x, Σs)

const FiniteGPAR = Stheno.FiniteGP{<:GPAR}
const FiniteGPAR = AbstractGPs.FiniteGP{<:GPAR}

extract_data(fx::FiniteGPAR) = fx.f, fx.x.X, fx.Σy

function Stheno.rand(rng::AbstractRNG, fx::FiniteGPAR)
function AbstractGPs.rand(rng::AbstractRNG, fx::FiniteGPAR)
f, X, Σs = extract_data(fx)
Y = Matrix{Float64}(undef, 0, length(fx.x))

Expand All @@ -27,7 +27,7 @@ function Stheno.rand(rng::AbstractRNG, fx::FiniteGPAR)
return ColVecs(Y)
end

function Stheno.logpdf(fx::FiniteGPAR, y::ColVecs)
function AbstractGPs.logpdf(fx::FiniteGPAR, y::ColVecs)
f, X, Σs = extract_data(fx)
Y = y.X

Expand All @@ -39,12 +39,12 @@ function Stheno.logpdf(fx::FiniteGPAR, y::ColVecs)
return l
end

function posterior(fx::FiniteGPAR, y::ColVecs)
function AbstractGPs.posterior(fx::FiniteGPAR, y::ColVecs)
f, X, Σs = extract_data(fx)
Y = y.X
fs_post = map(enumerate(f.fs)) do (p, f_p)
x_p = ColVecs(vcat(X, Y[1:p-1, :]))
return f_p | (f_p(x_p, Σs[p]) Y[p, :])
return posterior(f_p(x_p, Σs[p]), Y[p, :])
end
return GPAR(fs_post)
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Stheno = "8188c328-b5d6-583d-959b-9690869a5511"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2 changes: 1 addition & 1 deletion test/gpar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
D_out = 3

rng = MersenneTwister(123456)
f = GPAR([GP(EQ(), GPC()) for _ in 1:D_out])
f = GPAR([GP(SEKernel()) for _ in 1:D_out])
x = ColVecs(randn(D_in, N))
Σs = [rand(rng) + 0.1 for _ in 1:D_out]

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using AbstractGPs
using GPARs
using Random
using Stheno
using Test

using GPARs: posterior
Expand Down

2 comments on commit 0eddee2

@willtebbutt
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33431

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 0eddee2999ff483b3274b97122b71b35278c34d9
git push origin v0.2.0

Please sign in to comment.