Skip to content
This repository has been archived by the owner on Nov 4, 2021. It is now read-only.

Commit

Permalink
Merge pull request #8 from theogf/Addfix
Browse files Browse the repository at this point in the history
Loosen MvNormal and add StatsBase inheritance
  • Loading branch information
theogf committed Jul 19, 2021
2 parents cdc031c + 42fe5fa commit 0dd8d6c
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 232 deletions.
2 changes: 2 additions & 0 deletions .gitignore
@@ -1 +1,3 @@
/Manifest.toml
test/Manifest.toml
.vscode/settings.json
8 changes: 5 additions & 3 deletions Project.toml
@@ -1,20 +1,22 @@
name = "KLDivergences"
uuid = "3c9cd921-3d3f-41e2-830c-e020174918cc"
authors = ["Theo Galy-Fajou <theo.galyfajou@gmail.com> and contributors"]
version = "0.1.1"
version = "0.1.2"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Distances = "0.9, 0.10"
Distributions = "0.23, 0.24"
PDMats = "0.9, 0.10"
Distributions = "0.23, 0.24, 0.25"
PDMats = "0.9, 0.10, 0.11"
SpecialFunctions = "0.10, 1.2"
StatsBase = "0.33"
julia = "1"

[extras]
Expand Down
13 changes: 10 additions & 3 deletions src/KLDivergences.jl
@@ -1,10 +1,15 @@
module KLDivergences

using Distributions: StatsBase
using Distributions
using LinearAlgebra, PDMats
using Distances, SpecialFunctions
using LinearAlgebra
using PDMats
using Distances
using SpecialFunctions
using StatsBase: StatsBase, kldivergence

export KL

export KL, kldivergence

"""
KL(p::Distribution, q::Distribution) -> T
Expand All @@ -14,6 +19,8 @@ Return the KL divergence of KL(p||q), either by sampling or analytically
"""
KL

StatsBase.kldivergence(p::Sampleable, q::Sampleable) = KL(p, q)

KLbase(p, q, x) = logpdf(p, x) - logpdf(q, x)

## Generic fallback for multivariate Distributions
Expand Down
12 changes: 10 additions & 2 deletions src/multivariate.jl
@@ -1,6 +1,14 @@
function KL(p::AbstractMvNormal, q::AbstractMvNormal)
length(p) == length(q) ||
throw(DimensionMismatch("Distributions p and q have different dimensions $(length(p)) and $(length(q))"))
Σp = cov(p)
Σq = cov(q)
Δμ = mean(p) - mean(q)
0.5 * (tr(Σq \ Σp) + dot(Δμ / Σq, Δμ) - length(p) + logdet(Σq) - logdet(Σp))
end

function KL(p::MvNormal, q::MvNormal)
length(p) == length(q) ||
throw(DimensionMismatch("Distributions p and q have different dimensions $(length(p)) and $(length(q))"))
Σp = p.Σ; Σq = q.Σ
0.5 * (tr(Σq \ Σp) + invquad(Σq, mean(p) - mean(q)) - length(p) + logdet(Σq) - logdet(Σp))
0.5 * (tr(q.Σ \ p.Σ) + invquad(q.Σ, mean(p) - mean(q)) - length(p) + logdet(q.Σ) - logdet(p.Σ))
end
221 changes: 0 additions & 221 deletions test/Manifest.toml

This file was deleted.

1 change: 1 addition & 0 deletions test/Project.toml
@@ -1,4 +1,5 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1 change: 1 addition & 0 deletions test/runtests.jl
@@ -1,6 +1,7 @@
using KLDivergences
using Distributions
using LinearAlgebra
using Random
using Test

@testset "KLDivergences.jl" begin
Expand Down
6 changes: 3 additions & 3 deletions test/univariate.jl
Expand Up @@ -30,10 +30,10 @@
@test KL(p, q) KL(p, q, 100_000) atol = 0.1
end
@testset "Poisson" begin
p = Poisson(4)
q = Normal(5.0)
p = Poisson(4.0)
q = Poisson(3.0)
@test KL(p, q) > 0
@test KL(p, q) KL(p, q, 100_000) atol = 0.2
@test KL(p, q) KL(p, q, 100_000) atol = 0.1
end

end

2 comments on commit 0dd8d6c

@theogf
Copy link
Owner Author

@theogf theogf commented on 0dd8d6c Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/41176

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.2 -m "<description of version>" 0dd8d6cd75efdca9528d877456dd1c7ed876aef2
git push origin v0.1.2

Please sign in to comment.