-
Notifications
You must be signed in to change notification settings - Fork 159
/
mvnormal.jl
40 lines (29 loc) · 1.38 KB
/
mvnormal.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import LinearAlgebra
struct MultivariateNormal <: Distribution{Vector{Float64}} end
"""
mvnormal(mu::AbstractVector{T}, cov::AbstractMatrix{U}} where {T<:Real,U<:Real}
Samples a `Vector{Float64}` value from a multivariate normal distribution.
"""
const mvnormal = MultivariateNormal()
function logpdf(::MultivariateNormal, x::AbstractVector{T}, mu::AbstractVector{U},
cov::AbstractMatrix{V}) where {T <: Real, U <: Real, V <: Real}
dist = Distributions.MvNormal(mu, LinearAlgebra.Symmetric(cov))
Distributions.logpdf(dist, x)
end
function logpdf_grad(::MultivariateNormal, x::AbstractVector{T}, mu::AbstractVector{U},
cov::AbstractMatrix{V}) where {T <: Real,U <: Real, V <: Real}
dist = Distributions.MvNormal(mu, LinearAlgebra.Symmetric(cov))
inv_cov = Distributions.invcov(dist)
x_deriv = Distributions.gradlogpdf(dist, x)
mu_deriv = -x_deriv
cov_deriv = -0.5 * (inv_cov - (mu_deriv * transpose(mu_deriv)))
(x_deriv, mu_deriv, cov_deriv)
end
function random(::MultivariateNormal, mu::AbstractVector{U},
cov::AbstractMatrix{V}) where {U <: Real, V <: Real}
rand(Distributions.MvNormal(mu, LinearAlgebra.Symmetric(cov)))
end
(::MultivariateNormal)(mu, cov) = random(MultivariateNormal(), mu, cov)
has_output_grad(::MultivariateNormal) = true
has_argument_grads(::MultivariateNormal) = (true, true)
export mvnormal