In [None]:
# !Convert*first_grad_primal.py*w*sh*

In [None]:
import numpy

from utils import *

In [None]:
def init(m, n, s_refine, scale, refine=False):
    if refine:
        s = s_refine * scale
        s_old = s
    else:
        s = numpy.zeros((m, n))
        s_old = s
    
    return s, s_old

In [None]:
def update(m, n, mu, nu, c, s, j, pi0, pi1, pi2, lr):
    k = s + lr * ((pi1 * mu).reshape((m, 1)) + (pi2 * nu).reshape((1, n)) - c)
    
    s = (
          k
        - (pi1*lr / (1 + pi1*lr * n) * (k.sum(axis=1) - pi2*lr / (1 + pi1*lr * n + pi2*lr * m) * k.sum())).reshape((m, 1))
        - (pi2*lr / (1 + pi2*lr * m) * (k.sum(axis=0) - pi1*lr / (1 + pi1*lr * n + pi2*lr * m) * k.sum())).reshape((1, n))
    )
    
    s = numpy.maximum(s, 0.) + numpy.minimum(s + pi0 * lr, 0.)
    
    return s

In [None]:
def update_nesterov(m, n, mu, nu, c, s, s_old, j, pi0, pi1, pi2, lr):
    t = s + (j - 1) / (j + 2) * (s - s_old)
    s_old = s
    
    k = t + lr * ((pi1 * mu).reshape((m, 1)) + (pi2 * nu).reshape((1, n)) - c)
    
    s = (
          k
        - (pi1*lr / (1 + pi1*lr * n) * (k.sum(axis=1) - pi2*lr / (1 + pi1*lr * n + pi2*lr * m) * k.sum())).reshape((m, 1))
        - (pi2*lr / (1 + pi2*lr * m) * (k.sum(axis=0) - pi1*lr / (1 + pi1*lr * n + pi2*lr * m) * k.sum())).reshape((1, n))
    )
    
    s = numpy.maximum(s, 0.) + numpy.minimum(s + pi0 * lr, 0.)
    
    return s, s_old

In [None]:
def solve_grad_primal(
    p,
    scale=None, its=[], pi0s=[], pi1s=[], pi2s=[], lrs=[], nests=None, pures=None, epss=None, min_its=None,
    refine=False, fh=None, figs={}, log=None, stat=False,
    *args, **kwargs
):
    m, n = p.c.shape
    
    if scale is None:
        scale = math.sqrt(m * n)
    
    mu, nu = scale*p.mu, scale*p.nu
    c = p.c
    
    if fh is not None:
        if "error" in figs:
            error_mu = []
            error_nu = []
        if "loss" in figs:
            loss = []
    
    s, s_old = init(m, n, p.s, scale, refine)
    
    l = len(its)
    itc = 0
    
    for i in range(l):
        for j in range(its[i]):
            if nests is None or nests[i] == True:
                s, s_old = update_nesterov(m, n, mu, nu, c, s, s_old, j, pi0s[i], pi1s[i], pi2s[i], lrs[i])
            else:
                s = update(m, n, mu, nu, c, s, j, pi0s[i], pi1s[i], pi2s[i], lrs[i])

            itc += 1

            if fh is not None:
                if "error" in figs:
                    error_mu.append(numpy.linalg.norm(s.sum(axis=1) - mu, numpy.infty) / scale * m)
                    error_nu.append(numpy.linalg.norm(s.sum(axis=0) - nu, numpy.infty) / scale * n)
                if "loss" in figs:
                    loss.append((c * s).sum() / scale)

            if epss is not None:
                if (
                        numpy.linalg.norm(s.sum(axis=1) - mu, numpy.infty) / scale * m < epss[i]
                    and numpy.linalg.norm(s.sum(axis=0) - nu, numpy.infty) / scale * n < epss[i]
                ):
                    if min_its is None or j > min_its[i]:
                        break
                    
            if log is not None:
                log("i, j, itc = {0}, {1}, {2}".format(i, j, itc))
    
    p.s = s / scale
    
    if fh is not None:
        if "error" in figs:
            fh.new(1, 1, 1)
            fh.ax.semilogy(numpy.array(error_mu), label="Error of mu")
            fh.ax.semilogy(numpy.array(error_nu), label="Error of nu")
            fh.ax.legend()
            fh.show()
            fh.close()
        if "loss" in figs:
            fh.new(1, 1, 1)
            fh.ax.plot(numpy.array(loss), label="Loss")
            fh.ax.legend()
            fh.show()
            fh.close()
    
    if stat:
        s = {
            "title": "2-step proximal gradient on primal",
            "loss": (c * s).sum() / scale,
            "vars": 2*m*n,
            "iters": itc,
        }
        return p, s
    else:
        return p

In [None]:
def solve_combine(
    p,
    cfgs,
    *args, **kwargs
):
    l = len(cfgs)
    
    cfgs[0]["func"](p, *cfgs[0]["args"], *args, clean=False, **cfgs[0]["kwargs"], **kwargs)
    for i in range(1, l-1):
        cfgs[i]["func"](p, *cfgs[i]["args"], *args, clean=False, refine=True, **cfgs[i]["kwargs"], **kwargs)
    return cfgs[-1]["func"](p, *cfgs[-1]["args"], *args, refine=True, **cfgs[-1]["kwargs"], **kwargs)

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*first_grad_primal_test.py*w*sehx*

In [None]:
# !Switch*
# !SwitchCase*
# import font
# from utils import *
# from first_grad_primal import solve_grad_primal, solve_combine
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(sav=False, log=print)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", log=print)
# !SwitchEnd*

In [None]:
stat = Statistics(
    probs=[
        ot_2d_general(
            m=500, n=500,
            mup_gen=samp_2d_Caffarelli(0., 0., 1., 0.),
            nup_gen=samp_2d_Caffarelli(0., 0., 1., 2.),
            mu_gen=val_const(),
            nu_gen=val_const(),
            dist=dist_2d_euc_2,
        )
    ],
    merge_config=general_merge_config,
    output_config=general_output_config,
    prob="Test problems",
)

In [None]:
stat.test(
    solve_grad_primal,
    its=[200, 100, 200, 100, 200],
    pi0s=[10., 10., 10., 10., 10.],
    pi1s=[10., 10., 100., 100., 1000.],
    pi2s=[10., 10., 100., 100., 1000.],
    lrs=[1e-4, 1e-5, 1e-5, 1e-6, 1e-6],
    fh=fh, figs={"error", "loss"},
    clean=False,
)
stat.output_last()

In [None]:
stat.test(
    solve_grad_primal,
    its=[100, 200, 100, 200],
    pi0s=[10., 10., 10., 10.],
    pi1s=[1000., 10000., 10000., 100000.],
    pi2s=[1000., 10000., 10000., 100000.],
    lrs=[1e-7, 1e-7, 1e-8, 1e-8],
    fh=fh, figs={"error", "loss"},
    refine=True,
)
stat.output_last()

In [None]:
from solver_mosek import solve_mosek_interior_point

In [None]:
stat.test(
    solve_mosek_interior_point,
)
stat.output_last()

In [None]:
from first_ADMM_primal import solve_ADMM_primal

In [None]:
stat.test(
    solve_ADMM_primal,
    its=[1000],
    rhos=[3.],
    alphas=[1.618],
    epss=[1e-4],
    fh=fh, figs={"error", "loss"},
    clean=False,
)
stat.output_last()

In [None]:
stat.test(
    solve_grad_primal,
    its=[2000],
    pi0s=[100.],
    pi1s=[10000.],
    pi2s=[10000.],
    lrs=[1e-7],
    fh=fh, figs={"error", "loss"},
    refine=True,
)
stat.output_last()

In [None]:
stat.test(
    solve_combine,
    cfgs=[
        {
            "func": solve_ADMM_primal,
            "args": (),
            "kwargs": dict(
                its=[1000],
                rhos=[3.],
                alphas=[1.618],
                epss=[1e-4],
            )
        },
        {
            "func": solve_grad_primal,
            "args": (),
            "kwargs": dict(
                its=[2000],
                pi0s=[100.],
                pi1s=[10000.],
                pi2s=[10000.],
                lrs=[1e-7],
            )
        },
    ],
    fh=fh, figs={"error", "loss"},
)
stat.output_last()

In [None]:
# !ConvertEnd*