In [None]:
# !Convert*solver_gurobi.py*w*sh*

In [None]:
import time
import math

import numpy

from gurobipy import *

from utils import *

In [None]:
def gurobi_set_model(mu, nu, c, M):
    m, n = c.shape
    
    s = M.addVars(m, n, lb=0., ub=GRB.INFINITY)
    
    # LinExpr is much faster than tuplelist.prod or quicksum
    M.addConstrs(LinExpr([(1., s[i, j]) for j in range(n)]) == mu[i] for i in range(m))
    M.addConstrs(LinExpr([(1., s[i, j]) for i in range(m)]) == nu[j] for j in range(n))
    
    M.setObjective(LinExpr([(c[i, j], s[i, j]) for i in range(m) for j in range(n)]))
    
    return s

def solve_gurobi(
    p,
    scale=None,
    mtd=-1, it_name=None,
    log=None, stat=False, title="",
    *args, **kwargs
):
    if stat:
        start_time = time.time()
        
    m, n = p.c.shape
    
    if scale is None:
        scale = numpy.sqrt(m * n)
    
    M = Model("OT")
    
    if log is None:
        M.setParam(GRB.Param.OutputFlag, 0)
    else:
        M.setParam(GRB.Param.OutputFlag, 1)
    
    M.setParam(GRB.Param.Method, mtd)
    
    s = gurobi_set_model(scale*p.mu, scale*p.nu, p.c, M)
    
    if stat:
        end_time = time.time()
    
    M.optimize()
    
    sx = M.getAttr("x", s)
    p.s = numpy.array([sx[i, j] for i in range(m) for j in range(n)]).reshape(m, n) / scale
    
    if stat:
        setup_time = end_time - start_time
        s = {
            "title": title,
            "loss": M.getAttr("ObjVal") / scale,
            "vars": M.getAttr("NumVars"),
            "iters": M.getAttr(it_name),
            "setup": setup_time,
            "solve": M.getAttr("Runtime")
        }
        return p, s
    else:
        return p

In [None]:
def solve_gurobi_primal_simplex(p, *args, **kwargs):
    return solve_gurobi(
        p,
        mtd=0, it_name="IterCount", title="Gurobi, primal simplex",
        *args, **kwargs
    )

def solve_gurobi_dual_simplex(p, *args, **kwargs):
    return solve_gurobi(
        p,
        mtd=1, it_name="IterCount", title="Gurobi, dual simplex",
        *args, **kwargs
    )

def solve_gurobi_barrier(p, *args, **kwargs):
    return solve_gurobi(
        p,
        mtd=2, it_name="BarIterCount", title="Gurobi, barrier",
        *args, **kwargs
    )

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*solver_gurobi_test.py*w*sehx*

In [None]:
# !Switch*
# !SwitchCase*
# import font
# from utils import *
# from solver_gurobi import solve_gurobi_primal_simplex, solve_gurobi_dual_simplex, solve_gurobi_barrier
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(sav=False, log=print)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", log=print)
# !SwitchEnd*

In [None]:
prob = ot_2d_general(
    m=500, n=500,
    mup_gen=samp_2d_Caffarelli(0., 0., 1., 0.),
    nup_gen=samp_2d_Caffarelli(0., 0., 1., 2.),
    mu_gen=val_const(),
    nu_gen=val_const(),
    dist=dist_2d_euc_2,
)

In [None]:
solve_gurobi_primal_simplex(prob)

In [None]:
fh.fast(prob.plot_link)

In [None]:
solve_gurobi_primal_simplex(prob, log=print, stat=True)

In [None]:
solve_gurobi_dual_simplex(prob, log=print, stat=True)

In [None]:
solve_gurobi_barrier(prob, log=print, stat=True)

In [None]:
# !ConvertEnd*