In [None]:
# !Convert*first_ADMM_primal.py*w*sh*

In [None]:
import math

import numpy

In [None]:
def init(m, n):
    s = numpy.zeros((m, n))
    s_ = numpy.zeros((m, n))

    e = numpy.zeros((m, n))
    
    lamda = numpy.zeros(m)
    eta = numpy.zeros(n)
    
    return s, s_, e, lamda, eta

In [None]:
def update(m, n, mu, nu, ubox, c, s, s_, e, lamda, eta, rho, alpha):
    k = (
          (
              e
            + lamda.reshape((m, 1))
            + eta.reshape((1, n))
            - c
          ) / rho
        + mu.reshape((m, 1))
        + nu.reshape((1, n))
        + s_
    )
    
    s = (
          k
        - ((k.sum(axis=1) - k.sum() / (m + n + 1)) / (n + 1)).reshape((m, 1))
        - ((k.sum(axis=0) - k.sum() / (m + n + 1)) / (m + 1)).reshape((1, n))
    )
    
    s_ = s - e / rho
    s_ = numpy.maximum(s_, 0.)
    if ubox is not None:
        s_ = numpy.minimum(s_, ubox)
    
    lamda = lamda + alpha * rho * (mu - s.sum(axis=1))
    
    eta = eta + alpha * rho * (nu - s.sum(axis=0))
    
    e = e + alpha * rho * (s_ - s)
    
    return s, s_, e, lamda, eta

In [None]:
def solve_ADMM_primal(
    p,
    scale=None, its=[], rhos=[], alphas=[], epss=None, min_its=None,
    fh=None, figs={}, log=None, stat=False,
    *args, **kwargs
):
    m, n = p.c.shape
    
    if scale is None:
        scale = math.sqrt(m * n)
    
    mu, nu = scale*p.mu, scale*p.nu
    c = p.c
    ubox = p.ubox
    if ubox is not None:
        ubox *= scale
    
    if fh is not None:
        if "error" in figs:
            error_mu = []
            error_nu = []
            error_s = []
        if "loss" in figs:
            loss = []
    
    s, s_, e, lamda, eta = init(m, n)
    
    l = len(its)
    itc = 0
    
    for i in range(l):
        for j in range(its[i]):
            s, s_, e, lamda, eta = update(m, n, mu, nu, ubox, c, s, s_, e, lamda, eta, rhos[i], alphas[i])

            itc += 1

            if fh is not None:
                if "error" in figs:
                    error_mu.append(numpy.linalg.norm(s_.sum(axis=1) - mu, 1) / scale)
                    error_nu.append(numpy.linalg.norm(s_.sum(axis=0) - nu, 1) / scale)
                    error_s.append(numpy.linalg.norm(s_ - s))
                if "loss" in figs:
                    loss.append((c * s).sum() / scale)

            if epss is not None:
                if (
                        numpy.linalg.norm(s_.sum(axis=1) - mu, 1) / scale < epss[i]
                    and numpy.linalg.norm(s_.sum(axis=0) - nu, 1) / scale < epss[i]
                ):
                    if min_its is None or j > min_its[i]:
                        break
                    
            if log is not None:
                log("i, j, itc = {0}, {1}, {2}".format(i, j, itc))
    
    p.s = s_ / scale
    
    if fh is not None:
        if "error" in figs:
            fh.new(1, 1, 1)
            fh.ax.semilogy(numpy.array(error_mu), label="Error of mu")
            fh.ax.semilogy(numpy.array(error_nu), label="Error of nu")
            fh.ax.semilogy(numpy.array(error_s), label="Error of s")
            fh.ax.legend()
            fh.show()
            fh.close()
        if "loss" in figs:
            fh.new(1, 1, 1)
            fh.ax.plot(numpy.array(loss), label="Loss")
            fh.ax.legend()
            fh.show()
            fh.close()
    
    if stat:
        s = {
            "title": "ADMM on primal",
            "size": [m, n],
            "loss": (c * s_).sum() / scale,
            "vars": 3*m*n + m + n,
            "iters": itc,
        }
        return p, s
    else:
        
        
        return p

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*first_ADMM_primal_test.py*w*sehx*

In [None]:
import numpy

# !Switch*
from handler import FigureHandler
from dataset import ot_2d_Caffarelli
from stats import Statistics
# !SwitchCase*
# import font
# from handler import FigureHandler
# from dataset import ot_2d_Caffarelli
# from stats import Statistics
# from first_ADMM_primal import solve_ADMM_primal
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(redir=True)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", redir=True)
# !SwitchEnd*

In [None]:
stat = Statistics(
    probs=[
        ot_2d_Caffarelli(500, 500, 1)
    ],
    prob="Test problems",
    log=fh.write,
)

In [None]:
stat.test(
    solve_ADMM_primal,
    its=[10000],
    rhos=[3.],
    alphas=[1.618],
    epss=[1e-4],
    fh=fh, figs={"error", "loss"},
)
stat.output_last()

In [None]:
stat.test(
    solve_ADMM_primal,
    its=[10000],
    rhos=[3.],
    alphas=[1.618],
    epss=[1e-4],
)
stat.output_last()

In [None]:
from solver_mosek import solve_mosek_interior_point

In [None]:
stat.test(
    solve_mosek_interior_point,
)
stat.output_last()

In [None]:
prob = stat.probs[0]
m, n = prob.c.shape
prob.ubox = numpy.minimum(prob.mu.reshape((m, 1)), prob.nu.reshape((1, n)))

In [None]:
stat.test(
    solve_ADMM_primal,
    its=[10000],
    rhos=[3.],
    alphas=[1.618],
    epss=[1e-3],
    fh=fh, figs={"error", "loss"},
)
stat.output_last()

In [None]:
prob.ubox = numpy.minimum(prob.mu.reshape((m, 1)), prob.nu.reshape((1, n))) / 2.

In [None]:
stat.test(
    solve_ADMM_primal,
    its=[10000],
    rhos=[3.],
    alphas=[1.618],
    epss=[1e-3],
    fh=fh, figs={"error", "loss"},
    clean=False
)
stat.output_last()

In [None]:
fh.fast(prob.plot_link)

In [None]:
fh.write(str((prob.s < 0.).sum()))
fh.write(str((prob.s > prob.ubox).sum()))

In [None]:
# !ConvertEnd*