Skip to content

Commit

Permalink
improve code for mmd ratio est
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Sep 12, 2019
1 parent 68d431c commit cc90b8d
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/ratio/moment_matching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,28 @@ gaussian_gram_by_pairwise_dot(pdot; σ=1) = exp.(-pdot ./ 2(σ ^ 2))
gaussian_gram(x; σ=1) = gaussian_gram_by_pairwise_dot(pairwise_dot(x); σ=σ)
gaussian_gram(x, y; σ=1) = gaussian_gram_by_pairwise_dot(pairwise_dot(x, y); σ=σ)

function _estimate_r_de(pdot_dede, pdot_denu, get_r_hat, σ; kwargs...)
Kdede = gaussian_gram_by_pairwise_dot(pdot_dede; σ=σ)
Kdenu = gaussian_gram_by_pairwise_dot(pdot_denu; σ=σ)
return get_r_hat(Kdede, Kdenu; kwargs...)
end

function estimate_r_de(x_de, x_nu; get_r_hat=get_r_hat_analytical, σs=nothing, kwargs...)
pdot_dede = pairwise_dot_kai(x_de)
pdot_denu = pairwise_dot_kai(x_de, x_nu)

if isnothing(σs); σs = [sqrt(median([pdot_dede..., pdot_denu...]))]; end
if isnothing(σs)
σ = sqrt(median([pdot_dede..., pdot_denu...]))
@info "Automatically choose σ using the square root of the median of pairwise distances: ."
σs = [σ]
end

Kdede = mean([gaussian_gram_by_pairwise_dot(pdot_dede; σ=σ) for σ in σs])
Kdenu = mean([gaussian_gram_by_pairwise_dot(pdot_denu; σ=σ) for σ in σs])
r_de = _estimate_r_de(pdot_dede, pdot_denu, get_r_hat, σs[1]; kwargs...)
for σ in σs[2:end]
r_de += _estimate_r_de(pdot_dede, pdot_denu, get_r_hat, σ; kwargs...)
end

return get_r_hat(Kdede, Kdenu; kwargs...)
return r_de ./ length(σs)
end

function get_r_hat_numerically(Kdede, Kdenu; positive=true, normalisation=true)
Expand Down

0 comments on commit cc90b8d

Please sign in to comment.