Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ForwardDiff = "0.10"
GPLikelihoods = "0.2, 0.3, 0.4"
InducingPoints = "0.2, 0.3"
KernelFunctions = "0.8, 0.9, 0.10"
Optimisers = "0.1"
Optimisers = "0.2"
ProgressMeter = "1"
RecipesBase = "1.0, 1.1"
Reexport = "0.2, 1"
Expand Down
6 changes: 3 additions & 3 deletions src/hyperparameter/autotuning_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,22 @@ function update_kernel!(opt, k::Union{Kernel,Transform}, g::NamedTuple, state::N
end

function update_kernel!(opt, x::AbstractArray, g::AbstractArray, state)
state, Δ = Optimisers.apply(opt, state, x, x .* g)
state, Δ = Optimisers.apply!(opt, state, x, x .* g)
@. x = exp(log(x) + Δ) # Always assume that parameters need to be positive
return state
end

## Updating inducing points
function update_Z!(opt, Z::AbstractVector, Z_grads::AbstractVector, state)
return map(Z, Z_grads, state) do z, zgrad, st
st, ΔZ = Optimisers.apply(opt, st, z, zgrad)
st, ΔZ = Optimisers.apply!(opt, st, z, zgrad)
z .+= ΔZ
return st
end
end

function update_Z!(opt, Z::Union{ColVecs,RowVecs}, Z_grads::NamedTuple, state)
st, Δ = Optimisers.apply(opt, state, Z.X, Z_grads.X)
st, Δ = Optimisers.apply!(opt, state, Z.X, Z_grads.X)
Z.X .+= Δ
return st
end
2 changes: 1 addition & 1 deletion src/inference/analytic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function analytic_updates(m::GP{T}, state, y) where {T}
f.post.α .= cov(f) \ (y - pr_mean(f, input(m.data)))
if !isnothing(l.opt_noise)
g = (norm(mean(f), 2) - tr(inv(cov(f)))) / 2
state_σ², Δlogσ² = Optimisers.apply(
state_σ², Δlogσ² = Optimisers.apply!(
l.opt_noise, state.local_vars.state_σ², l.σ², g .* l.σ²
)
local_vars = merge(state.local_vars, (; state_σ²))
Expand Down
4 changes: 2 additions & 2 deletions src/inference/analyticVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ end

function global_update!(gp::SparseVarLatent, opt::AVIOptimizer, i::AnalyticVI, opt_state)
if is_stochastic(i)
state_η₁, Δ₁ = Optimisers.apply(
state_η₁, Δ₁ = Optimisers.apply!(
opt.optimiser, opt_state.state_η₁, nat1(gp), opt_state.∇η₁
)
state_η₂, Δ₂ = Optimisers.apply(
state_η₂, Δ₂ = Optimisers.apply!(
opt.optimiser, opt_state.state_η₂, nat2(gp).data, opt_state.∇η₂
)
gp.post.η₁ .+= Δ₁
Expand Down
6 changes: 3 additions & 3 deletions src/inference/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ function RobbinsMonro(κ::Real=0.51, τ::Real=1)
return RobbinsMonro{promote_type(typeof(κ), typeof(τ))}(κ, τ)
end

Optimisers.init(::RobbinsMonro, ::Any) = 1
Optimisers.setup(::RobbinsMonro, ::Any) = 1

function Optimisers.apply(o::RobbinsMonro, st, x, Δ)
function Optimisers.apply!(o::RobbinsMonro, st, x, Δ)
κ = o.κ
τ = o.τ
n = st
Expand All @@ -36,7 +36,7 @@ function Optimisers.init(opt::ALRSVI{T}, x::AbstractArray) where {T}
return (; i, g, h, ρ=opt.ρ, τ)
end

function apply(opt::ALRSVI, state, x::AbstractArray, Δx::AbstractArray)
function apply(opt::ALRSVI, state, ::AbstractArray, Δx::AbstractArray)
if state.i <= opt.n_mc
g = state.g + Δx
h = state.h + norm(Δx)
Expand Down
6 changes: 3 additions & 3 deletions src/inference/vi_optimizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ struct SOptimizer{O} <: AOptimizer
optimiser::O
end

Optimisers.state(opt::InferenceOptimizer, x) = Optimisers.state(opt.optimiser, x)
function Optimisers.apply(opt::InferenceOptimizer, st, x, dx)
return Optimisers.apply(opt.optimiser, st, x, dx)
Optimisers.setup(opt::InferenceOptimizer, x) = Optimisers.setup(opt.optimiser, x)
function Optimisers.apply!(opt::InferenceOptimizer, st, x, dx)
return Optimisers.apply!(opt.optimiser, st, x, dx)
end
function Optimisers.update(opt::InferenceOptimizer, st, x, dx)
return Optimisers.update(opt.optimiser, st, x, dx)
Expand Down
4 changes: 2 additions & 2 deletions src/mean/affinemean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ end

function update!(μ₀::AffineMean{T}, hyperopt_state, grad) where {T<:Real}
μ₀_state = hyperopt_state.μ₀_state
w, Δw = Optimisers.apply(μ₀.opt, μ₀_state.w, μ₀.w, grad.w)
b, Δb = Optimisers.apply(μ₀.opt, μ₀_state.b, μ₀.b, grad.b)
w, Δw = Optimisers.apply!(μ₀.opt, μ₀_state.w, μ₀.w, grad.w)
b, Δb = Optimisers.apply!(μ₀.opt, μ₀_state.b, μ₀.b, grad.b)
μ₀.w .+= Δw
μ₀.b .+= Δb
return merge(hyperopt_state, (; μ₀_state=(; w, b)))
Expand Down
4 changes: 2 additions & 2 deletions src/mean/constantmean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ end
(μ::ConstantMean{T})(x::AbstractVector) where {T<:Real} = fill(only(μ.C), length(x))

function init_priormean_state(hyperopt_state, μ₀::ConstantMean)
μ₀_state = (; C=Optimisers.state(μ₀.opt, μ₀.C))
μ₀_state = (; C=Optimisers.setup(μ₀.opt, μ₀.C))
return merge(hyperopt_state, (; μ₀_state))
end

function update!(μ₀::ConstantMean{T}, hyperopt_state, grad) where {T<:Real}
μ₀_state = hyperopt_state.μ₀_state
C, ΔC = Optimisers.apply(μ₀.opt, μ₀_state.C, μ₀.C, grad)
C, ΔC = Optimisers.apply!(μ₀.opt, μ₀_state.C, μ₀.C, grad)
μ₀.C .+= ΔC
return merge(hyperopt_state, (; μ₀_state=(; C)))
end
2 changes: 1 addition & 1 deletion src/mean/empiricalmean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end

function update!(μ₀::EmpiricalMean{T}, hyperopt_state, grad) where {T<:Real}
μ₀_state = hyperopt_state.μ₀_state
C, ΔC = Optimisers.apply(μ₀.opt, μ₀_state.C, μ₀.C, grad.C)
C, ΔC = Optimisers.apply!(μ₀.opt, μ₀_state.C, μ₀.C, grad.C)
μ₀.C .+= ΔC
return merge(hyperopt_state, (; μ₀_state=(; C)))
end
2 changes: 1 addition & 1 deletion src/models/single_and_multi_output_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ function update_A!(m::TGP, state, ys) where {TGP<:AbstractGPModel}
∇A[q] = x1 - 2 * m.A[t][j][q] * x2
end
state_A = state.A_state[t][j]
state_A, ΔA = Optimisers.apply(m.A_opt, state_A, m.A[t][j], ∇A)
state_A, ΔA = Optimisers.apply!(m.A_opt, state_A, m.A[t][j], ∇A)
m.A[t][j] .+= ΔA
m.A[t][j] /= sqrt(sum(abs2, m.A[t][j])) # Projection on the unit circle
state.A_state[t][j] = state_A
Expand Down
22 changes: 11 additions & 11 deletions src/training/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ function init_opt_state(gp::Union{VarLatent{T},TVarLatent{T}}, vi::NumericalVI)
return (;
ν=zeros(T, batchsize(vi)), # Derivative -<dv/dx>_qn
λ=zeros(T, batchsize(vi)), # Derivative <d²V/dx²>_qm
state_μ=Optimisers.state(opt(vi).optimiser, mean(gp)),
state_Σ=Optimisers.state(opt(vi), cov(gp).data),
state_μ=Optimisers.setup(opt(vi).optimiser, mean(gp)),
state_Σ=Optimisers.setup(opt(vi), cov(gp).data),
∇η₁=zero(mean(gp)),
∇η₂=zero(cov(gp).data),
)
Expand All @@ -64,8 +64,8 @@ function init_opt_state(gp::SparseVarLatent{T}, vi::VariationalInference) where
state = merge(
state,
(;
state_η₁=Optimisers.state(opt(vi).optimiser, nat1(gp)),
state_η₂=Optimisers.state(opt(vi), nat2(gp).data),
state_η₁=Optimisers.setup(opt(vi).optimiser, nat1(gp)),
state_η₂=Optimisers.setup(opt(vi), nat2(gp).data),
),
)
end
Expand All @@ -75,8 +75,8 @@ function init_opt_state(gp::SparseVarLatent{T}, vi::VariationalInference) where
(;
ν=zeros(T, batchsize(vi)), # Derivative -<dv/dx>_qn
λ=zeros(T, batchsize(vi)), # Derivative <d²V/dx²>_qm
state_μ=Optimisers.state(opt(vi).optimiser, mean(gp)),
state_Σ=Optimisers.state(opt(vi), cov(gp).data),
state_μ=Optimisers.setup(opt(vi).optimiser, mean(gp)),
state_Σ=Optimisers.setup(opt(vi), cov(gp).data),
),
)
end
Expand Down Expand Up @@ -121,7 +121,7 @@ end
hyperopt_state = (;)
if !isnothing(opt(gp))
k = kernel(gp)
state_k = Optimisers.state(opt(gp), k)
state_k = Optimisers.setup(opt(gp), k)
hyperopt_state = merge(hyperopt_state, (; state_k))
end
hyperopt_state = init_priormean_state(hyperopt_state, pr_mean(gp))
Expand All @@ -132,18 +132,18 @@ end
hyperopt_state = (;)
if !isnothing(opt(gp))
k = kernel(gp)
state_k = Optimisers.state(opt(gp), k)
state_k = Optimisers.setup(opt(gp), k)
hyperopt_state = merge(hyperopt_state, (; state_k))
end
if !isnothing(Zopt(gp))
Z = Zview(gp)
state_Z = Optimisers.state(opt(gp), Z)
state_Z = Optimisers.setup(opt(gp), Z)
hyperopt_state = merge(hyperopt_state, (; state_Z))
end
hyperopt_state = init_priormean_state(hyperopt_state, pr_mean(gp))
return hyperopt_state
end

function Optimisers.state(opt, Z::Union{ColVecs,RowVecs})
return Optimisers.state(opt, Z.X)
function Optimisers.setup(opt, Z::Union{ColVecs,RowVecs})
return Optimisers.setup(opt, Z.X)
end
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using AugmentedGaussianProcesses
using Test
using LinearAlgebra, Distributions
using Zygote
using Test, AugmentedGaussianProcesses
using Distributions
using LinearAlgebra
using PDMats
using MLDataUtils
using Random: seed!
using Zygote
seed!(42)

include("testingtools.jl")
Expand Down