In [None]:
# !Convert*discrete_transportation_simplex_networkx.py*w*sh*

In [None]:
import math

import numpy as np

import networkx as nx
from networkx.algorithms import bipartite

In [None]:
def init(m, n):
    g = nx.Graph()
    g.add_nodes_from(range(1, m+1, 1), bipartite=0)
    g.add_nodes_from(range(-1, -n-1, -1), bipartite=1)
    mu_dual = np.zeros((m))
    nu_dual = np.zeros((n))
    c_dual = np.zeros((n, m))
    return g, mu_dual, nu_dual, c_dual

In [None]:
def find_solution(m, n, mu, nu, g):
    """
    Find a basis solution.
    """
    j = 0
    nu1 = nu.copy()

    for i in range(m):
        tmp = mu[i]

        while j < n:
            if nu1[j] >= tmp:
                nu1[j] = nu1[j] - tmp
                g.add_edge(i+1, -j-1, weight=tmp)
                tmp = 0
                if nu1[j] == tmp and i < m-1:
                    g.add_edge(i+2, -j-1, weight=0)
                    j = j + 1
                break
            else:
                g.add_edge(i+1, -j-1, weight=nu1[j])
                tmp = tmp - nu1[j]
                nu1[j] = 0
                j = j + 1
    
    return g

In [None]:
def primal_value(m, n, dist, g):
    """
    Calculate the primal value with respect to the solution.
    """
    acc = sum(g.edges[x, y]['weight'] * dist[abs(x)-1, abs(y)-1] for x, y in g.edges)
    return acc

In [None]:
def find_dual(m, n, dist, g, mu_dual, nu_dual, c_dual):
    """
    Find the dual variables according to the primal variables.
    """
    mu_dual[1] = 0
    queue = [1]
    father = {1:1}
    u = 0

    while len(queue) < m + n:
        if u >= len(queue):
            break
        i = queue[u]
        u = u + 1
        dic = g.neighbors(i)

        for item in dic:
            if item not in father.keys():
                queue.append(item)
                father[item] = i
                if item > 0:
                    mu_dual[item-1] = dist[item-1, abs(i)-1] - nu_dual[-i-1]
                else:
                    nu_dual[-item-1] = dist[i-1, abs(item)-1] - mu_dual[i-1]

    c_dual = dist.T - mu_dual.reshape((1, m)) - nu_dual.reshape((n, 1))
    
    return mu_dual, nu_dual, c_dual

In [None]:
def find_loop(u, v, g):
    """
    Find a loop in the graph after adding the edge u--v
    """
    set1 = [v]
    visit= {v:0}
    flag = 0
    w = 0

    while True:
        now = set1[w]
        w = w + 1
        for i in g.neighbors(now):
            if i not in visit:
                set1.append(i)
                visit[i] = now
            if i == u:
                flag = 1
                break
        if flag == 1:
            break

    l = []
    t = u

    while t != v:
        l.append((t, visit[t]))
        t = visit[t]
    l.append((v, u))

    return l

In [None]:
def update(u, v, dist, g):
    """
    Update the graph after adding the edge u--v.
    """
    loop = find_loop(u, v, g)
    g.add_edge(u, v, weight=0)
    min1 = float('inf')
    dis1 = sum(dist[i-1, abs(j)-1] for i,j in loop if i > 0)
    dis2 = sum(dist[j-1, abs(i)-1] for i,j in loop if i < 0)

    for i,j in loop:
        if i > 0:
            if min1 > g.edges[i, j]['weight']:
                min1 = g.edges[i, j]['weight']
                x1, y1 = i, j

    if dis1 <= dis2:
        g.remove_edge(u, v)
    else:
        for i,j in loop:
            if i > 0:
                g.edges[i, j]['weight'] -= min1
            else:
                g.edges[i, j]['weight'] += min1
        g.remove_edge(x1, y1)
    
    return g

In [None]:
def export_solution(m, n, g):
    """
    Export the solution in a numpy array.
    """
    arr = np.zeros((m, n))
    for u, v in g.edges:
        arr[u-1, -v-1] = g.edges[u, v]['weight']
    return arr

In [None]:
def draw_graph(g, fh):
    """
    Draw an explicit graph of the graph
    """
    x, y = bipartite.sets(g)
    pos = dict()
    pos.update((u, (1, i)) for i, u in enumerate(x))
    pos.update((u, (2, i)) for i, u in enumerate(y))
    nx.draw(g, pos=pos, ax=fh.ax, with_labels=True)

In [None]:
def solve_transportation_simplex_networkx(
    p,
    eps, it, scale=None,
    fh=None, figs={}, log=None, stat=False,
    *args, **kwargs
):
    m, n = p.c.shape
    
    if scale is None:
        scale = math.sqrt(m * n)
    
    mu, nu = scale*p.mu, scale*p.nu
    dist = p.c
    
    if fh is not None:
        if "value" in figs:
            value = []
    
    g, mu_dual, nu_dual, c_dual = init(m, n)
    
    its = 0
    
    g = find_solution(m, n, mu, nu, g)
    
    while its < it:
        mu_dual, nu_dual, c_dual = find_dual(m, n, dist, g, mu_dual, nu_dual, c_dual)
        
        pos = np.argmin(c_dual)
        u, v = pos//m, pos%m

        if c_dual[u, v] >= -eps:
            break

        g = update(v+1, -u-1, dist, g)
        
        its += 1

        if fh is not None:
            if "value" in figs:
                value.append(primal_value(m, n, dist, g) / scale)

        if log is not None:
            log("its = {0}".format(its))
    
    p.s = export_solution(m, n, g) / scale
    
    if fh is not None:
        if "value" in figs:
            fh.new(1, 1, 1)
            fh.ax.plot(np.array(value), label="Primal value")
            fh.ax.legend()
            fh.show()
            fh.close()
        if "graph" in figs:
            fh.new(1, 1, 1)
            draw_graph(g, fh)
            fh.show()
            fh.close()
    
    if stat:
        s = {
            "title": "Transportation simplex with networkx",
            "loss": primal_value(m, n, dist, g) / scale,
            "iters": its,
        }
        return p, s
    else:
        return p

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*discrete_transportation_simplex_networkx_test.py*w*sehx*

In [None]:
# !Switch*
from handler import FigureHandler
from dataset import ot_2d_Caffarelli
from stats import Statistics
# !SwitchCase*
# import font
# from handler import FigureHandler
# from dataset import ot_2d_Caffarelli
# from stats import Statistics
# from discrete_transportation_simplex_networkx import solve_transportation_simplex_networkx
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(redir=True)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", redir=True)
# !SwitchEnd*

In [None]:
stat = Statistics(
    probs=[
        ot_2d_Caffarelli(10, 10, 1)
    ],
    prob="Test problems",
    log=fh.write,
)

In [None]:
stat.test(
    solve_transportation_simplex_networkx,
    eps=1e-12,
    it=6000,
    fh=fh, figs={"value", "graph"},
)
stat.output_last()

In [None]:
stat = Statistics(
    probs=[
        ot_2d_Caffarelli(100, 100, 1)
    ],
    prob="Test problems",
    log=fh.write,
)

In [None]:
stat.test(
    solve_transportation_simplex_networkx,
    eps=1e-12,
    it=6000,
    fh=fh, figs={"value"},
)
stat.output_last()

In [None]:
stat.test(
    solve_transportation_simplex_networkx,
    eps=1e-12,
    it=6000,
)
stat.output_last()

In [None]:
from solver_mosek import solve_mosek_interior_point

In [None]:
stat.test(
    solve_mosek_interior_point,
)
stat.output_last()

In [None]:
# !ConvertEnd*