In [None]:
# !Convert*solve_mosek.py*w*sh*

In [None]:
import time

import numpy

import mosek

from utils import *

In [None]:
def mosek_set_model(p, task):
    m, n = p.c.shape
    
    inf = 0.
    
    task.appendvars(m*n)
    task.appendcons(m+n)
    
    task.putvarboundlist(
        range(m*n),
        [mosek.boundkey.lo]*(m*n),
        [0.]*(m*n),
        [inf]*(m*n)
    )
    
    for i in range(m):
        task.putarow(
            i,
            range(i*n, (i+1)*n),
            [1.]*n
        )
    task.putconboundlist(
        range(0, m),
        [mosek.boundkey.fx]*m,
        p.mu,
        p.mu
    )
    
    for i in range(n):
        task.putarow(
            i+m,
            range(i, i+m*n, n),
            [1.]*m
        )
    task.putconboundlist(
        range(m, m+n),
        [mosek.boundkey.fx]*n,
        p.nu,
        p.nu
    )
    
    task.putclist(range(m*n), p.c.reshape(m*n))
    
    task.putobjsense(mosek.objsense.minimize)

def solve_mosek(
    p,
    mtd=None, sol=None, var=None, it=None,
    log=None, stat=False,
    *args, **kwargs
):
    m, n = p.c.shape
    
    if stat:
        start_time = time.time()
    
    with mosek.Env() as env:
        env.set_Stream(mosek.streamtype.log, log)
        
        with env.Task() as task:
            task.set_Stream(mosek.streamtype.log, log)
            
            task.putintparam(mosek.iparam.optimizer, mtd)
            
            mosek_set_model(p, task)
            
            if stat:
                end_time = time.time()
            
            task.optimize()
            
            xx = [0.] * (m*n)
            task.getxx(sol, xx)
            
            p.sol = numpy.array(xx).reshape(m, n)
    
            if stat:
                setup_time = end_time - start_time
                s = {
                    "loss": task.getprimalobj(sol),
                    "vars": task.getintinf(var),
                    "iters": task.getintinf(it),
                    "setup": setup_time,
                    "solve": task.getdouinf(mosek.dinfitem.optimizer_time)
                }
                return p, s
            else:
                return p

In [None]:
def solve_mosek_primal_simplex(p, *args, **kwargs):
    return solve_mosek(
        p,
        mtd=mosek.optimizertype.primal_simplex,
        sol=mosek.soltype.bas,
        var=mosek.iinfitem.opt_numvar,
        it=mosek.iinfitem.sim_primal_iter,
        *args, **kwargs
    )

def solve_mosek_dual_simplex(p, *args, **kwargs):
    return solve_mosek(
        p,
        mtd=mosek.optimizertype.dual_simplex,
        sol=mosek.soltype.bas,
        var=mosek.iinfitem.opt_numvar,
        it=mosek.iinfitem.sim_dual_iter,
        *args, **kwargs
    )

def solve_mosek_interior_point(p, *args, **kwargs):
    return solve_mosek(
        p,
        mtd=mosek.optimizertype.intpnt,
        sol=mosek.soltype.itr,
        var=mosek.iinfitem.opt_numvar,
        it=mosek.iinfitem.intpnt_iter,
        *args, **kwargs
    )

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*solve_mosek_test.py*w*sehx*

In [None]:
# !Switch*
# !SwitchCase*
# import font
# from utils import *
# from solve_mosek import solve_mosek_primal_simplex, solve_mosek_dual_simplex, solve_mosek_interior_point
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(sav=False, ext=".pdf", log=print)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", log=print)
# !SwitchEnd*

In [None]:
prob = ot_2d_general(
    m=500, n=500,
    mup_gen=samp_2d_Caffarelli(0., 0., 1., 0.),
    nup_gen=samp_2d_Caffarelli(0., 0., 1., 2.),
    mu_gen=val_const(),
    nu_gen=val_const(),
    dist=dist_2d_euc_2,
)

In [None]:
solve_mosek_primal_simplex(prob)

In [None]:
fh.fast(prob.plot_link, scale=30., cutoff=0.5)

In [None]:
prob = ot_2d_general(
    m=2000, n=2000,
    mup_gen=samp_2d_Caffarelli(0., 0., 1., 0.),
    nup_gen=samp_2d_Caffarelli(0., 0., 1., 4.),
    mu_gen=val_const(),
    nu_gen=val_const(),
    dist=dist_2d_euc_2,
)

In [None]:
print("Size = ({0}, {1})".format(prob.mu.size, prob.nu.size))

In [None]:
solve_mosek_primal_simplex(prob, log=print, stat=True)

In [None]:
solve_mosek_dual_simplex(prob, log=print, stat=True)

In [None]:
solve_mosek_interior_point(prob, log=print, stat=True)

In [None]:
# !ConvertEnd*