In [2]:
import numpy as np
from numpy.linalg import pinv
from scipy.optimize import linprog
from scipy.linalg import fractional_matrix_power
import copy

In [3]:
def get_ind_tc(Px, Py):
    dx, dx_col = Px.shape
    dy, dy_col = Py.shape

    P_ind = np.zeros((dx * dy, dx_col * dy_col))
    for x_row in range(dx):
        for x_col in range(dx_col):
            for y_row in range(dy):
                for y_col in range(dy_col):
                    idx1 = dy * (x_row) + y_row
                    idx2 = dy * (x_col) + y_col
                    P_ind[idx1, idx2] = Px[x_row, x_col] * Py[y_row, y_col]
    return P_ind

In [4]:
def exact_tce(Pz, c):
    d = Pz.shape[0]
    # c = np.reshape(c.T, (d, -1))
    c = np.reshape(c, (d, -1))
    A = np.block(
        [
            [np.eye(d) - Pz, np.zeros((d, d)), np.zeros((d, d))],
            [np.eye(d), np.eye(d) - Pz, np.zeros((d, d))],
            [np.zeros((d, d)), np.eye(d), np.eye(d) - Pz],
        ]
    )
    # b = np.block([np.zeros((d, 1)), c, np.zeros((d, 1))])
    b = np.concatenate([np.zeros((d, 1)), c, np.zeros((d, 1))])
    try:
        sol = np.linalg.solve(A, b)
    # except np.linalg.LinAlgError:
    except:
        sol = np.matmul(pinv(A), b)
    # sol = sol.T
    # g = sol[0:d].T
    # h = sol[d:2*d].T
    g = sol[0:d].flatten()
    h = sol[d : 2 * d].flatten()
    return g, h

In [5]:
def computeot_lp(C, r, c):
    # nx = r.shape[0]
    # ny = c.shape[1]
    nx = r.size
    ny = c.size
    Aeq = np.zeros((nx + ny, nx * ny))
    beq = np.concatenate((r.flatten(), c.flatten()))
    beq = beq.reshape(-1, 1)

    # column sums correct
    for row in range(nx):
        for t in range(ny):
            Aeq[row, (row * ny) + t] = 1

    # row sums correct
    for row in range(nx, nx + ny):
        for t in range(nx):
            Aeq[row, t * ny + (row - nx)] = 1

    # lb = np.zeros(nx*ny)
    bound = [[0, None]] * (nx * ny)

    # solve OT LP using linprog
    # cost = C.flatten()
    cost = C.reshape(-1, 1)
    # res = linprog(cost, A_eq=Aeq, b_eq=beq, bounds=(lb, None), method='highs')
    res = linprog(cost, A_eq=Aeq, b_eq=beq, bounds=bound, method="highs")
    lp_sol = res.x
    lp_val = res.fun
    return lp_sol, lp_val

In [7]:
def exact_tci(g, h, P0, Px, Py):
    # x_sizes = Px.shape
    # y_sizes = Py.shape
    # dx = x_sizes[0]
    # dy = y_sizes[0]
    dx = Px.shape[0]
    dy = Py.shape[0]
    Pz = np.zeros((dx * dy, dx * dy))
    # print(1, P0[18, 1])
    ## Try to improve with respect to g.
    # Check if g is constant.
    g_const = True
    for i in range(dx):
        for j in range(i + 1, dx):
            if abs(g[i] - g[j]) > 1e-3:
                g_const = False
                break
        if not g_const:
            break
    # If g is not constant, improve transition coupling against g.
    # print(2, P0[18, 1])
    # print(P[18, 1])
    if not g_const:
        # g_mat = np.reshape(g, (dx, dy)).T
        g_mat = np.reshape(g, (dx, dy))
        for x_row in range(dx):
            for y_row in range(dy):
                dist_x = Px[x_row, :]
                dist_y = Py[y_row, :]
                # Check if either distribution is degenerate.
                if any(dist_x == 1) or any(dist_y == 1):
                    sol = np.outer(dist_x, dist_y)
                # If not degenerate, proceed with OT.
                else:
                    sol, val = computeot_lp(g_mat, dist_x, dist_y)
                idx = dy * (x_row) + y_row
                Pz[idx, :] = np.reshape(sol, (-1, dx * dy))
                # P[idx, :] = sol
        if np.max(np.abs(np.matmul(P0, g) - np.matmul(Pz, g))) <= 1e-7:
            # print('HERER')
            Pz = copy.deepcopy(P0)
        else:
            return Pz
    # print(3, P0[18, 1])
    # print(P[18, 1])
    ## Try to improve with respect to h.
    h_mat = np.reshape(h, (dx, dy))
    for x_row in range(dx):
        for y_row in range(dy):
            dist_x = Px[x_row, :]
            dist_y = Py[y_row, :]
            # Check if either distribution is degenerate.
            if any(dist_x == 1) or any(dist_y == 1):
                sol = np.outer(dist_x, dist_y)
            # If not degenerate, proceed with OT.
            else:
                sol, val = computeot_lp(h_mat, dist_x, dist_y)
            idx = dy * (x_row) + y_row
            # print(x_row, y_row, P0[18, 1])
            Pz[idx, :] = np.reshape(sol, (-1, dx * dy))

    if np.max(np.abs(np.matmul(P0, h) - np.matmul(Pz, h))) <= 1e-4:
        Pz = copy.deepcopy(P0)
        # print('12312312')
    # print(4, P0[18, 1])
    return Pz

In [8]:
import numpy as np


def get_stat_dist(Pz):
    # Calculate the eigenvalues and eigenvectors
    eigenvalues, eigenvectors = np.linalg.eig(Pz.T)

    # Find the index of the eigenvalue closest to 1
    idx = np.argmin(np.abs(eigenvalues - 1))

    # Get the corresponding eigenvector
    stationary_dist = np.real(eigenvectors[:, idx])
    stationary_dist /= np.sum(
        stationary_dist
    )  # Normalize to make it a probability distribution

    return stationary_dist


# Example usage
P = np.array([[0.5, 0.4, 0.1], [0.3, 0.2, 0.5], [0.4, 0.2, 0.4]])
stat_dist = get_stat_dist(P)
print("Stationary distribution:", stat_dist)

Stationary distribution: [0.41304348 0.2826087  0.30434783]


In [9]:
def exact_otc1(Px, Py, c):
    dx = Px.shape[0]
    dy = Py.shape[0]

    P_old = np.ones((dx * dy, dx * dy))
    P = get_ind_tc(Px, Py)
    iter_ctr = 0
    while np.max(np.abs(P - P_old)) > 1e-10:
        print(iter_ctr)
        iter_ctr += 1
        # P_old = P.copy()
        P_old = np.copy(P)
        # P_old = P

        # Transition coupling evaluation.
        g, h = exact_tce(P, c)

        # Transition coupling improvement.
        P = exact_tci(g, h, P_old, Px, Py)

        # Check for convergence.
        if np.all(P == P_old):
            stat_dist = get_stat_dist(P)
            # stat_dist = np.reshape(stat_dist, (dy, dx)).T
            stat_dist = np.reshape(stat_dist, (dx, dy))
            exp_cost = np.sum(stat_dist * c)
            return iter_ctr, exp_cost, P, stat_dist

    return None, None, None, None

In [10]:
def softmax_matrix(matrix):
    # Apply softmax function row-wise
    return np.apply_along_axis(softmax, axis=1, arr=matrix)


def softmax(x):
    return np.exp(x) / np.sum(np.exp(x))


def adj_to_trans(A):
    nrow = A.shape[0]
    T = np.copy(A).astype(float)
    for i in range(nrow):
        row = A[i, :]
        k = np.where(row != 0)[0]
        vals = softmax(row[k])
        for idx in range(len(k)):
            T[i, k[idx]] = vals[idx]

    return T

In [11]:
def get_degree_cost(D1, D2):
    d1 = D1.shape[0]
    d2 = D2.shape[0]
    degrees1 = np.sum(D1, axis=1)
    degrees2 = np.sum(D2, axis=1)
    cost_mat = np.zeros((d1, d2))
    for i in range(d1):
        for j in range(d2):
            cost_mat[i, j] = (degrees1[i] - degrees2[j]) ** 2
    return cost_mat

In [12]:
A1 = np.zeros((10, 10))
for i in range(9):
    A1[i, i + 1] = 1
    A1[i + 1, i] = 1
A1[4, 9] = 1
A1[9, 4] = 1
A2 = np.zeros((18, 18))
for i in range(17):
    A2[i, i + 1] = 1
    A2[i + 1, i] = 1
A2[0, 5] = 1
A2[5, 0] = 1
A2[3, 6] = 1
A2[6, 3] = 1
A2[9, 11] = 1
A2[11, 9] = 1
A2[12, 17] = 1
A2[17, 12] = 1
A2[5, 6] = 0
A2[6, 5] = 0
A2[10, 11] = 0
A2[11, 10] = 0
A3 = np.zeros((15, 15))
for i in range(14):
    A3[i, i + 1] = 1
    A3[i + 1, i] = 1
A3[1, 3] = 1
A3[3, 1] = 1
A3[3, 14] = 1
A3[14, 3] = 1
A3[8, 13] = 1
A3[13, 8] = 1
A3[8, 14] = 1
A3[14, 8] = 1
A3[6, 8] = 1
A3[8, 6] = 1
A3[11, 13] = 1
A3[13, 11] = 1
A3[2, 3] = 0
A3[3, 2] = 0
A3[7, 8] = 0
A3[8, 7] = 0
A3[12, 13] = 0
A3[13, 12] = 0
P1 = adj_to_trans(A1)
P2 = adj_to_trans(A2)
P3 = adj_to_trans(A3)
c12 = get_degree_cost(A1, A2)
c13 = get_degree_cost(A1, A3)
c23 = get_degree_cost(A2, A3)

In [13]:
result = np.zeros((3, 1))

# Exact OTC
# [result[0, 0], P12, stat_dist12] = exact_otc1(P1, P2, c1)
# [result[0, 1], P13, stat_dist13] = exact_otc1(P1, P3, c1)
[iter, result[0, 0], P12, stat_dist12] = exact_otc1(P1, P2, c12)
# [result[1, 0], P13, stat_dist13] = exact_otc1(P1, P3, c13)
# [result[2, 0], P23, stat_dist23] = exact_otc1(P2, P3, c23)
print(result)
print(iter)

0
1
2
3
[[0.16323382]
 [0.        ]
 [0.        ]]
4
