In [None]:
# !Convert*first_ADMM_dual.py*w*sh*

In [None]:
import numpy

from utils import *

In [None]:
def init(m, n, c):
    lamda = numpy.zeros(m)
    eta = numpy.zeros(n)
    
    e = c - lamda.reshape((m, 1)) - eta.reshape((1, n))
    d = numpy.zeros((m, n))
    
    return lamda, eta, e, d

In [None]:
def update(m, n, mu, nu, c, lamda, eta, e, d, eta_sigma, rho, alpha):
    lamda = (
          (mu + numpy.sum(d, axis=1)) / rho
        - eta_sigma
        - numpy.sum(e, axis=1)
        + numpy.sum(c, axis=1)
    ) / n
    
    lamda_sigma = numpy.sum(lamda)
    
    eta = (
          (nu + numpy.sum(d, axis=0)) / rho
        - lamda_sigma
        - numpy.sum(e, axis=0)
        + numpy.sum(c, axis=0)
    ) / m
    
    e = d + c - lamda.reshape((m, 1)) - eta.reshape((1, n))
    e = numpy.maximum(e, 0.)
    
    d = d + alpha * rho * (c - lamda.reshape((m, 1)) - eta.reshape((1, n)) - e)
    
    return lamda, eta, e, d

In [None]:
def solve_ADMM_dual(
    p,
    scale=None, its=[], rhos=[], alphas=[], epss=None, min_its=None,
    fh=None, figs={}, log=None, stat=False,
    *args, **kwargs
):
    m, n = p.c.shape
    
    if scale is None:
        scale = math.sqrt(m * n)
    
    mu, nu = scale*p.mu, scale*p.nu
    c = p.c
    
    if fh is not None:
        if "error" in figs:
            error_e = []
            error_mu = []
            error_nu = []
        if "loss" in figs:
            loss_dual = []
            loss_dual2 = []
    
    lamda, eta, e, d = init(m, n, c)
    
    l = len(its)
    itc = 0
    
    for i in range(l):
        for j in range(its[i]):
            lamda, eta, e, d = update(m, n, mu, nu, c, lamda, eta, e, d, 0., rhos[i], alphas[i])

            itc += 1

            if fh is not None:
                if "error" in figs:
                    error_mu.append(numpy.linalg.norm(d.sum(axis=1) + mu, numpy.infty) / scale * m)
                    error_nu.append(numpy.linalg.norm(d.sum(axis=0) + nu, numpy.infty) / scale * n)
                    error_e.append(numpy.linalg.norm(c - lamda.reshape((m, 1)) - eta.reshape((1, n)) - e))
                if "loss" in figs:
                    loss_dual.append(-((c * d).sum() / scale))
                    loss_dual2.append(((lamda * mu).sum() + (eta * nu).sum()) / scale)
            if epss is not None:
                if (
                        numpy.linalg.norm(d.sum(axis=1) + mu, numpy.infty) / scale * m < epss[i]
                    and numpy.linalg.norm(d.sum(axis=0) + nu, numpy.infty) / scale * n < epss[i]
                ):
                    if min_its is None or j > min_its[i]:
                        break
                    
            if log is not None:
                log("i, j, itc = {0}, {1}, {2}".format(i, j, itc))
    
    # To save time
    p.s = d / (-scale)
    
    if fh is not None:
        if "error" in figs:
            fh.new(1, 1, 1)
            fh.ax.semilogy(numpy.array(error_mu), label="Error of mu")
            fh.ax.semilogy(numpy.array(error_nu), label="Error of nu")
            fh.ax.semilogy(numpy.array(error_e), label="Error of e")
            fh.ax.legend()
            fh.show()
            fh.close()
        if "loss" in figs:
            fh.new(1, 1, 1)
            fh.ax.plot(numpy.array(loss_dual), label="Loss from dual")
            fh.ax.plot(numpy.array(loss_dual2), label="Loss from dual^2")
            fh.ax.legend()
            fh.show()
            fh.close()
    
    if stat:
        s = {
            "title": "ADMM on dual",
            "loss": (c * (-d)).sum() / scale,
            "vars": 2*m*n + m + n,
            "iters": itc,
        }
        return p, s
    else:
        return p

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*first_ADMM_dual_test.py*w*sehx*

In [None]:
# !Switch*
# !SwitchCase*
# import font
# from utils import *
# from first_ADMM_dual import solve_ADMM_dual
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(sav=False, log=print)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", log=print)
# !SwitchEnd*

In [None]:
stat = Statistics(
    probs=[
        ot_2d_general(
            m=500, n=500,
            mup_gen=samp_2d_Caffarelli(0., 0., 1., 0.),
            nup_gen=samp_2d_Caffarelli(0., 0., 1., 2.),
            mu_gen=val_const(),
            nu_gen=val_const(),
            dist=dist_2d_euc_2,
        )
    ],
    merge_config=general_merge_config,
    output_config=general_output_config,
    prob="Test problems",
)

In [None]:
stat.test(
    solve_ADMM_dual,
    its=[8000],
    rhos=[0.1],
    alphas=[1.618],
    epss=[1e-3],
    fh=fh, figs={"error", "loss"},
)
stat.output_last()

In [None]:
stat.test(
    solve_ADMM_dual,
    its=[8000],
    rhos=[0.1],
    alphas=[1.618],
    epss=[1e-3],
)
stat.output_last()

In [None]:
from solver_mosek import solve_mosek_interior_point

In [None]:
stat.test(
    solve_mosek_interior_point,
)
stat.output_last()

In [None]:
# !ConvertEnd*