In [1]:
using Distributions
using Random
using LinearAlgebra
using PyPlot
using Optim
using JLD

In [None]:
function grad!(∇,sorted_X,μ1,μ2,Z1,Z2,σ)
    Y = [μ1 .+ sqrt(1 + σ^2)*Z1; μ2 .+ sqrt(1 + σ^2)*Z2]
    π = sortperm(Y)
    Y = Y[π]
    inv = zeros(size(π))
    for (i,v) in enumerate(π)
        inv[v] = i
    end
    
    ∇[1] = 2*(Y .- sorted_X)⋅(inv .<= length(Y)/2)/length(sorted_X)
    ∇[2] = (2*sum(Y .- sorted_X) - ∇[1])/length(sorted_X)
end

In [None]:
function squared_W2(sorted_X,μ1,μ2,Z1,Z2,σ)
    sorted_Y = sort([μ1 .+ sqrt(1 + σ^2)*Z1; μ2 .+ sqrt(1 + σ^2)*Z2])
    return (sorted_X .- sorted_Y)⋅(sorted_X .- sorted_Y)/length(sorted_Y)
end

In [None]:
true_μ1 = -1;
true_μ2 = 1;
true_dist = MixtureModel(Normal[Normal(true_μ1, 1), Normal(true_μ2, 1)], [1/2, 1/2]);

In [None]:
sigmas = [0.1 0.2 0.5]
sample_sizes = [25 50 100 200 400]
T = 20 # num noise samples per original sample from target dist
N = 1 # num iterations per sample size

# GD details
K = 2000
learning_rate = 0.01



μ1_ests = zeros(length(sigmas), length(sample_sizes), N)
μ2_ests = zeros(length(sigmas), length(sample_sizes), N)
sw2_errs = zeros(length(sigmas), length(sample_sizes), N)

for (i,σ) in enumerate(sigmas)
    println(σ)
    for (j,n) in enumerate(sample_sizes)
        println(n)
        for k = 1:N
            #target dist samples
            model_samples = rand(true_dist, n)
            smoothed_model_samples = zeros(n,T*2)
            for l = 1:n
                smoothed_model_samples[l,:] = model_samples[l] .+ σ * randn(T*2)
            end
            sorted_X = sort(vec(smoothed_model_samples));

            # input noise
            m = n*T
            Z1 = randn(m)
            Z2 = randn(m)

            # estimate parameters which minimize distance to empirical approximation of target distribution
            θ = [2.0, 2.0]
            ∇ = [1.0, 1.0]
            l = 0
            while norm(∇) > 0.001
                grad!(∇, sorted_X, θ[1], θ[2], Z1, Z2, σ)
                θ -= learning_rate * ∇
                l += 1
            end
            println(l)
            μ1_ests[i,j,k] = min(θ...)
            μ2_ests[i,j,k] = max(θ...)
            println(θ)
            sw2_errs[i,j,k] = W2(sorted_X, θ[1], θ[2], Z1, Z2, σ)
        end
    end
end

In [None]:
fig, axs = plt.subplots(2,2, figsize=(6,5))

for (i,σ) in enumerate(sigmas)
    for (j,n) in enumerate(sample_sizes)
        if j == 2 || j == 4
            continue
        end
        x1 = μ1_ests[i,j,:]
        x2 = μ2_ests[i,j,:]
        filt = x1.^2 + x2.^2 .> 0.1
        x1 = x1[filt]
        x2 = x2[filt]
            
        normalized_μ1_errs = sqrt(n)*(x1 .- true_μ1)
        normalized_μ2_errs = sqrt(n)*(x2 .- true_μ2)
        axs[i].scatter(normalized_μ1_errs, normalized_μ2_errs, marker="x", label="n=$(n)")
    end
    axs[i].set_xlabel(L"\sqrt{n}(\hat{a}_1 - a_1)")
    axs[i].set_ylabel(L"\sqrt{n}(\hat{a}_2 - a_2)")
    axs[i].set_xlim([-6,6])
    axs[i].set_ylim([-6,6])
    axs[i].set_title("σ = $(σ)")
    axs[i].tick_params(axis="both", which="major", labelsize=6)
    axs[i].tick_params(axis="both", which="minor", labelsize=6)
end
axs[1].legend()

for (i,σ) in enumerate(sigmas)
    mean_sw2_err = zeros(length(sample_sizes))
    for (j,n) in enumerate(sample_sizes)
        x1 = μ1_ests[i,j,:]
        x2 = μ2_ests[i,j,:]
        filt = x1.^2 + x2.^2 .> 0.1
        mean_sw2_err[j] = sqrt.(mean(sw2_errs[i,j,filt]))
    end
    axs[end].plot(vec(sample_sizes), mean_sw2_err,label="σ=$(σ)")
    axs[end].set_xscale("log")
    axs[end].set_yscale("log")
    axs[end].set_xlabel("# samples n")
    axs[end].set_ylabel(L"W_2^{(\sigma)}")
    axs[end].tick_params(axis="both", which="major", labelsize=6)
    axs[end].tick_params(axis="both", which="minor", labelsize=6)
end
axs[end].legend()
axs[end].set_title(L"$W_2^{(\sigma)}$ convergence")

subplots_adjust(wspace=0.4, hspace=0.45)