/
mvnormal.jl
38 lines (28 loc) · 1.22 KB
/
mvnormal.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
struct MultivariateNormal <: Distribution{Vector{Float64}} end
"""
mvnormal(mu::AbstractVector{T}, cov::AbstractMatrix{U}} where {T<:Real,U<:Real}
Samples a `Vector{Float64}` value from a multivariate normal distribution.
"""
const mvnormal = MultivariateNormal()
function logpdf(::MultivariateNormal, x::AbstractVector{T}, mu::AbstractVector{U},
cov::AbstractMatrix{V}) where {T,U,V}
dist = Distributions.MvNormal(mu, cov)
Distributions.logpdf(dist, x)
end
function logpdf_grad(::MultivariateNormal, x::AbstractVector{T}, mu::AbstractVector{U},
cov::AbstractMatrix{V}) where {T,U,V}
dist = Distributions.MvNormal(mu, cov)
inv_cov = Distributions.invcov(dist)
x_deriv = Distributions.gradlogpdf(dist, x)
mu_deriv = -x_deriv
cov_deriv = -0.5 * (inv_cov - (mu_deriv * transpose(mu_deriv)))
(x_deriv, mu_deriv, cov_deriv)
end
function random(::MultivariateNormal, mu::AbstractVector{U},
cov::AbstractMatrix{V}) where {T,U,V}
rand(Distributions.MvNormal(mu, cov))
end
(::MultivariateNormal)(mu, cov) = random(MultivariateNormal(), mu, cov)
has_output_grad(::MultivariateNormal) = true
has_argument_grads(::MultivariateNormal) = (true, true)
export mvnormal