In [1]:
import numpy as np
import time
import sys
import os
from scipy.optimize import linprog 
#from exact_otc import exact_otc

sys.path.append(os.path.abspath("../src"))
from pyotc.otc_backend.graph.utils import adj_to_trans, get_degree_cost
from pyotc.examples.stochastic_block_model import stochastic_block_model

  __import__('pkg_resources').declare_namespace(__name__)


In [2]:
import numpy as np
from numpy.linalg import pinv
from scipy.optimize import linprog
import copy
import ot


def get_ind_tc(Px, Py):
    dx, dx_col = Px.shape
    dy, dy_col = Py.shape

    P_ind = np.zeros((dx*dy, dx_col*dy_col))
    for x_row in range(dx):
        for x_col in range(dx_col):
            for y_row in range(dy):
                for y_col in range(dy_col):
                    idx1 = dy*(x_row) + y_row
                    idx2 = dy*(x_col) + y_col
                    P_ind[idx1, idx2] = Px[x_row, x_col] * Py[y_row, y_col]
    return P_ind


def exact_tce(Pz, c):
    d = Pz.shape[0]
    c = np.reshape(c, (d, -1))
    A = np.block([[np.eye(d) - Pz, np.zeros((d, d)), np.zeros((d, d))],
                  [np.eye(d), np.eye(d) - Pz, np.zeros((d, d))],
                  [np.zeros((d, d)), np.eye(d), np.eye(d) - Pz]])
    b = np.concatenate([np.zeros((d, 1)), c, np.zeros((d, 1))])
    try:
        sol = np.linalg.solve(A, b)
    #except np.linalg.LinAlgError:
    except:
        sol = np.matmul(pinv(A), b)

    g = sol[0:d].flatten()
    h = sol[d:2*d].flatten()
    return g, h


def computeot_pot(C, r, c):
    # Ensure r and c are numpy arrays
    r = np.array(r).flatten()
    c = np.array(c).flatten()

    # Compute the optimal transport plan and the cost using the ot.emd function
    lp_sol = ot.emd(r, c, C)
    lp_val = np.sum(lp_sol * C)

    return lp_sol, lp_val


def check_constant(f, threshold=1e-3):  
    d = f.shape[0]
    f_const = True
    for i in range(d):
        for j in range(i + 1, d):
            if abs(f[i] - f[j]) > threshold:
                f_const = False
                break
        if not f_const:
            break
    return f_const


def setup_ot(f, Px, Py, Pz):   
    dx = Px.shape[0]
    dy = Py.shape[0]
    f_mat = np.reshape(f, (dx, dy))
    for x_row in range(dx):
        for y_row in range(dy):
            dist_x = Px[x_row, :]
            dist_y = Py[y_row, :]
            # Check if either distribution is degenerate.
            if any(dist_x == 1) or any(dist_y == 1):
                sol = np.outer(dist_x, dist_y)
            # If not degenerate, proceed with OT.
            else:
                sol, val = computeot_pot(f_mat, dist_x, dist_y)
            idx = dy * (x_row) + y_row
            Pz[idx, :] = np.reshape(sol, (-1, dx * dy))
    return Pz


def exact_tci(g, h, P0, Px, Py):
    # Check if g is constant.
    dx = Px.shape[0]
    dy = Py.shape[0]
    Pz = np.zeros((dx * dy, dx * dy))
    g_const = check_constant(f=g)
    
    # If g is not constant, improve transition coupling against g.
    if not g_const:
        Pz = setup_ot(f=g, Px=Px, Py=Py, Pz=Pz)
        if np.max(np.abs(np.matmul(P0, g) - np.matmul(Pz, g))) <= 1e-7:
            Pz = copy.deepcopy(P0)
        else:
            return Pz
        
    # Try to improve with respect to h.
    Pz = setup_ot(f=h, Px=Px, Py=Py, Pz=Pz)
    if np.max(np.abs(np.matmul(P0, h) - np.matmul(Pz, h))) <= 1e-4:
        Pz = copy.deepcopy(P0)
        
    return Pz

# def exact_tci(g, h, P0, Px, Py):
#     dx = Px.shape[0]
#     dy = Py.shape[0]
#     Pz = np.zeros((dx*dy, dx*dy))
    
#     g_const = True
#     for i in range(dx):
#         for j in range(i+1, dx):
#             if abs(g[i] - g[j]) > 1e-3:
#                 g_const = False
#                 break
#         if not g_const:
#             break
#     # If g is not constant, improve transition coupling against g.
#     if not g_const:
#         g_mat = np.reshape(g, (dx, dy))
#         for x_row in range(dx):
#             for y_row in range(dy):
#                 dist_x = Px[x_row, :]
#                 dist_y = Py[y_row, :]
#                 # Check if either distribution is degenerate.
#                 if any(dist_x == 1) or any(dist_y == 1):
#                     sol = np.outer(dist_x, dist_y)
#                 # If not degenerate, proceed with OT.
#                 else:
#                     sol, val = computeot_pot(g_mat, dist_x, dist_y)
#                 idx = dy*(x_row)+y_row
#                 Pz[idx, :] = np.reshape(sol, (-1, dx*dy))
#                 #P[idx, :] = sol
#         if np.max(np.abs(np.matmul(P0, g) - np.matmul(Pz, g))) <= 1e-7:
#             Pz = copy.deepcopy(P0)
#         else:
#             return Pz
#     ## Try to improve with respect to h.
#     h_mat = np.reshape(h, (dx, dy))
#     for x_row in range(dx):
#         for y_row in range(dy):
#             dist_x = Px[x_row, :]
#             dist_y = Py[y_row, :]
#             # Check if either distribution is degenerate.
#             if any(dist_x == 1) or any(dist_y == 1):
#                 sol = np.outer(dist_x, dist_y)
#             # If not degenerate, proceed with OT.
#             else:
#                 sol, val = computeot_pot(h_mat, dist_x, dist_y)
#             idx = dy*(x_row)+y_row
#             # print(x_row, y_row, P0[18, 1])
#             Pz[idx, :] = np.reshape(sol, (-1, dx*dy))

#     if np.max(np.abs(np.matmul(P0, h) - np.matmul(Pz, h))) <= 1e-4:
#         Pz = copy.deepcopy(P0)
#     return Pz


def get_best_stat_dist(P, c):
    # Set up constraints.
    n = P.shape[0]
    c = np.reshape(c, (n, -1))
    Aeq = np.concatenate((P.T - np.eye(n), np.ones((1, n))), axis = 0)
    beq = np.concatenate((np.zeros((n, 1)), 1), axis = None)
    beq = beq.reshape(-1,1)
    bound = [[0, None]] * n
    
    # Solve linear program.
    res = linprog(c, A_eq=Aeq, b_eq=beq, bounds=bound)
    stat_dist = res.x
    exp_cost = res.fun
    
    return stat_dist, exp_cost


def exact_otc(Px, Py, c):
    dx = Px.shape[0]
    dy = Py.shape[0]

    P_old = np.ones((dx*dy, dx*dy))
    P = get_ind_tc(Px, Py)
    iter_ctr = 0
    while np.max(np.abs(P-P_old)) > 1e-10:
        print(np.max(np.abs(P-P_old)))
        print(f"Iteration {iter_ctr}")
        iter_ctr += 1
        P_old = np.copy(P)

        # Transition coupling evaluation.
        g, h = exact_tce(P, c)

        # Transition coupling improvement.
        P = exact_tci(g, h, P_old, Px, Py)

        # Check for convergence.
        if np.all(P == P_old):
            stat_dist, exp_cost = get_best_stat_dist(P,c)
            stat_dist = np.reshape(stat_dist, (dx, dy))
            return exp_cost, P, stat_dist

    return None, None, None

In [3]:
# Seed number
np.random.seed(1004)

In [4]:
m = 10
A1 = stochastic_block_model(
    (m, m, m, m),
    np.array(
        [
            [0.9, 0.1, 0.1, 0.1],
            [0.1, 0.9, 0.1, 0.1],
            [0.1, 0.1, 0.9, 0.1],
            [0.1, 0.1, 0.1, 0.9],
        ]
    ),
)
A2 = stochastic_block_model(
    (m, m, m, m),
    np.array(
        [
            [0.9, 0.1, 0.1, 0.1],
            [0.1, 0.9, 0.1, 0.1],
            [0.1, 0.1, 0.9, 0.1],
            [0.1, 0.1, 0.1, 0.9],
        ]
    ),
)
P1 = adj_to_trans(A1)
P2 = adj_to_trans(A2)
c = get_degree_cost(A1, A2)

start = time.time()
exp_cost, otc, stat_dist = exact_otc(P1, P2, c)
end = time.time()
print(exp_cost, end - start)

1.0
Iteration 0
0.109375
Iteration 1
0.125
Iteration 2
0.125
Iteration 3
0.1111111111111111
Iteration 4
0.710306974693498 26.720885038375854


In [5]:
Px = P1
Py = P2
c = get_degree_cost(A1, A2)
dx = Px.shape[0]
dy = Py.shape[0]

P_old = np.ones((dx*dy, dx*dy))
P = get_ind_tc(Px, Py)
P_old = np.copy(P)

# Transition coupling evaluation.
g, h = exact_tce(P, c)
P = exact_tci(g, h, P_old, Px, Py)
np.max(np.abs(P-P_old))

0.109375

In [6]:
g

array([5.34900015, 5.34900015, 5.34900015, ..., 5.34900015, 5.34900015,
       5.34900015])

In [7]:
h

array([ 9.54509398,  0.1623038 , -2.80030919, ..., -5.30228197,
       -4.62557099, -1.71031044])

In [8]:
P[10][200:400]

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.08333333, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.     

In [31]:
P_old = np.copy(P)
g, h = exact_tce(P, c)
P = exact_tci(g, h, P_old, Px, Py)
np.max(np.abs(P-P_old))

0.125

In [32]:
P_old = np.copy(P)
g, h = exact_tce(P, c)
P = exact_tci(g, h, P_old, Px, Py)
np.max(np.abs(P-P_old))

0.125

In [33]:
P_old = np.copy(P)
g, h = exact_tce(P, c)
# P = exact_tci(g, h, P_old, Px, Py)
# np.max(np.abs(P-P_old))

In [34]:
g

array([0.71031912, 0.71031912, 0.71031912, ..., 0.71031912, 0.71031912,
       0.71031912])

In [35]:
h

array([16.08253384,  3.73351632,  0.72138359, ..., -0.08382943,
        0.22315136,  2.82742524])

In [18]:
Pz = P
d = Pz.shape[0]
c = np.reshape(c, (d, -1))
A = np.block([[np.eye(d) - Pz, np.zeros((d, d)), np.zeros((d, d))],
                [np.eye(d), np.eye(d) - Pz, np.zeros((d, d))],
                [np.zeros((d, d)), np.eye(d), np.eye(d) - Pz]])
b = np.concatenate([np.zeros((d, 1)), c, np.zeros((d, 1))])

In [28]:
g

array([4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140268,
       4.88140268, 4.88140268, 4.88140268, 4.88140268, 4.88140

In [15]:
P0 = P_old
dx = Px.shape[0]
dy = Py.shape[0]
Pz = np.zeros((dx * dy, dx * dy))
g_const = check_constant(f=g)
g_const
# If g is not constant, improve transition coupling against g.
# if not g_const:
    # Pz = setup_ot(f=g, Px=Px, Py=Py, Pz=Pz)
    # if np.max(np.abs(np.matmul(P0, g) - np.matmul(Pz, g))) <= 1e-7:
    #     print(111)
    #     Pz = copy.deepcopy(P0)
    # else:
    #     return Pz
    
# # Try to improve with respect to h.
# Pz = setup_ot(f=h, Px=Px, Py=Py, Pz=Pz)
# if np.max(np.abs(np.matmul(P0, h) - np.matmul(Pz, h))) <= 1e-4:
#     Pz = copy.deepcopy(P0)

False

In [22]:
f = g 
dx = Px.shape[0]
dy = Py.shape[0]
f_mat = np.reshape(f, (dx, dy))
for x_row in range(dx):
    for y_row in range(dy):
        dist_x = Px[x_row, :]
        dist_y = Py[y_row, :]
        # Check if either distribution is degenerate.
        if any(dist_x == 1) or any(dist_y == 1):
            print(x_row, y_row, 111)
            sol = np.outer(dist_x, dist_y)
        # If not degenerate, proceed with OT.
        else:
            sol, val = computeot_pot(f_mat, dist_x, dist_y)
            if x_row == 0 and y_row == 4:
                print("sol", sol, val)
        idx = dy * (x_row) + y_row
        Pz[idx, :] = np.reshape(sol, (-1, dx * dy))


sol [[0.         0.         0.         0.         0.         0.
  0.         0.         0.        ]
 [0.         0.         0.         0.33333333 0.         0.
  0.         0.         0.        ]
 [0.         0.33333333 0.         0.         0.         0.
  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.33333333
  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.        ]] 1.5194805194805077
0 5 111
1 5 111
2 5 111
3 5 111
4 5 111
5 5 111
6 5 111
7 5 111
8 5 111


In [13]:
Pz[0]

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.16666667, 0.16666667, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.33333333,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.33333333, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.     

In [None]:
""