In [None]:
# !Convert*multiscale_image_func.py*w*sh*

In [None]:
import math

import numpy

from dataset import OTProblem

In [None]:
def shrink_u(m, mu, shrink):
    ''' First we change the array mu back to image
       then shrink the image by combining scale_shrink ** 2 pixels into one pixel
       then we change the image_shrink to an array
    '''
    mu_matrix = mu.reshape(m)
    # Suppose shrink divide m[0] and m[1]
    new_m = (m[0] // shrink, m[1] // shrink)
    mu_shrink_matrix = numpy.zeros(new_m)
    
    for i in range(new_m[0]):
        for j in range(new_m[1]):
            for i1 in range(shrink):
                for j1 in range(shrink):
                    mu_shrink_matrix[i, j] += mu_matrix[i * shrink + i1, j * shrink + j1]

    return mu_shrink_matrix.reshape(new_m[0] * new_m[1])

In [None]:
def shrink_c(m, n, c, shrink):
    new_m = (m[0] // shrink, m[1] // shrink)
    new_n = (n[0] // shrink, n[1] // shrink)
    new_ms = new_m[0] * new_m[1]
    new_ns = new_n[0] * new_n[1]
    
    new_c = c.reshape(m[0], m[1], n[0], n[1])[::shrink, ::shrink, ::shrink, ::shrink]
    return new_m, new_n, new_c.reshape(new_ms, new_ns)

In [None]:
def small(m, n, mu, nu, c, cfgs, capacity, error, ris=None, first=False):
    ms = m[0] * m[1]
    ns = n[0] * n[1]
    
    if ris is None:
        ris = numpy.minimum(mu.reshape((ms, 1)), nu.reshape((1, ns)))
    
    p = OTProblem()
    p.mu, p.nu, p.c = mu, nu, c
    
    p.ubox = ris * capacity[0]
    
    if first:
        _, st = cfgs[0]["func"](p, *cfgs[0]["args"], stat=True, **cfgs[0]["kwargs"])
        return p.s, st
    else:
        cfgs[0]["func"](p, *cfgs[0]["args"], **cfgs[0]["kwargs"])
        return p.s

In [None]:
def multi(m, n, mu, nu, c, cfgs, stop, shrink, capacity, error, ris=None, first=False):
    if min(*m, *n) <= stop:
        return small(m, n, mu, nu, c, cfgs, capacity, error, ris=ris, first=first)
    
    ms = m[0] * m[1]
    ns = n[0] * n[1]
    
    if ris is None:
        ris = numpy.minimum(mu.reshape((ms, 1)), nu.reshape((1, ns)))
    
    mu_shrink = shrink_u(m, mu, shrink)
    nu_shrink = shrink_u(n, nu, shrink)
    m_shrink, n_shrink, c_shrink = shrink_c(m, n, c, shrink)
    
    ms_shrink = n_shrink[0] * m_shrink[1]
    ns_shrink = n_shrink[0] * n_shrink[1]
    
    ris_shrink = numpy.minimum(mu_shrink.reshape((ms_shrink, 1)), nu_shrink.reshape((1, ns_shrink)))
    
    sol_shrink = multi(m_shrink, n_shrink, mu_shrink, nu_shrink, c_shrink, cfgs[1:], stop, shrink, capacity[1:], error[1:], ris=ris_shrink)
    
    cc_shrink = numpy.zeros((ms_shrink, ns_shrink))
    cc_shrink[sol_shrink > error[0]*ris_shrink] = 1.
    
    cc = cc_shrink.reshape((*m_shrink, *n_shrink))
    for i in range(4):
        cc = cc.repeat(shrink, axis=i)
    cc = cc.reshape((ms, ns))
    
    p = OTProblem()
    p.mu, p.nu, p.c = mu, nu, c
    
    p.ubox = ris * capacity[0] * cc
    
    if first:
        _, st = cfgs[0]["func"](p, *cfgs[0]["args"], stat=True, **cfgs[0]["kwargs"])
        return p.s, st
    else:
        cfgs[0]["func"](p, *cfgs[0]["args"], **cfgs[0]["kwargs"])
        return p.s

In [None]:
def solve_multiscale_image_func(
    p,
    cfgs, stop, shrink, caps, errs,
    log=None, stat=False,
    *args, **kwargs
):
    s = multi(p.m, p.n, p.mu, p.nu, p.c, cfgs, stop, shrink, caps, errs, first=True)
    
    p.s, st = s
    
    if stat:
        st["title"]: "multiscale for image using func"
        st["loss"]: (p.c * s).sum()
        return p, st
    else:
        return p

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*multiscale_image_func_test.py*w*sehx*

In [None]:
# !Switch*
from handler import FigureHandler
from dataset import ot_im_general, samp_2d_mid, val_unif, dist_2d_euc_2
from stats import Statistics
# !SwitchCase*
# import font
# from handler import FigureHandler
# from dataset import ot_im_general, samp_2d_mid, val_unif, dist_2d_euc_2
# from stats import Statistics
# from multiscale_image_func import solve_multiscale_image_func
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(redir=True)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", redir=True)
# !SwitchEnd*

In [None]:
stat = Statistics(
    probs=[
        ot_im_general(
            m=[32, 32], n=[32, 32],
            mw=[0., 1., 0., 1.], nw=[0., 1., 0., 1.],
            mup_gen=samp_2d_mid(0., 1., 0., 1.),
            nup_gen=samp_2d_mid(0., 1., 0., 1.),
            mu_gen=val_unif(0., 1.),
            nu_gen=val_unif(0., 1.),
            dist=dist_2d_euc_2,
            seed=1,
        )
    ],
    prob="Test problems",
)

In [None]:
from first_ADMM_primal import solve_ADMM_primal

In [None]:
stat.test(
    solve_multiscale_image_func,
    cfgs=[
        {
            "func": solve_ADMM_primal,
            "args": (),
            "kwargs": dict(
                its=[20000],
                rhos=[3.],
                alphas=[1.618],
                epss=[1e-3],
                fh=fh, figs={"error", "loss"},
            )
        },
        {
            "func": solve_ADMM_primal,
            "args": (),
            "kwargs": dict(
                its=[2000],
                rhos=[3.],
                alphas=[1.618],
                epss=[1e-4],
                fh=fh, figs={"error", "loss"},
            )
        },
        {
            "func": solve_ADMM_primal,
            "args": (),
            "kwargs": dict(
                its=[1000],
                rhos=[3.],
                alphas=[1.618],
                epss=[1e-4],
                fh=fh, figs={"error", "loss"},
            )
        },
        {
            "func": solve_ADMM_primal,
            "args": (),
            "kwargs": dict(
                its=[200],
                rhos=[3.],
                alphas=[1.618],
                epss=[1e-4],
                fh=fh, figs={"error", "loss"},
            )
        },
    ],
    stop=4,
    shrink=2,
    caps=[1., 0.1, 0.1, 0.1],
    errs=[0.05, 0.05, 0.02, 0.00],
)
stat.output_last()

In [None]:
stat.test(
    solve_ADMM_primal,
    its=[20000],
    rhos=[3.],
    alphas=[1.618],
    epss=[1e-3],
    fh=fh, figs={"error", "loss"},
)
stat.output_last()

In [None]:
from solver_mosek import solve_mosek_interior_point

In [None]:
stat.test(solve_mosek_interior_point)
stat.output_last()

In [None]:
# !ConvertEnd*