Skip to content

Commit

Permalink
impl MMD ratio estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Sep 11, 2019
1 parent 9631a73 commit 35a1c63
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 35 deletions.
152 changes: 117 additions & 35 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,35 @@ version = "1.1.3"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[BenchmarkTools]]
deps = ["JSON", "Printf", "Statistics"]
git-tree-sha1 = "90b73db83791c5f83155016dd1cc1f684d4e1361"
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
version = "0.4.3"

[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"

[[BinaryProvider]]
deps = ["Libdl", "SHA"]
deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.4"
version = "0.5.6"

[[CSTParser]]
deps = ["LibGit2", "Test", "Tokenize"]
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
deps = ["Tokenize"]
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.5.2"
version = "0.6.2"

[[Calculus]]
deps = ["Compat"]
git-tree-sha1 = "bd8bbd105ba583a42385bd6dc4a20dad8ab3dc11"
uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
version = "0.5.0"

[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
Expand All @@ -46,16 +58,22 @@ uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.2"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "10050a24b09e8e41b951e9976b109871ce98d965"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"
version = "0.8.0"

[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"]
git-tree-sha1 = "c9c1845d6bf22e34738bee65c357a69f416ed5d1"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
version = "0.9.6"

[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand All @@ -64,10 +82,10 @@ uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.1.0"

[[Conda]]
deps = ["Compat", "JSON", "VersionParsing"]
git-tree-sha1 = "b625d802587c2150c279a40a646fba63f9bd8187"
deps = ["JSON", "VersionParsing"]
git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032"
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
version = "1.2.0"
version = "1.3.0"

[[Crayons]]
deps = ["Test"]
Expand All @@ -76,10 +94,10 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
version = "0.17.0"

[[Dates]]
deps = ["Printf"]
Expand All @@ -89,6 +107,24 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.4"

[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10"

[[Distances]]
deps = ["LinearAlgebra", "Statistics"]
git-tree-sha1 = "23717536c81b63e250f682b0e0933769eecd1411"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.8.2"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -104,6 +140,12 @@ git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.6.1"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"

[[GSL]]
deps = ["BinaryProvider", "Compat", "Libdl", "LinearAlgebra", "Markdown", "Printf", "Random", "Test"]
git-tree-sha1 = "96048e9db673b38968fd279db00ee45a16520add"
Expand All @@ -114,10 +156,22 @@ version = "0.5.0"
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[Ipopt]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "MathOptInterface", "MathProgBase"]
git-tree-sha1 = "0fee58f35c4acf9011e1652d223791c72258d6b5"
uuid = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
version = "0.6.0"

[[JSON]]
deps = ["Dates", "Distributed", "Mmap", "Sockets", "Test", "Unicode"]
git-tree-sha1 = "1f7a25b53ec67f5e9422f1f551ee216503f4a0fa"
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.0"

[[JuMP]]
deps = ["Calculus", "DataStructures", "ForwardDiff", "LinearAlgebra", "MathOptInterface", "NaNMath", "Random", "SparseArrays", "Statistics"]
git-tree-sha1 = "a970a86abc924f2c126cdb4978a5e8923d0e7b22"
uuid = "4076af6c-e467-56ae-b986-b466b2749572"
version = "0.20.0"

[[LaTeXStrings]]
Expand All @@ -140,24 +194,41 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["CSTParser", "Compat", "DataStructures", "Test"]
git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162"
deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"]
git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.0"
version = "0.5.1"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[MathOptInterface]]
deps = ["BenchmarkTools", "LinearAlgebra", "OrderedCollections", "SparseArrays", "Test", "Unicode"]
git-tree-sha1 = "2772d0090391b4bce23f4da0fa74143b6d0d5939"
uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
version = "0.9.2"

[[MathProgBase]]
deps = ["Compat"]
git-tree-sha1 = "3bf2e534e635df810e5f4b4f1a8b6de9004a0d53"
uuid = "fdba3010-5040-5b88-9595-932c9decdf73"
version = "0.7.7"

[[Missings]]
deps = ["SparseArrays", "Test"]
git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007"
git-tree-sha1 = "29858ce6c8ae629cf2d733bffa329619a1c843d0"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.1"
version = "0.4.2"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
Expand All @@ -166,9 +237,15 @@ version = "1.1.0"

[[PDMats]]
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
git-tree-sha1 = "8b68513175b2dc4023a564cb0e917ce90e74fd69"
git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.9.7"
version = "0.9.9"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "ef0af6c8601db18c282d092ccbd2f01f3f0cd70b"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.3.7"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand All @@ -191,10 +268,10 @@ uuid = "d330b81b-6aea-500a-939a-2ce795aea3ee"
version = "2.8.1"

[[QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "389fd27b958a33df5234772cc464b663b208273d"
deps = ["DataStructures", "LinearAlgebra", "Test"]
git-tree-sha1 = "3ce467a8e76c6030d4c3786e7d3a73442017cdc0"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.0.4"
version = "2.0.3"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
Expand Down Expand Up @@ -245,6 +322,12 @@ git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.11.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -282,16 +365,15 @@ uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"

[[Tokenize]]
deps = ["Printf", "Test"]
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.3"
version = "0.5.6"

[[TranscodingStreams]]
deps = ["Random", "Test"]
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.4"
version = "0.9.5"

[[URIParser]]
deps = ["Test", "Unicode"]
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ version = "0.3.0"
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
AutoGrad = "6710c13c-97f1-543f-91c5-74e8f7d95b35"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GSL = "92c85e6c-cbff-5e0c-80f7-495c94daaecd"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
1 change: 1 addition & 0 deletions src/MLToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ include("distributions/distributions.jl")
include("MonteCarlo/MonteCarlo.jl")
Reexport.@reexport using .MonteCarlo
# include("neural/neural.jl")
include("ratio/ratio.jl")

include("test_util.jl")
export NUM_RANDTESTS, ATOL, ATOL_RAND, include_list_as_module
Expand Down
58 changes: 58 additions & 0 deletions src/ratio/moment_matching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
pairwise_dot(x) = pairwise(SqEuclidean(), x; dims=2)

function pairwise_dot_kai(x)
n = size(x, 2)
xixj = x' * x
xsq = sum(x .^ 2; dims=1)
return repeat(xsq', 1, n) + repeat(xsq, n, 1) - 2xixj
end

# pairwise_dot(x::CuArray) = pairwise_dot_kai(x)

pairwise_dot(x, y) = pairwise(SqEuclidean(), x, y; dims=2)

function pairwise_dot_kai(x, y)
nx = size(x, 2)
ny = size(y, 2)
xiyj = x' * y
xsq = sum(x .^ 2; dims=1)
ysq = sum(y .^ 2; dims=1)
return repeat(xsq', 1, ny) .+ repeat(ysq, nx, 1) - 2xiyj
end

# pairwise_dot(x::CuArray, y::CuArray) = pairwise_dot_kai(x, y)

gaussian_gram_by_pairwise_dot(pdot; σ=1) = exp.(-pdot ./ 2^ 2))
gaussian_gram(x; σ=1) = gaussian_gram_by_pairwise_dot(pairwise_dot(x); σ=σ)
gaussian_gram(x, y; σ=1) = gaussian_gram_by_pairwise_dot(pairwise_dot(x, y); σ=σ)

function estimate_r_de(x_de, x_nu, get_r_hat=get_r_hat_numerically; σs=[sqrt(median([pdot_dede..., pdot_denu...]))])
pdot_dede = pairwise_dot_kai(x_de)
pdot_denu = pairwise_dot_kai(x_de, x_nu)

Kdede = mean([gaussian_gram_by_pairwise_dot(pdot_dede; σ=σ) for σ in σs])
Kdenu = mean([gaussian_gram_by_pairwise_dot(pdot_denu; σ=σ) for σ in σs])

return get_r_hat(Kdede, Kdenu)
end

function get_r_hat_numerically(Kdede, Kdenu; positive=true, normalisation=true)
n_de, n_nu = size(Kdenu)
model = Model(with_optimizer(Ipopt.Optimizer; print_level=0))
@variable(model, r[1:n_de])
@objective(model, Min, 1 / n_de ^ 2 * sum(r[i] * Kdede[i,j] * r[j] for i = 1:n_de, j=1:n_de) - 2 / (n_de * n_nu) * sum(r[i] * Kdenu[i,j] for i = 1:n_de, j=1:n_nu))
if positive
@constraint(model, r .>= 0)
end
if normalisation
@constraint(model, 1 / n_de * sum(r) == 1)
end
JuMP.optimize!(model)
return value.(r)
end

function get_r_hat_analytical(Kdede, Kdenu; ϵ=1 / 1_000)
n_de, n_nu = size(Kdenu)
return n_de / n_nu * inv(Kdede + ϵ * I) * Kdenu * ones(n_nu)
end
;
5 changes: 5 additions & 0 deletions src/ratio/ratio.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using Distances: pairwise, SqEuclidean
using LinearAlgebra: I
using JuMP, Ipopt
include("moment_matching.jl")
export estimate_r_de, get_r_hat_numerically, get_r_hat_analytical

0 comments on commit 35a1c63

Please sign in to comment.