In [None]:
include("../SVGD.jl")

using Plots
using Distributions: MvNormal, pdf
using Test
import Statistics

In [None]:
"""
    creating a mesh grid for PDF of given distribution
"""
function pdf_grid(x_, y_, dist_)
    pdf_ = Array{Float64}(undef, size(x_, 1), size(y_, 1))
    for i_ in 1:size(x_, 1)
        for j_ in 1:size(y_, 1)
            pdf_[i_, j_] = pdf(dist_, Vector([x_[i_], y_[j_]]))
        end
    end
    return pdf_
end
;

In [None]:
# Data
# ground truth mean and covariance matrix
cov_mat = [0.333 0.357;0.357 0.666]
mean_vec = [0. 0.]
mvn = MvNormal(Vector(mean_vec[1, :]), Matrix(cov_mat))

xx, yy = [-3.:.01:3.;], [-3.:.01:3.;]
mvn_grid = pdf_grid(xx, yy, mvn)

# heart samples
x = [range(-2., 2, length=25);]
y1 = sqrt.(1 .- (abs.(x) .- 1) .^ 2)
y2 = - 3 * sqrt.(1 .- (abs.(x) ./ 2) .^ 0.5)

init_particles = hcat(vcat(y1.+1, y2.+1), vcat(x, x))
;

In [None]:
# plot PDF heatmap and heart particles
heatmap(xx, yy, mvn_grid,
    legend=false, border=:none, background_color_subplot="black", background_color=:transparent)
scatter!(init_particles[:, 2], init_particles[:, 1],
    legend=false, color="White", aspect_ratio=:equal, axis=nothing)


In [None]:
# creating the evaluate ana_dlogmvn
ana_dlogmvn_eval(x) = ana_dlogmvn(mean_vec, cov_mat, x)

trans_parts = update(init_particles, ana_dlogmvn_eval, n_epochs=2000, dt=0.002, opt="adagrad")

# testing the results
@test all(isapprox.(Statistics.mean(trans_parts, dims=1), mean_vec, atol=0.1))
@test all(isapprox.(Statistics.cov(trans_parts), cov_mat, atol=0.1))


In [None]:
# plot PDF heatmap and evolved particles
heatmap(xx, yy, mvn_grid,
    legend=false, border=:none, background_color_subplot="black", background_color=:transparent)
scatter!(trans_parts[:, 2], trans_parts[:, 1],
    legend=false, color="White", aspect_ratio=:equal, axis=nothing)
