Comparing the compute time for running the interior-point method on three different formulations of the optimization problem: (1) dual, (2) primal, and (3) modified primal.

In [7]:
using RCall
using Mosek
using JuMP
using CSV

In [3]:
include("../code/julia/ash.jl")



ash (generic function with 1 method)

In [8]:
# vanilla interior point method for the primal
function vanilla_primal(L)
    n,m = size(L);
    mod = Model(solver=MosekSolver(QUIET = true));
    @variable(mod, x[1:m] >= 0, start = 1/m);
    @NLobjective(mod, Min, -sum(log(sum(L[i,j]*x[j] for j = 1:m)) for i = 1:n)/n);
    @constraint(mod, sum(x) == 1)
    solve(mod);
    x = getvalue(x);
    x[x.< 1e-3] = 0;
    return sparse(x/sum(x));
end

vanilla_primal (generic function with 1 method)

In [9]:
# vanilla interior point method for the primal
function vanilla_boxed(L)
    n,m = size(L);
    mod = Model(solver=MosekSolver(QUIET = true));
    @variable(mod, x[1:m] >= 0, start = 1/m);
    @NLobjective(mod, Min, -sum(log(sum(L[i,j]*x[j] for j = 1:m)) for i = 1:n)/n + sum(x[i] for i = 1:m));
    solve(mod);
    x = getvalue(x);
    x[x.< 1e-3] = 0;
    return sparse(x/sum(x));
end

vanilla_boxed (generic function with 1 method)

In [10]:
# vanilla interior point method for the primal
function vanilla_dual(L)
    n,m = size(L);
    mod = Model(solver=MosekSolver(QUIET = true));
    @variable(mod, y[1:n] >= 0, start = 1/n);
    @NLobjective(mod, Min, -sum(log(y[i]) for i = 1:n)/n);
    @constraint(mod, ic, L'*y .<= 1)
    solve(mod);
    x = -getdual(ic);
    x[x.< 1e-3] = 0;
    return sparse(x/sum(x));
end

vanilla_dual (generic function with 1 method)

In [11]:
function REBayes(L)
    @rput L;
    R"require(REBayes);
    t_rebayes = system.time(res <- KWDual(L, rep(1,dim(L)[2]), rep(1,dim(L)[1])/dim(L)[1]))[3];
    res$f[res$f < 1e-3] = 0
    x_rebayes = res$f / sum(res$f)"
    @rget x_rebayes;
    return sparse(x_rebayes);
end

REBayes (generic function with 1 method)

In [140]:
L = Array{Float64,2}(CSV.read("../data/sample5000x20.txt", nullable = false, header = false, delim = ' '));

In [51]:
@time vanilla_dual(L)
@time vanilla_primal(L)
@time vanilla_boxed(L)
@time REBayes(L)

  0.632511 seconds (442.58 k allocations: 53.744 MiB, 67.20% gc time)
  0.686826 seconds (1.02 M allocations: 133.468 MiB, 1.73% gc time)
  0.818156 seconds (1.02 M allocations: 133.450 MiB, 18.72% gc time)
  0.399800 seconds (170 allocations: 6.984 KiB)


20-element SparseVector{Float64,Int64} with 3 stored entries:
  [1 ]  =  0.469905
  [13]  =  0.305865
  [14]  =  0.22423

In [64]:
srand(1)
x = [randn(5*10^5);3*randn(5*10^5)];
s = ones(10^6);
iter = 8;
L = ash(x,s, mult = 1e3, lowrank = "nothing")[3];
size(L)

(1000000, 3)

In [49]:
@time vanilla_dual(L)
@time vanilla_primal(L)
@time vanilla_boxed(L)
@time REBayes(L)

 14.036674 seconds (37.00 M allocations: 2.549 GiB, 24.60% gc time)
 10.933699 seconds (51.00 M allocations: 3.004 GiB, 12.16% gc time)
 10.391847 seconds (51.00 M allocations: 3.004 GiB, 9.84% gc time)
 11.840568 seconds (153 allocations: 5.531 KiB)


In [173]:
function sqp_box(L)
    n,m = size(L);
    x = ones(m)/m;
    for i = 1:100
        D = 1./(L*x + 1e-8);
        g = -L'*D/n;
        H = L'*Diagonal(D.^2)*L/n + 1e-8 * speye(m);
        
        if minimum(g+1) >= -1e-6
          break;
        end
        
        mod = Model(solver=MosekSolver(QUIET = true));
        a,b = findn(H);
        @variable(mod, y[1:m] >= 0);
        @objective(mod, Min, QuadExpr(y[a],y[b],H[:]/2,AffExpr(y, 2*g+1, 0)) )
        solve(mod);
        x = getvalue(y);
        x[x .< 0] = 0;
        println(x)
        
    end
    x[x .< 1e-3] = 0;
    return x
end

sqp_box (generic function with 1 method)

In [631]:

for i = 1:100
    D = 1./(L*x + 1e-8);
    g = -L'*D/n;
    x = x .* (g.^2);
end

In [632]:
x

20-element Array{Float64,1}:
 0.448069    
 0.020337    
 0.00157086  
 1.42207e-5  
 2.39921e-9  
 2.34251e-16 
 1.98556e-29 
 1.57267e-53 
 9.24419e-96 
 1.35138e-160
 8.97308e-231
 6.05222e-218
 0.305698    
 0.224312    
 4.94066e-323
 9.88131e-324
 4.94066e-324
 0.0         
 0.0         
 0.0         

In [624]:
D = 1./(L*x + 1e-8);
    g = -L'*D/n;
    x = x .* (g.^2);
    x = x/sum(x)

20-element Array{Float64,1}:
 0.447824    
 0.0205508   
 0.00160193  
 1.47481e-5  
 2.56706e-9  
 2.6572e-16  
 2.51469e-29 
 2.44819e-53 
 2.08009e-95 
 5.42192e-160
 6.85951e-230
 4.2389e-217 
 0.305696    
 0.224312    
 4.94066e-323
 9.88131e-324
 4.94066e-324
 0.0         
 0.0         
 0.0         

In [601]:
[eval_f(L,x) eval_f(L,y)]

1×2 Array{Float64,2}:
 0.467113  0.467113

In [597]:
y = sqp_primal(L)

20-element Array{Float64,1}:
 0.469885
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.305864
 0.22423 
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     

In [535]:
function sqp_primal(L)
    n,m = size(L);
    x = ones(m)/m;
    for i = 1:100
        D = 1./(L*x + 1e-8);
        g = -L'*D/n;
        H = L'*Diagonal(D.^2)*L/n + 1e-8 * speye(m);
        
        if minimum(g+1) >= -1e-8
          break;
        end
        
        mod = Model(solver=MosekSolver(QUIET = true));
        a,b = findn(H);
        @variable(mod, y[1:m] >= 0);
        @objective(mod, Min, QuadExpr(y[a],y[b],H[:]/2,AffExpr(y, 2*g, 0)) )
        @constraint(mod, sum(y) == 1);
        solve(mod);
        x = getvalue(y);
        x[x .< 0] = 0;

    end
    x[x .< 1e-3] = 0;
    return x
end

sqp_primal (generic function with 1 method)

In [132]:
function sqp_dual(L)
    n,m = size(L);
    x = ones(n)/n;
    for i = 1:100
        D = 1./x;
        
        mod = Model(solver=MosekSolver(QUIET = true));
        @variable(mod, y[1:n] >= 0);
        @objective(mod, Min, QuadExpr(y,y,D.^2/2/n,AffExpr(y, -2*D/n, 0)) )
        @constraint(mod, L'*y .<= 1);
        solve(mod);
        x = getvalue(y);
        x[x .< 0] = 0;
        
        println(x[1])

    end
    return x
end

sqp_dual (generic function with 1 method)

In [114]:
L = Array{Float64,2}(CSV.read("../data/sample100000x100.txt", nullable = false, header = false, delim = ' '));

In [174]:
@time sqp_box(L)

[0.050302, 0.0502918, 0.0502833, 0.0502682, 0.0502412, 0.0501956, 0.0501309, 0.0500998, 0.0504027, 0.0521416, 0.0585201, 0.0783786, 0.13232, 0.145536, 0.0463974, 0.0148561, 0.00820402, 0.00568864, 0.00975584, 0.0]
[0.0523327, 0.0523339, 0.0523355, 0.0523398, 0.0523523, 0.0523905, 0.0525099, 0.0528767, 0.0539315, 0.0566371, 0.0635433, 0.0833016, 0.127922, 0.133212, 0.0593565, 0.0218916, 0.0101227, 0.00609943, 0.020397, 0.0]
  0.229389 seconds (106.99 k allocations: 14.326 MiB, 2.38% gc time)


20-element Array{Float64,1}:
 0.0523327 
 0.0523339 
 0.0523355 
 0.0523398 
 0.0523523 
 0.0523905 
 0.0525099 
 0.0528767 
 0.0539315 
 0.0566371 
 0.0635433 
 0.0833016 
 0.127922  
 0.133212  
 0.0593565 
 0.0218916 
 0.0101227 
 0.00609943
 0.020397  
 0.0       

In [177]:
@time sqp_primal(L)

[0.477111, 2.94024e-5, 1.57915e-5, 6.80474e-6, 2.63773e-6, 1.24819e-6, 7.07332e-7, 4.32877e-7, 2.7562e-7, 1.90629e-7, 1.50425e-7, 1.95713e-7, 0.445846, 0.0769853, 4.16843e-8, 1.52489e-8, 9.8944e-9, 0.0, 0.0, 0.0]
[0.461626, 3.72366e-5, 9.12725e-6, 5.33708e-6, 3.38055e-6, 1.99809e-6, 1.13996e-6, 6.6778e-7, 4.21359e-7, 3.24509e-7, 2.96978e-7, 3.93307e-7, 0.370649, 0.167665, 6.24075e-8, 1.61561e-8, 5.8141e-9, 0.0, 0.0, 0.0]
[0.469062, 9.62642e-7, 4.06529e-7, 2.25402e-7, 1.26081e-7, 7.27474e-8, 4.54177e-8, 3.11381e-8, 2.28832e-8, 1.77779e-8, 1.53081e-8, 1.91976e-8, 0.312897, 0.218038, 2.59878e-9, 5.83347e-10, 2.65858e-10, 0.0, 0.0, 0.0]
[0.469855, 2.04808e-5, 9.67103e-6, 5.24604e-6, 2.89619e-6, 1.59013e-6, 8.65718e-7, 4.71995e-7, 2.64519e-7, 1.59462e-7, 1.12786e-7, 1.21442e-7, 0.305941, 0.224161, 1.31551e-8, 3.0742e-9, 1.07856e-9, 0.0, 0.0, 0.0]
[0.469885, 9.83125e-6, 4.68968e-6, 2.56557e-6, 1.41839e-6, 7.77267e-7, 4.22833e-7, 2.31523e-7, 1.31235e-7, 8.04096e-8, 5.78786e-8, 6.31848e-8, 0.3

20-element Array{Float64,1}:
 0.469885
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.305864
 0.22423 
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     
 0.0     

In [127]:
@time sqp_dual(L)

1.4804971720005258e-5


LoadError: [91mInterruptException:[39m