-
Notifications
You must be signed in to change notification settings - Fork 9
/
constantmean.jl
36 lines (29 loc) · 1.13 KB
/
constantmean.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
mutable struct ConstantMean{T<:Real,O} <: PriorMean{T}
C::Vector{T}
opt::O
end
"""
ConstantMean(c::Real = 1.0; opt=ADAM(0.01))
## Arguments
- `c::Real` : Constant value
Construct a prior mean with constant `c`
Optionally set an optimiser `opt` (`ADAM(0.01)` by default)
"""
function ConstantMean(c::T=1.0; opt=ADAM(0.01)) where {T<:Real}
return ConstantMean{T,typeof(opt)}([c], opt)
end
function Base.show(io::IO, ::MIME"text/plain", μ₀::ConstantMean)
return print(io, "Constant Mean Prior (c = ", only(μ₀.C), ")")
end
(μ::ConstantMean{T})(::Real) where {T<:Real} = only(μ.C)
(μ::ConstantMean{T})(x::AbstractVector) where {T<:Real} = fill(only(μ.C), length(x))
function init_priormean_state(hyperopt_state, μ₀::ConstantMean)
μ₀_state = (; C=Optimisers.state(μ₀.opt, μ₀.C))
return merge(hyperopt_state, (; μ₀_state))
end
function update!(μ₀::ConstantMean{T}, hyperopt_state, grad) where {T<:Real}
μ₀_state = hyperopt_state.μ₀_state
C, ΔC = Optimisers.apply(μ₀.opt, μ₀_state.C, μ₀.C, grad)
μ₀.C .+= ΔC
return merge(hyperopt_state, (; μ₀_state=(; C)))
end