In [None]:
# !Convert*utils.py*w*sh*

In [None]:
import math
import time
import statistics

import numpy
import scipy.stats
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
class FigureHandler(object):
    def __init__(self, fig=None, ax=None, sav=False, disp=True, ext=".pdf", log=None):
        self.fig = fig
        self.ax = ax
        self.sav = sav
        self.disp = disp
        self.ext = ext
        self.log = log
        self.ctr = 0
    
    def filename(self):
        self.ctr += 1
        fn = "Figure-{0:04}-{1:}".format(self.ctr, int(time.time())) + self.ext
        return fn
    
    def new(self, *args, **kwargs):
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(*args, **kwargs)
    
    def refresh(self, *args, **kwargs):
        self.fig = plt.figure(*args, **kwargs)
    
    def subplot(self, *args, **kwargs):
        self.ax = self.fig.add_subplot(*args, **kwargs)
    
    def colorbar(self, mpbl, *args, **kwargs):
        self.fig.colorbar(mpbl, *args, **kwargs)
    
    def close(self, *args, **kwargs):
        plt.close(self.fig, *args, **kwargs)
    
    def save(self):
        fn = self.filename()
        plt.savefig(fn)
        if self.log is not None:
            print("{} saved".format(fn))
    
    def show(self):
        if self.sav:
            self.save()
        if self.disp:
            plt.show()
    
    def fast(self, func, new_pos=(), new_kw={}, *args, **kwargs):
        self.new(1, 1, 1, *new_pos, **new_kw)
        func(self, *args, **kwargs)
        self.show()
        self.close()

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils_test.py*w*sehx*

In [None]:
# !Switch*
# !SwitchCase*
# import font
# from utils import *
# from solve_mosek import solve_mosek_primal_simplex, solve_mosek_dual_simplex, solve_mosek_interior_point
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(sav=False, ext=".pdf", log=print)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", log=print)
# !SwitchEnd*

In [None]:
fh.new(1, 1, 1, projection="3d")
fh.ax.scatter([1., 2.], [3., 4.], [5., 6.])
fh.show()
fh.close()

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils.py*a*sh*

In [None]:
class OTProblem(object):
    def __init__(self, mu=None, nu=None, c=None):
        self.mu = mu
        self.nu = nu
        self.c = c
        self.sol = None
    
    def set_cx(self):
        self.cx = self.sol
    
    def clean(self):
        self.sol = None
    
    def plot_hotline(self, fh, colorbar=True, *args, **kwargs):
        mpbl = fh.ax.imshow(self.sol, *args, **kwargs)
        if colorbar:
            fh.fig.colorbar(mpbl)
        return mpbl

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils_test.py*a*sh*

In [None]:
prob = OTProblem()
prob.sol = numpy.fromfunction(lambda i, j: numpy.sin(i / 3) * j**2, (50, 50))

In [None]:
fh.fast(prob.plot_hotline)

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils.py*a*sh*

In [None]:
class OTProblem1d(OTProblem):
    def __init__(self, *args, **kwargs):
        super(OTProblem1d, self).__init__(*args, **kwargs)
    
    def plot_mu_scatter(self, fh, *args, **kwargs):
        fh.ax.scatter(self.mup, self.mu, *args, **kwargs)
    
    def plot_nu_scatter(self, fh, *args, **kwargs):
        fh.ax.scatter(self.nup, self.nu, *args, **kwargs)
    
    def plot_scatter(
        self, fh,
        cutoff=0.2,
        aspect="auto", colorbar=True,
        *args, **kwargs
    ):
        m, n = self.c.shape
        
        mean = self.sol.mean() * numpy.sqrt(m * n)
        mask = self.sol > cutoff * mean
        
        ind = numpy.indices((m, n))
        
        dot_x, dot_y = self.mup[ind[0]][mask], self.nup[ind[1]][mask]
        dot_c = self.sol[mask]
        
        mpbl = fh.ax.scatter(dot_x, dot_y, c=dot_c, *args, **kwargs)
        if colorbar:
            fh.fig.colorbar(mpbl)
        
        fh.ax.set_aspect(aspect)
        return mpbl
    
    def plot_link(
        self, fh,
        off_bx=0., off_by=0., off_ex=0., off_ey=1.,
        scatter=True, 
        cutoff=0.2, scale=30.,
        aspect="auto", colorbar=True,
        sca_pos=(), sca_kw={},
        *args, **kwargs
    ):
        m, n = self.c.shape
        
        if scatter:
            mus = self.mu / self.mu.mean() * scale
            nus = self.nu / self.nu.mean() * scale
            fh.ax.scatter(self.mup + off_bx, numpy.zeros_like(self.mup) + off_by, s=mus, *sca_pos, **sca_kw)
            fh.ax.scatter(self.nup + off_ex, numpy.zeros_like(self.nup) + off_ey, s=nus, *sca_pos, **sca_kw)
        
        mean = self.sol.mean() * numpy.sqrt(m * n)
        mask = self.sol > cutoff * mean
        
        ind = numpy.indices((m, n))

        arr_bx = self.mup[ind[0]][mask] + off_bx
        arr_by = numpy.zeros_like(arr_bx) + off_by

        arr_ex = self.nup[ind[1]][mask] + off_ex
        arr_ey = numpy.zeros_like(arr_ex) + off_ey

        arr_dx = arr_ex - arr_bx
        arr_dy = arr_ey - arr_by

        arr_c = self.sol[mask]

        mpbl = fh.ax.quiver(arr_bx, arr_by, arr_dx, arr_dy, arr_c, angles="xy", scale_units="xy", scale=1., *args, **kwargs)
        if colorbar:
            fh.fig.colorbar(mpbl)
        
        fh.ax.set_aspect(aspect)
        return mpbl

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils_test.py*a*sh*

In [None]:
prob = OTProblem1d()
prob.mup = numpy.linspace(0., 1., 50)
prob.nup = numpy.linspace(2., 3., 50)
prob.mu = numpy.linspace(2., 1., 50)
prob.nu = numpy.linspace(4., 3., 50)
prob.c = numpy.eye(50)
prob.sol = numpy.eye(50)

In [None]:
fh.fast(prob.plot_mu_scatter, alpha=0.5)

In [None]:
fh.fast(prob.plot_scatter, alpha=0.5)

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils.py*a*sh*

In [None]:
def samp_1d_grid(start, stop):
    def samp_1d_grid_gen(num):
        p = numpy.linspace(start, stop, num)
        return p
    return samp_1d_grid_gen

In [None]:
def samp_1d_norm(loc, scale):
    def samp_1d_norm_gen(num):
        p = numpy.random.normal(loc, scale, num)
        p.sort()
        return p
    return samp_1d_norm_gen

In [None]:
def val_const(sigma=1.):
    def val_const_gen(p):
        num = p.shape[0]
        v = numpy.ones(num) * sigma / num
        return v
    return val_const_gen

In [None]:
def val_unif(low, high, sigma=1.):
    def val_unif_gen(p):
        num = p.shape[0]
        v = numpy.random.uniform(low, high, num)
        v = v * sigma / numpy.sum(v)
        return v
    return val_unif_gen

In [None]:
def val_1d_norm_pdf(loc, scale, sigma=1.):
    norm = scipy.stats.norm(loc=loc, scale=scale)
    def val_1d_norm_pdf_gen(p):
        v = norm.pdf(p)
        v = v * sigma / numpy.sum(v)
        return v
    return val_1d_norm_pdf_gen

In [None]:
def dist_1d_euc_2(mup, nup):
    m, n = mup.size, nup.size
    ind = numpy.indices((m, n))
    c = (mup[ind[0]] - nup[ind[1]])**2
    return c

In [None]:
def ot_1d_general(m, n, mu_gen, nu_gen, mup_gen, nup_gen, dist):
    p = OTProblem1d()
    p.mup = mup_gen(m)
    p.nup = nup_gen(n)
    p.mu = mu_gen(p.mup)
    p.nu = nu_gen(p.nup)
    p.c = dist(p.mup, p.nup)
    return p

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*solve_mosek.py*w*sh*

In [None]:
import time

import numpy

import mosek

In [None]:
def mosek_set_model(p, task):
    m, n = p.c.shape
    
    inf = 0.
    
    task.appendvars(m*n)
    task.appendcons(m+n)
    
    task.putvarboundlist(
        range(m*n),
        [mosek.boundkey.lo]*(m*n),
        [0.]*(m*n),
        [inf]*(m*n)
    )
    
    for i in range(m):
        task.putarow(
            i,
            range(i*n, (i+1)*n),
            [1.]*n
        )
    task.putconboundlist(
        range(0, m),
        [mosek.boundkey.fx]*m,
        p.mu,
        p.mu
    )
    
    for i in range(n):
        task.putarow(
            i+m,
            range(i, i+m*n, n),
            [1.]*m
        )
    task.putconboundlist(
        range(m, m+n),
        [mosek.boundkey.fx]*n,
        p.nu,
        p.nu
    )
    
    task.putclist(range(m*n), p.c.reshape(m*n))
    
    task.putobjsense(mosek.objsense.minimize)

def solve_mosek(
    p,
    mtd=None, sol=None, var=None, it=None,
    log=None, stat=False,
    *args, **kwargs
):
    m, n = p.c.shape
    
    if stat:
        start_time = time.time()
    
    with mosek.Env() as env:
        env.set_Stream(mosek.streamtype.log, log)
        
        with env.Task() as task:
            task.set_Stream(mosek.streamtype.log, log)
            
            task.putintparam(mosek.iparam.optimizer, mtd)
            
            mosek_set_model(p, task)
            
            if stat:
                end_time = time.time()
            
            task.optimize()
            
            xx = [0.] * (m*n)
            task.getxx(sol, xx)
            
            p.sol = numpy.array(xx).reshape(m, n)
    
            if stat:
                setup_time = end_time - start_time
                s = {
                    "loss": task.getprimalobj(sol),
                    "vars": task.getintinf(var),
                    "iters": task.getintinf(it),
                    "setup": setup_time,
                    "solve": task.getdouinf(mosek.dinfitem.optimizer_time)
                }
                return p, s
            else:
                return p

In [None]:
def solve_mosek_primal_simplex(p, *args, **kwargs):
    return solve_mosek(
        p,
        mtd=mosek.optimizertype.primal_simplex,
        sol=mosek.soltype.bas,
        var=mosek.iinfitem.opt_numvar,
        it=mosek.iinfitem.sim_primal_iter,
        *args, **kwargs
    )

def solve_mosek_dual_simplex(p, *args, **kwargs):
    return solve_mosek(
        p,
        mtd=mosek.optimizertype.dual_simplex,
        sol=mosek.soltype.bas,
        var=mosek.iinfitem.opt_numvar,
        it=mosek.iinfitem.sim_dual_iter,
        *args, **kwargs
    )

def solve_mosek_interior_point(p, *args, **kwargs):
    return solve_mosek(
        p,
        mtd=mosek.optimizertype.intpnt,
        sol=mosek.soltype.itr,
        var=mosek.iinfitem.opt_numvar,
        it=mosek.iinfitem.intpnt_iter,
        *args, **kwargs
    )

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils_test.py*a*sh*

In [None]:
prob = ot_1d_general(
    m=50, n=50,
    mup_gen=samp_1d_norm(-5., 1.),
    nup_gen=samp_1d_grid(-5., -4.),
    mu_gen=val_unif(0., 1.),
    nu_gen=val_unif(0., 1.),
    dist=dist_1d_euc_2,
)

In [None]:
fh.fast(prob.plot_mu_scatter)

In [None]:
_, s = solve_mosek_primal_simplex(prob, stat=True)
print(s)

In [None]:
fh.fast(prob.plot_hotline)

In [None]:
fh.fast(prob.plot_scatter, alpha=0.5)

In [None]:
fh.fast(prob.plot_link)

In [None]:
prob = ot_1d_general(
    m=50, n=100,
    mup_gen=samp_1d_grid(-7., -3.),
    nup_gen=samp_1d_grid(-3., -1.),
    mu_gen=val_1d_norm_pdf(-5., 0.8),
    nu_gen=val_1d_norm_pdf(-2., 0.4),
    dist=dist_1d_euc_2,
)

In [None]:
fh.fast(prob.plot_mu_scatter)

In [None]:
solve_mosek_primal_simplex(prob)

In [None]:
fh.fast(prob.plot_link)

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils.py*a*sh*

In [None]:
class OTProblem2d(OTProblem):
    def __init__(self, *args, **kwargs):
        super(OTProblem2d, self).__init__(*args, **kwargs)
    
    def plot_mu_scatter(self, fh, *args, **kwargs):
        fh.ax.scatter(self.mup[:, 0], self.mup[:, 1], self.mu, *args, **kwargs)
    
    def plot_nu_scatter(self, fh, *args, **kwargs):
        fh.ax.scatter(self.nup[:, 0], self.nup[:, 1], self.nu, *args, **kwargs)
    
    def plot_mu_scatter_plain(self, fh, scale=30., *args, **kwargs):
        s = self.mu / self.mu.mean() * scale
        fh.ax.scatter(self.mup[:, 0], self.mup[:, 1], s=s, *args, **kwargs)
    
    def plot_nu_scatter_plain(self, fh, scale=30., *args, **kwargs):
        s = self.nu / self.nu.mean() * scale
        fh.ax.scatter(self.nup[:, 0], self.nup[:, 1], s=s, *args, **kwargs)
    
    def plot_link(
        self, fh,
        off_bx=0., off_by=0., off_ex=0., off_ey=0.,
        scatter=True,
        cutoff=0.2, scale=30.,
        aspect="auto", colorbar=True,
        sca_pos=(), sca_kw={},
        *args, **kwargs
    ):
        m, n = self.c.shape
        
        if scatter:
            mus = self.mu / self.mu.mean() * scale
            nus = self.nu / self.nu.mean() * scale
            fh.ax.scatter(self.mup[:, 0] + off_bx, self.mup[:, 1] + off_by, s=mus, *sca_pos, **sca_kw)
            fh.ax.scatter(self.nup[:, 0] + off_ex, self.nup[:, 1] + off_ey, s=nus, *sca_pos, **sca_kw)
        
        mean = self.sol.mean() * numpy.sqrt(m * n)
        mask = self.sol > cutoff * mean

        ind = numpy.indices((m, n))

        arr_b = self.mup[ind[0]][mask]
        arr_bx, arr_by = arr_b[:, 0] + off_bx, arr_b[:, 1] + off_by

        arr_e = self.nup[ind[1]][mask]
        arr_ex, arr_ey = arr_e[:, 0] + off_ex, arr_e[:, 1] + off_ey

        arr_dx = arr_ex - arr_bx
        arr_dy = arr_ey - arr_by

        arr_c = self.sol[mask]

        mpbl = fh.ax.quiver(arr_bx, arr_by, arr_dx, arr_dy, arr_c, angles="xy", scale_units="xy", scale=1., *args, **kwargs)
        if colorbar:
            fh.fig.colorbar(mpbl)
        
        fh.ax.set_aspect(aspect)
        return mpbl

In [None]:
def samp_norm(mean, cov):
    def samp_norm_gen(num):
        p = numpy.random.multivariate_normal(mean, cov, num)
        return p
    return samp_norm_gen

In [None]:
def samp_2d_grid(start_x, end_x, start_y, end_y):
    def samp_2d_grid_gen(num):
        num_t = num[0] * num[1]
        x = numpy.linspace(start_x, end_x, num[0])
        y = numpy.linspace(start_y, end_y, num[1])
        xp, yp = numpy.meshgrid(x, y)
        p = numpy.concatenate((xp.reshape((num_t, 1)), yp.reshape((num_t, 1))), axis=1)
        return p
    return samp_2d_grid_gen

In [None]:
def samp_2d_mid(start_x, end_x, start_y, end_y):
    def samp_2d_grid_gen(num):
        num_t = num[0] * num[1]
        step_x = (end_x - start_x) / num[0]
        x = numpy.linspace(start_x + step_x / 2., end_x - step_x / 2., num[0])
        step_y = (end_y - start_y) / num[1]
        y = numpy.linspace(start_y + step_y / 2., end_y - step_y / 2., num[1])
        xp, yp = numpy.meshgrid(x, y)
        p = numpy.concatenate((xp.reshape((num_t, 1)), yp.reshape((num_t, 1))), axis=1)
        return p
    return samp_2d_grid_gen

In [None]:
def samp_2d_ellipse(cen_x, r_x, cen_y, r_y, noi):
    def samp_2d_ellipse_gen(num):
        r = numpy.random.uniform(0, 2.*math.pi, num)
        dx = numpy.cos(r) + noi / math.sqrt(2.) * numpy.random.randn(num)
        dy = numpy.sin(r) + noi / math.sqrt(2.) * numpy.random.randn(num)
        x = r_x * dx + cen_x
        y = r_y * dy + cen_y
        p = numpy.concatenate((x.reshape(num, 1), y.reshape(num, 1)), axis=1)
        return p
    return samp_2d_ellipse_gen

In [None]:
def samp_2d_Caffarelli(cen_x, cen_y, r, d):
    def samp_2d_Caffarelli_gen(num):
        ox = numpy.random.uniform(-r, r, num)
        oy = numpy.random.uniform(-r, r, num)
        mask = ox**2 + oy**2 < r**2
        fx, fy = ox[mask], oy[mask]
        fx[fx < 0.] -= d / 2.
        fx[fx >= 0.] += d / 2.
        nnum = fx.size
        x = fx + cen_x
        y = fy + cen_y
        p = numpy.concatenate((x.reshape(nnum, 1), y.reshape(nnum, 1)), axis=1)
        return p
    return samp_2d_Caffarelli_gen

In [None]:
def val_norm_pdf(mean, cov, sigma=1.):
    norm = scipy.stats.multivariate_normal(mean=mean, cov=cov)
    def val_norm_pdf_gen(p):
        v = norm.pdf(p)
        v = v * sigma / numpy.sum(v)
        return v
    return val_norm_pdf_gen

In [None]:
def dist_2d_euc_2(mup, nup):
    m, n = mup.shape[0], nup.shape[0]
    ind = numpy.indices((m, n))
    c = ((mup[ind[0]] - nup[ind[1]])**2).sum(axis=2)
    return c

In [None]:
def ot_2d_general(m, n, mu_gen, nu_gen, mup_gen, nup_gen, dist):
    p = OTProblem2d()
    p.mup = mup_gen(m)
    p.nup = nup_gen(n)
    p.mu = mu_gen(p.mup)
    p.nu = nu_gen(p.nup)
    p.c = dist(p.mup, p.nup)
    return p

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils_test.py*a*sh*

In [None]:
prob = ot_2d_general(
    m=[5, 5], n=50,
    mup_gen=samp_2d_mid(0., 1., 0., 1.),
    nup_gen=samp_norm([2., 2.], [[1., 0.5], [0.5, 1.]]),
    mu_gen=val_unif(0., 1.),
    nu_gen=val_unif(0., 1.),
    dist=dist_2d_euc_2,
)

In [None]:
solve_mosek_primal_simplex(prob)

In [None]:
fh.fast(prob.plot_hotline)

In [None]:
fh.fast(prob.plot_mu_scatter, new_kw={"projection": "3d"})

In [None]:
fh.fast(prob.plot_mu_scatter_plain)

In [None]:
fh.fast(prob.plot_nu_scatter_plain)

In [None]:
fh.fast(prob.plot_link, scale=30., cutoff=0.5)

In [None]:
prob = ot_2d_general(
    m=100, n=100,
    mup_gen=samp_2d_ellipse(0., 2., 0., 0.5, 0.1),
    nup_gen=samp_2d_ellipse(0., 0.5, 0., 2., 0.1),
    mu_gen=val_const(),
    nu_gen=val_const(),
    dist=dist_2d_euc_2,
)

In [None]:
fh.fast(prob.plot_mu_scatter_plain)

In [None]:
solve_mosek_primal_simplex(prob)

In [None]:
fh.fast(prob.plot_link, scale=30., cutoff=0.5)

In [None]:
prob = ot_2d_general(
    m=500, n=500,
    mup_gen=samp_2d_Caffarelli(0., 0., 1., 0.),
    nup_gen=samp_2d_Caffarelli(0., 0., 1., 2.),
    mu_gen=val_const(),
    nu_gen=val_const(),
    dist=dist_2d_euc_2,
)

In [None]:
fh.fast(prob.plot_nu_scatter_plain)

In [None]:
solve_mosek_primal_simplex(prob)

In [None]:
fh.fast(prob.plot_link, scale=30., cutoff=0.5)

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils.py*a*sh*

In [None]:
class OTProblemImage(OTProblem2d):
    def __init__(self, *args, **kwargs):
        super(OTProblemImage, self).__init__(*args, **kwargs)
    
    def plot_mu_image(
        self, fh,
        origin="lower",
        colorbar=False,
        *args, **kwargs
    ):
        mwn = [self.mw[0], self.mw[1], self.mw[2], self.mw[3]]
        mpbl = fh.ax.imshow(self.mu.reshape(self.m[1], self.m[0]), extent=mwn, origin=origin, *args, **kwargs)
        if colorbar:
            fh.fig.colorbar(mpbl)
    
    def plot_nu_image(
        self, fh,
        origin="lower",
        colorbar=False,
        *args, **kwargs
    ):
        nwn = [self.nw[0], self.nw[1], self.nw[2], self.nw[3]]
        mpbl = fh.ax.imshow(self.nu.reshape(self.n[1], self.n[0]), extent=nwn, origin=origin, *args, **kwargs)
        if colorbar:
            fh.fig.colorbar(mpbl)
    
    def plot_link(
        self, fh,
        off_bx=0., off_by=0., off_ex=0., off_ey=0.,
        image_m=False, image_n=False,
        origin="lower",
        cutoff=0.2, scale=30.,
        aspect="auto", colorbar=False,
        im_pos=(), im_kw={},
        *args, **kwargs
    ):
        m, n = self.c.shape
        
        if image_m:
            mwn = [self.mw[0] + off_bx, self.mw[1] + off_bx, self.mw[2] + off_by, self.mw[3] + off_by]
            fh.ax.imshow(self.mu.reshape(self.m), extent=mwn, origin=origin, *im_pos, **im_kw)
        
        if image_n:
            nwn = [self.nw[0] + off_ex, self.nw[1] + off_ex, self.nw[2] + off_ey, self.nw[3] + off_ey]
            fh.ax.imshow(self.nu.reshape(self.n), extent=nwn, origin=origin, *im_pos, **im_kw)
        
        mean = self.sol.mean() * numpy.sqrt(m * n)
        mask = self.sol > cutoff * mean

        ind = numpy.indices((m, n))

        arr_b = self.mup[ind[0]][mask]
        arr_bx, arr_by = arr_b[:, 0] + off_bx, arr_b[:, 1] + off_by

        arr_e = self.nup[ind[1]][mask]
        arr_ex, arr_ey = arr_e[:, 0] + off_ex, arr_e[:, 1] + off_ey

        arr_dx = arr_ex - arr_bx
        arr_dy = arr_ey - arr_by

        arr_c = self.sol[mask]

        mpbl = fh.ax.quiver(arr_bx, arr_by, arr_dx, arr_dy, arr_c, angles="xy", scale_units="xy", scale=1., *args, **kwargs)
        if colorbar:
            fh.fig.colorbar(mpbl)
        
        fh.ax.set_aspect(aspect)
        return mpbl

In [None]:
def ot_im_general(m, n, mw, nw, mu_gen, nu_gen, mup_gen, nup_gen, dist):
    p = OTProblemImage()
    p.m = m
    p.n = n
    p.mw = mw
    p.nw = nw
    p.mup = mup_gen(m)
    p.nup = nup_gen(n)
    p.mu = mu_gen(p.mup)
    p.nu = nu_gen(p.nup)
    p.c = dist(p.mup, p.nup)
    return p

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils_test.py*a*sh*

In [None]:
prob = ot_im_general(
    m=[10, 10], n=[10, 10],
    mw=[0., 1., 0., 1.], nw=[0., 1., 0., 1.],
    mup_gen=samp_2d_mid(0., 1., 0., 1.),
    nup_gen=samp_2d_mid(0., 1., 0., 1.),
    mu_gen=val_unif(0., 1.),
    nu_gen=val_unif(0., 1.),
    dist=dist_2d_euc_2,
)

In [None]:
solve_mosek_primal_simplex(prob)

In [None]:
fh.fast(prob.plot_link, aspect="equal")

In [None]:
fh.refresh()
fh.subplot(1, 3, 1)
prob.plot_mu_image(fh)
fh.subplot(1, 3, 2)
prob.plot_link(fh, aspect="equal")
fh.subplot(1, 3, 3)
prob.plot_nu_image(fh)
fh.show()
fh.close()

In [None]:
prob = ot_im_general(
    m=[20, 10], n=[10, 30],
    mw=[0., 2., 0., 1.], nw=[0., 1., 0., 2.],
    mup_gen=samp_2d_mid(0., 2., 0., 1.),
    nup_gen=samp_2d_mid(0., 1., 0., 2.),
    mu_gen=val_norm_pdf([0.2, 0.8], [[0.2, 0.], [0., 0.2]]),
    nu_gen=val_norm_pdf([0.8, 0.2], [[0.2, 0.], [0., 0.2]]),
    dist=dist_2d_euc_2,
)

In [None]:
solve_mosek_primal_simplex(prob)

In [None]:
fh.fast(prob.plot_mu_image)

In [None]:
fh.refresh()
fh.subplot(1, 3, 1)
prob.plot_mu_image(fh)
fh.subplot(1, 3, 2)
prob.plot_mu_image(fh, alpha=0.3)
prob.plot_nu_image(fh, alpha=0.3)
prob.plot_link(fh, aspect="equal")
fh.subplot(1, 3, 3)
prob.plot_nu_image(fh)
fh.show()
fh.close()

In [None]:
prob = ot_2d_general(
    m=1000, n=1000,
    mup_gen=samp_2d_Caffarelli(0., 0., 1., 0.),
    nup_gen=samp_2d_Caffarelli(0., 0., 1., 4.),
    mu_gen=val_const(),
    nu_gen=val_const(),
    dist=dist_2d_euc_2,
)

In [None]:
print("Size = ({0}, {1})".format(prob.mu.size, prob.nu.size))

In [None]:
solve_mosek_primal_simplex(prob, log=print, stat=True)

In [None]:
solve_mosek_dual_simplex(prob, log=print, stat=True)

In [None]:
solve_mosek_interior_point(prob, log=print, stat=True)

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils.py*a*sh*

In [None]:
def relative_error(u, v):
    return numpy.linalg.norm(u - v) / numpy.linalg.norm(u)

In [None]:
general_merge_config = {
    "title": "first",
    "name": "first",
    "time": "mean+-stdev",
    "setup": "mean+-stdev",
    "solve": "mean+-stdev",
    "vars": "set",
    "iters": "mean",
    "loss": "mean+-stdev",
    "check": "mean+-stdev",
    "error_xx": "mean+-stdev",
}

In [None]:
general_output_config = {
    "title": ["Title", "{}"],
    "name": ["Function name", "{}"],
    "time": ["Time", "{0[0]:.5f}+-{0[1]:.5f}"],
    "setup": ["Setup time", "{0[0]:.5f}+-{0[1]:.5f}"],
    "solve": ["Solve time", "{0[0]:.5f}+-{0[1]:.5f}"],
    "vars": ["Variables", "{}"],
    "iters": ["Average iterables", "{:.3f}"],
    "loss": ["Loss", "{0[0]:.7e}+-{0[1]:.7e}"],
    "check": ["Check loss", "{0[0]:.7e}+-{0[1]:.7e}"],
    "error_xx": ["Error to known", "{0[0]:.7e}+-{0[1]:.7e}"],
}

In [None]:
def merge_stats(stats, config):
    d = {}
    for k, m in config.items():
        if k not in stats[0].keys():
            continue
        if m == "mean":
            d[k] = statistics.mean(s[k] for s in stats)
        elif m == "stdev":
            d[k] = statistics.stdev(s[k] for s in stats)
        elif m == "mean+-stdev":
            d[k] = [statistics.mean(s[k] for s in stats), statistics.stdev(s[k] for s in stats)]
        elif m == "first":
            d[k] = stats[0][k]
        elif m == "list":
            d[k] = [s[k] for s in stats]
        elif m == "set":
            d[k] = {s[k] for s in stats}
    return d

In [None]:
def format_output(res, config):
    for k, v in config.items():
        if k in res:
            n = v[0]
            rp = v[1].format(res[k])
            print("{0}: {1}".format(n, rp))

In [None]:
class Statistics(object):
    def __init__(
        self,
        probs=None,
        merge_config=general_merge_config,
        output_config=general_output_config
    ):
        self.len = len(probs)
        self.probs = probs
        self.merge_config = merge_config
        self.output_config = output_config
        self.stats = []
        self.ress = []
    
    def set_cx(self, func, log=None, *args, **kwargs):
        for i in range(self.len):
            if log is not None:
                log("Setting {0}/{1}".format(i, self.len))
            func(self.probs[i], *args, **kwargs)
            self.probs[i].set_cx()
            self.probs[i].clean()
    
    def test_piece(self, prob, func, title="", clean=True, *args, **kwargs):
        start_time = time.time()
        prob, stat = func(prob, stat=True, *args, **kwargs)
        end_time = time.time()
        elapsed_time = end_time - start_time
        
        check_loss = numpy.sum(prob.sol * prob.c)
        
        stat["title"] = title
        stat["name"] = func.__name__
        stat["time"] = elapsed_time
        stat["check"] = check_loss
        if prob.cx is not None:
            stat["error_xx"] = relative_error(prob.cx, prob.sol)
        
        if clean:
            prob.clean()
        return stat
    
    def test(self, func, title="", log=None, *args, **kwargs):
        ss = []
        for i in range(self.len):
            if log is not None:
                log("Testing {0}/{1}".format(i, self.len))
            s = self.test_piece(self.probs[i], func, title, *args, **kwargs)
            ss.append(s)
        r = merge_stats(ss, self.merge_config)
        self.stats.append(ss)
        self.ress.append(r)
    
    def clean_last(self):
        for i in range(self.len):
            self.probs[i].clean()
    
    def output_last(self):
        format_output(self.ress[-1], self.output_config)

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*utils_test.py*a*sh*

In [None]:
stat = Statistics(
    probs=[
        ot_2d_general(
            m=500, n=500,
            mup_gen=samp_2d_Caffarelli(0., 0., 1., 0.),
            nup_gen=samp_2d_Caffarelli(0., 0., 1., 4.),
            mu_gen=val_const(),
            nu_gen=val_const(),
            dist=dist_2d_euc_2,
        ) for i in range(10)
    ],
    merge_config=general_merge_config,
    output_config=general_output_config,
)

In [None]:
stat.set_cx(solve_mosek_interior_point, log=print)

In [None]:
stat.test(solve_mosek_primal_simplex, title="MOSEK, simplex for primal", log=print, clean=False)
stat.output_last()
fh.fast(stat.probs[0].plot_link)
stat.clean_last()

In [None]:
stat.test(solve_mosek_dual_simplex, title="MOSEK, simplex for dual", log=print)
stat.output_last()

In [None]:
stat.test(solve_mosek_interior_point, title="MOSEK, interior point", log=print, clean=False)
stat.output_last()
fh.fast(stat.probs[0].plot_link)
stat.clean_last()

In [None]:
# !ConvertEnd*