In [None]:
# !Convert*discrete_transportation_simplex.py*w*sh*

In [None]:
import math

import numpy as np

In [None]:
def init(m, n):
    record = {i:set({}) for i in range(-n, m+1) if i != 0}
    sol = np.zeros((m, n))
    mu_dual = np.zeros((m))
    nu_dual = np.zeros((n))
    c_dual = np.zeros((n, m))
    return record, sol, mu_dual, nu_dual, c_dual

In [None]:
def find_solution(m, n, mu, nu, record, sol):
    """
    Find a basis solution.
    """
    j = 0
    nu1 = nu.copy()

    for i in range(m):
        tmp = mu[i]

        while j < n:
            if nu1[j] >= tmp:
                nu1[j] = nu1[j] - tmp
                sol[i, j] = tmp
                record[i+1].add(-j-1)
                record[-j-1].add(i+1)
                tmp = 0
                if nu1[j] == tmp and i < m-1:
                    record[i+2].add(-j-1)
                    record[-j-1].add(i+2)
                    j = j + 1
                break
            else:
                sol[i, j] = nu1[j]
                record[i+1].add(-j-1)
                record[-j-1].add(i+1)
                tmp = tmp - nu1[j]
                nu1[j] = 0
                j = j + 1
    
    return record, sol

In [None]:
def primal_value(m, n, dist, record, sol):
    """
    Calculate the primal value with respect to the solution.
    """
    acc = sum(sol[i-1, -j-1] * dist[i-1, -j-1] for i in range(1, m+1) for j in record[i])
    return acc

In [None]:
def find_dual(m, n, dist, record, mu_dual, nu_dual, c_dual):
    """
    Find the dual variables according to the primal variables.
    """
    mu_dual[0] = 0.
    queue = [1]
    father = {1:1}
    u = 0

    while len(queue) < m + n:
        if u >= len(queue):
            break
        i = queue[u]
        u = u + 1
        dic = record[i]

        for item in dic:
            if item not in father.keys():
                queue.append(item)
                father[item] = i
                if item > 0:
                    mu_dual[item-1] = dist[item-1, abs(i)-1] - nu_dual[-i-1]
                else:
                    nu_dual[-item-1] = dist[i-1, abs(item)-1] - mu_dual[i-1]

    c_dual = dist.T - mu_dual.reshape((1, m)) - nu_dual.reshape((n, 1))
    
    return mu_dual, nu_dual, c_dual

In [None]:
def find_loop(u, v, record):
    """
    Find a loop in the graph after adding the edge u--v
    """
    set1 = [v]
    visit= {v:0}
    flag = 0
    w = 0

    while True:
        now = set1[w]
        w = w + 1
        for i in record[now]:
            if i not in visit:
                set1.append(i)
                visit[i] = now
            if i == u:
                flag = 1
                break
        if flag == 1:
            break

    l = []
    t = u

    while t != v:
        l.append((t, visit[t]))
        t = visit[t]
    l.append((v, u))

    return l

In [None]:
def update(u, v, dist, record, sol):
    """
    Update the graph after adding the edge u--v.
    """
    loop = find_loop(u, v, record)
    record[u].add(v)
    record[v].add(u)
    min1 = float('inf')
    dis1 = sum(dist[i-1, abs(j)-1] for i,j in loop if i > 0)
    dis2 = sum(dist[j-1, abs(i)-1] for i,j in loop if i < 0)

    for i,j in loop:
        if i > 0:
            if min1 > sol[i-1, -j-1]:
                min1 = sol[i-1, -j-1]
                x1, y1 = i, j

    if dis1 <= dis2:
        record[u].remove(v)
        record[v].remove(u)
    else:
        for i,j in loop:
            if i > 0:
                sol[i-1, -j-1] -= min1
            else:
                sol[j-1, -i-1] += min1
        record[x1].remove(y1)
        record[y1].remove(x1)
    
    return record, sol

In [None]:
def solve_transportation_simplex(
    p,
    eps, it, scale=None,
    fh=None, figs={}, log=None, stat=False,
    *args, **kwargs
):
    m, n = p.c.shape
    
    if scale is None:
        scale = math.sqrt(m * n)
    
    mu, nu = scale*p.mu, scale*p.nu
    dist = p.c
    
    if fh is not None:
        if "value" in figs:
            value = []
    
    record, sol, mu_dual, nu_dual, c_dual = init(m, n)
    
    its = 0
    
    record, sol = find_solution(m, n, mu, nu, record, sol)
    
    while its < it:
        mu_dual, nu_dual, c_dual = find_dual(m, n, dist, record, mu_dual, nu_dual, c_dual)
        
        pos = np.argmin(c_dual)
        u, v = pos//m, pos%m

        if c_dual[u, v] >= -eps:
            break

        record, sol = update(v+1, -u-1, dist, record, sol)
        
        its += 1

        if fh is not None:
            if "value" in figs:
                value.append(primal_value(m, n, dist, record, sol) / scale)

        if log is not None:
            log("its = {0}".format(its))
    
    p.s = sol / scale
    
    if fh is not None:
        if "value" in figs:
            fh.new(1, 1, 1)
            fh.ax.plot(np.array(value), label="Primal value")
            fh.ax.legend()
            fh.show()
            fh.close()
    
    if stat:
        s = {
            "title": "Transportation simplex",
            "loss": primal_value(m, n, dist, record, sol) / scale,
            "iters": its,
        }
        return p, s
    else:
        return p

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*discrete_transportation_simplex_test.py*w*sehx*

In [None]:
# !Switch*
from handler import FigureHandler
from dataset import ot_2d_Caffarelli
from stats import Statistics
# !SwitchCase*
# import font
# from handler import FigureHandler
# from dataset import ot_2d_Caffarelli
# from stats import Statistics
# from discrete_transportation_simplex import solve_transportation_simplex
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(redir=True)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", redir=True)
# !SwitchEnd*

In [None]:
stat = Statistics(
    probs=[
        ot_2d_Caffarelli(100, 100, 1)
    ],
    prob="Test problems",
    log=fh.write,
)

In [None]:
stat.test(
    solve_transportation_simplex,
    eps=1e-12,
    it=6000,
    fh=fh, figs={"value"},
)
stat.output_last()

In [None]:
stat.test(
    solve_transportation_simplex,
    eps=1e-12,
    it=6000,
)
stat.output_last()

In [None]:
from solver_mosek import solve_mosek_interior_point

In [None]:
stat.test(
    solve_mosek_interior_point,
)
stat.output_last()

In [None]:
# !ConvertEnd*