In [None]:
# !Convert*multiscale_image_mosek.py*w*sh*

In [None]:
import math

import numpy

import mosek

In [None]:
def point_num(point, ysize):
    return point[0]*ysize + point[1]

In [None]:
def point_vec(num, ysize):
    return (num // ysize, num % ysize)

In [None]:
def shrink_u(m, mu, shrink):
    ''' First we change the array mu back to image
       then shrink the image by combining scale_shrink ** 2 pixels into one pixel
       then we change the image_shrink to an array
    '''
    mu_matrix = mu.reshape(m)
    # Suppose shrink divide m[0] and m[1]
    new_m = (m[0] // shrink, m[1] // shrink)
    mu_shrink_matrix = numpy.zeros(new_m)
    
    for i in range(new_m[0]):
        for j in range(new_m[1]):
            for i1 in range(shrink):
                for j1 in range(shrink):
                    mu_shrink_matrix[i, j] += mu_matrix[i * shrink + i1, j * shrink + j1]

    return mu_shrink_matrix.reshape(new_m[0] * new_m[1])

In [None]:
def shrink_c(m, n, c, shrink):
    new_m = (m[0] // shrink, m[1] // shrink)
    new_n = (n[0] // shrink, n[1] // shrink)
    new_ms = new_m[0] * new_m[1]
    new_ns = new_n[0] * new_n[1]
    
    new_c = c.reshape(m[0], m[1], n[0], n[1])[::shrink, ::shrink, ::shrink, ::shrink]
    return new_m, new_n, new_c.reshape(new_ms, new_ns)

In [None]:
def propagate(m, n, m_shrink, n_shrink, path_coarsen, shrink):
    path = []
    for p in path_coarsen:
        for i1 in range(shrink):
            for i2 in range(shrink):
                for j1 in range(shrink):
                    for j2 in range(shrink):
                        begin_x, begin_y = point_vec(p[1], m_shrink[1])
                        end_x, end_y = point_vec(p[2], n_shrink[1])
                        path.append((
                            p[0] / (shrink ** 4),
                            point_num((begin_x * shrink + i1, begin_y * shrink + i2), m[1]),
                            point_num((end_x * shrink + j1, end_y * shrink + j2), n[1])
                        ))
    return path

In [None]:
def small(m, n, mu, nu, c, capacity, error, mtd, solt):
    ms = m[0] * m[1]
    ns = n[0] * n[1]
    
    scale = math.sqrt(ms * ns)
    
    mu, nu = mu*scale, nu*scale
    
    ris = numpy.minimum(mu.reshape((ms, 1)), nu.reshape((1, ns)))
    
    with mosek.Env() as env:
        with env.Task() as task:
            task.putintparam(mosek.iparam.optimizer, mtd)
            
            task.appendvars(ms*ns)
            task.appendcons(ms+ns)
            
            task.putvarboundlist(
                range(ms*ns),
                [mosek.boundkey.ra]*(ms*ns),
                [0.]*(ms*ns),
                (capacity[0] * ris).reshape((ms*ns))
            )
            
            for i in range(ms):
                task.putarow(
                    i,
                    range(i*ns, (i+1)*ns),
                    [1.]*ns
                )
            task.putconboundlist(
                range(0, ms),
                [mosek.boundkey.fx]*ms,
                mu,
                mu
            )
    
            for i in range(ns):
                task.putarow(
                    i+ms,
                    range(i, i+ms*ns, ns),
                    [1.]*ms
                )
            task.putconboundlist(
                range(ms, ms+ns),
                [mosek.boundkey.fx]*ns,
                nu,
                nu
            )
            
            task.putclist(range(ms*ns), c.reshape(ms*ns))
            task.putobjsense(mosek.objsense.minimize)
            task.optimize()
            
            sol = [0.] * (ms*ns)
            task.getxx(solt, sol)
    
    path = [(sol[i*ns + j] / scale, i, j) for i in range(ms) for j in range(ns) if sol[i*ns + j] > error[0] * ris[i, j]]
    
    return path

In [None]:
def multi(m, n, mu, nu, c, stop, shrink, capacity, error, mtd, solt):
    if min(*m, *n) <= stop:
        return small(m, n, mu, nu, c, capacity, error, mtd, solt)
    
    ms = m[0] * m[1]
    ns = n[0] * n[1]
    
    mu_shrink = shrink_u(m, mu, shrink)
    nu_shrink = shrink_u(n, nu, shrink)
    m_shrink, n_shrink, c_shrink = shrink_c(m, n, c, shrink)
    
    path_coarsen = multi(m_shrink, n_shrink, mu_shrink, nu_shrink, c_shrink, stop, shrink, capacity[1:], error[1:], mtd, solt)
    path = propagate(m, n, m_shrink, n_shrink, path_coarsen, shrink)
    l = len(path)
    
    ris = numpy.minimum(mu.reshape((ms, 1)), nu.reshape((1, ns)))
    ris = ris[[p[1] for p in path], [p[2] for p in path]]
    
    with mosek.Env() as env:
        with env.Task() as task:
            task.putintparam(mosek.iparam.optimizer, mtd)
            
            task.appendvars(l)
            task.appendcons(ms+ns)
            
            task.putvarboundlist(
                range(l),
                [mosek.boundkey.ra]*(l),
                [0.]*(l),
                capacity[0] * ris
            )
            
            tmp1 = [[] for i in range(ms)]
            tmp2 = [[] for i in range(ns)]
            for j in range(l):
                tmp1[path[j][1]].append(j)
                tmp2[path[j][2]].append(j)
            
            for i in range(ms):
                task.putarow(
                    i,
                    tmp1[i],
                    [1.] * len(tmp1[i])
                )
            task.putconboundlist(
                range(0, ms),
                [mosek.boundkey.fx]*ms,
                mu,
                mu
            )

            for i in range(ns):
                task.putarow(
                    i+ms,
                    tmp2[i],
                    [1.] * len(tmp2[i])
                )
            task.putconboundlist(
                range(ms, ms+ns),
                [mosek.boundkey.fx]*ns,
                nu,
                nu
            )
            
            task.putclist(range(l), c[[p[1] for p in path], [p[2] for p in path]])
            task.putobjsense(mosek.objsense.minimize)
            task.optimize()

            xx = [0.] * l
            task.getxx(solt, xx)

    new_path = []
    for i in range(l):
        if xx[i] > error[0] * ris[i]:
            new_path.append((
                xx[i],
                path[i][1],
                path[i][2]
            ))
    
    return new_path

In [None]:
def solve_multiscale_image_mosek(
    p,
    stop, shrink, caps, errs, mtd, solt,
    log=None, stat=False, title="",
    *args, **kwargs
):
    path = multi(p.m, p.n, p.mu, p.nu, p.c, stop, shrink, caps, errs, mtd, solt)
    
    sol = numpy.zeros((p.m[0] * p.m[1], p.n[0] * p.n[1]))
    for pa in path:
        sol[pa[1], pa[2]] = pa[0]
    
    p.s = sol
    
    if stat:
        s = {
            "title": title,
            "loss": (p.c * sol).sum(),
        }
        return p, s
    else:
        return p

In [None]:
def solve_multiscale_image_mosek_interior_point(
    p,
    stop, shrink, caps, errs,
    *args, **kwargs
):
    return solve_multiscale_image_mosek(
        p, stop, shrink, caps, errs, 
        mtd=mosek.optimizertype.intpnt,
        solt=mosek.soltype.itr,
        title="Multiscale for image using mosek interior point",
        *args, **kwargs
    )

def solve_multiscale_image_mosek_primal_simplex(
    p,
    stop, shrink, caps, errs,
    *args, **kwargs
):
    return solve_multiscale_image_mosek(
        p, stop, shrink, caps, errs, 
        mtd=mosek.optimizertype.primal_simplex,
        solt=mosek.soltype.bas,
        title="Multiscale for image using mosek primal simplex",
        *args, **kwargs
    )

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*multiscale_image_mosek_test.py*w*sehx*

In [None]:
# !Switch*
from handler import FigureHandler
from dataset import ot_im_general, samp_2d_mid, val_unif, dist_2d_euc_2
from stats import Statistics
# !SwitchCase*
# import font
# from handler import FigureHandler
# from dataset import ot_im_general, samp_2d_mid, val_unif, dist_2d_euc_2
# from stats import Statistics
# from multiscale_image_mosek import solve_multiscale_image_mosek_interior_point
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(redir=True)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", redir=True)
# !SwitchEnd*

In [None]:
stat = Statistics(
    probs=[
        ot_im_general(
            m=[32, 32], n=[32, 32],
            mw=[0., 1., 0., 1.], nw=[0., 1., 0., 1.],
            mup_gen=samp_2d_mid(0., 1., 0., 1.),
            nup_gen=samp_2d_mid(0., 1., 0., 1.),
            mu_gen=val_unif(0., 1.),
            nu_gen=val_unif(0., 1.),
            dist=dist_2d_euc_2,
            seed=1,
        )
    ],
    prob="Test problems",
)

In [None]:
stat.test(
    solve_multiscale_image_mosek_interior_point,
    stop=4, shrink=2,
    caps=[1., 0.3, 0.1, 0.1],
    errs=[0.001, 0.001, 0.001, 0.00],
    clean=False
)
stat.output_last()
fh.fast(stat.probs[0].plot_link)
stat.probs[0].clean()

In [None]:
from solver_mosek import solve_mosek_interior_point

In [None]:
stat.test(
    solve_mosek_interior_point,
    clean=False
)
stat.output_last()
fh.fast(stat.probs[0].plot_link)
stat.probs[0].clean()

In [None]:
# !ConvertEnd*