In [None]:

#Include
using Plots, LightGraphs, SparseArrays
using Statistics, BenchmarkTools, LinearAlgebra, ProgressMeter
using Distributions, Base.Threads
using Base.GC
using Mosek, MosekTools, OSQP, ECOS, SCS, ProxSDP, CPLEX
using Clustering, JuMP, Roots, PyCall, Base.Threads
plotly()

In [None]:
ot = pyimport("ot")
np = pyimport("numpy")

# Sample Distribution

In [None]:
x = collect(range(-20,20; length = 5001));
F(x) = exp(-1*((x-15)^2/20))
G(x) = exp(-1*((x+15)^2/20))
H(x) = exp(-1*((x+5)^2/10))
#I(x) = exp(-1*((x-5)^2/10))
Exp(x) = exp(-1*x/100)
a = F.(x)
b = G.(x)

a = a./sum(a);
b = b./sum(b);

In [None]:
D = zeros(length(a),length(b))
for i = 1:length(a)
    for j = 1:length(b)
        D[i,j] = abs(x[i]-x[j])^2
    end
end

In [None]:
plotly()
scatter(x,a, label="a")
scatter!(x,b, label = "b")

# Qudratric regularization

In [None]:
function calc_opt(a,b,C,γ = 1,τ = 1e-4; maxIters = 1000, verbose = false)
    f = γ.*a
    g = γ.*b
    n = length(f)
    m = length(g)
    
    error = τ+1
    
    L = Dict()
    Z = spzeros(n,m)
    
    count = 1
    
    while(error > τ && count < maxIters)
        error = 0
        
        #Fing violated constraints
        if count < maxIters+1
            for i = 1:n
                for j = 1:m
                    if f[i]+g[j] > C[i,j]
                        L[(i,j)] = 1
                    end
                end
            end
        else
            for _ = 1:n
                i = rand(1:n)
                j = rand(1:m)
                L[(i,j)] = 1
            end
        end
        
        #Do the project and forget steps
        for k = 1:100
            K = collect(keys(L))
            for l = 1:length(K)
                i = K[l][1]
                j = K[l][2]
                
                θ = (C[i,j] - f[i] - g[j])/(2*γ)
                c = min(θ,Z[i,j])
                
                if error < -1*θ
                    error = abs(θ)
                end
                
                f[i] += γ*c
                g[j] += γ*c
                
                if c == Z[i,j] || Z[i,j] - c < τ/(n*m)
                    Z[i,j] = 0
                    delete!(L,K[l])
                else
                    Z[i,j] -= c
                end
            end
        end
        
        if verbose
            @show((error,length(L),count))
        end
        
        count += 1
    end
    
    return f,g,Z,L
end

# Method experiment

## Solve Primal using CPLEX and Mosek

In [None]:
heatmap(TZ', color=:viridis)

In [None]:
model = Model()
n = 501
t = 0
t += @elapsed @variable(model, f[1:n])
t += @elapsed @variable(model, g[1:n]);
cons = Array{Any,2}(undef, n,n)
for i = 1:n
    for j = 1:n
        t += @elapsed cons[i,j] = @constraint(model, f[i]+g[j] <= D[i,j])
    end
end

γ = 1e3

In [None]:
@objective(model, Max, sum(f.*a)+sum(g.*b) - (sum(f.^2)+sum(g.^2))/(2*γ));

In [None]:
set_optimizer(model, Mosek.Optimizer)

In [None]:
JuMP.optimize!(model)

In [None]:
P = zeros(n,n)
t += @elapsed for i = 1:n
    for j = 1:n
        P[i,j] = dual(cons[i,j])
    end
end

In [None]:
t

In [None]:
set_optimizer(model, CPLEX.Optimizer)
JuMP.optimize!(model)

In [None]:
set_optimizer(model, OSQP.Optimizer)
JuMP.optimize!(model)

In [None]:
set_optimizer(model, SCS.Optimizer)
JuMP.optimize!(model)

In [None]:
set_optimizer(model, ECOS.Optimizer)
JuMP.optimize!(model)

In [None]:
set_optimizer(model, ProxSDP.Optimizer)
JuMP.optimize!(model)

## Solving the dual with CPLEX and Mosek

In [None]:
using Convex

In [None]:
@time P = Variable(length(a),length(b));

In [None]:
@time vone = ones(length(a));

In [None]:
@time problem = minimize(sum(D.*P)+sumsquares(a-P*vone)*γ/2+sumsquares(b-P'*vone)*γ/2, [P >= 0]);

In [None]:
@time problem = minimize(sum(D.*P), [P >= 0, a==P*vone,b==P'*vone]);

In [None]:
t = @elapsed solve!(problem, CPLEX.Optimizer);