In [None]:
# !Convert*entropy_sinkhorn.py*w*sh*

In [None]:
import math

import numpy

In [None]:
def init(m, n, c, gamma):
    u = numpy.ones(m)
    v = numpy.ones(n)
    xi = numpy.exp(-c / gamma)
    return u, v, xi

In [None]:
def update(u, v, xi, mu, nu):
    u = mu / (xi.dot(v))
    v = nu / (xi.T.dot(u))
    return u, v

In [None]:
def recover(m, n, u, v, xi):
    pi = u.reshape((m, 1)) * xi * v.reshape((1, n))
    return pi

In [None]:
def solve_sinkhorn(
    p,
    it, gamma, scale=None,
    fh=None, figs={}, stat=False,
    *args, **kwargs
):
    m, n = p.c.shape
    
    if scale is None:
        scale = math.sqrt(m*n)
    
    mu, nu = scale*p.mu, scale*p.nu
    c = p.c
    
    if fh is not None:
        if "error" in figs:
            error_mu = []
            error_nu = []
        if "loss" in figs:
            loss = []
    
    u, v, xi = init(m, n, c, gamma)
    
    itc = 0
    
    for i in range(it):
        u, v = update(u, v, xi, mu, nu)
        
        itc += 1
        
        if fh is not None:
            s = recover(m, n, u, v, xi)
            if "error" in figs:
                error_mu.append(numpy.linalg.norm(s.sum(axis=1) - mu, 1) / scale)
                error_nu.append(numpy.linalg.norm(s.sum(axis=0) - nu, 1) / scale)
            if "loss" in figs:
                loss.append((c * s).sum() / scale)
    
    p.s = recover(m, n, u, v, xi) / scale
    
    if fh is not None:
        if "error" in figs:
            fh.new(1, 1, 1)
            fh.ax.semilogy(numpy.array(error_mu), label="Error of mu")
            fh.ax.semilogy(numpy.array(error_nu), label="Error of nu")
            fh.ax.legend()
            fh.show()
            fh.close()
        if "loss" in figs:
            fh.new(1, 1, 1)
            fh.ax.plot(numpy.array(loss), label="Loss")
            fh.ax.legend()
            fh.show()
            fh.close()
    
    if stat:
        s = {
            "title": "Sinkhorn algorithm",
            "size": [m, n],
            "loss": (c * p.s).sum(),
            "vars": m*n + m + n,
            "iters": itc,
        }
        return p, s
    else:
        return p
        
        return p

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*entropy_sinkhorn_test.py*w*sehx*

In [None]:
# !Switch*
from handler import FigureHandler
from dataset import ot_2d_Caffarelli
from stats import Statistics
# !SwitchCase*
# import font
# from handler import FigureHandler
# from dataset import ot_2d_Caffarelli
# from stats import Statistics
# from entropy_sinkhorn import solve_sinkhorn
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(redir=True)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", redir=True)
# !SwitchEnd*

In [None]:
stat = Statistics(
    probs=[
        ot_2d_Caffarelli(500, 500, 1)
    ],
    prob="Test problems",
    log=fh.write,
)

In [None]:
stat.test(
    solve_sinkhorn,
    it=2000, gamma=1e-2,
)
stat.output_last()

In [None]:
from solver_mosek import solve_mosek_interior_point

In [None]:
stat.test(
    solve_mosek_interior_point,
)
stat.output_last()

In [None]:
# !ConvertEnd*