In [1]:
import networkx as nx
import numpy as np
from itertools import permutations, product
from functools import partial
import scipy.sparse as sps
import scipy.linalg as la
from scipy.optimize import minimize, linear_sum_assignment, LinearConstraint, Bounds
import sys, time
sys.path.insert(1, '/home/me/persim-fork/persim')
from gromov_hausdorff import estimate, find_lb, find_ub, find_ub_of_min_distortion

In [849]:
def generate_graph(n, p, enforce_n=False):
    while True:
        G = nx.erdos_renyi_graph(n, p)
        cc = list(max(nx.connected_components(G), key=len))
        if len(cc) == n or not enforce_n:
            A = nx.to_numpy_array(G)[np.ix_(cc, cc)]
            break
    
    D = sps.csgraph.shortest_path(A, directed=False, unweighted=True)
    
    return A, D


def pi2Pi(pi, m): # {1,…,m}^n → {0,1}^n×m
    return np.identity(m)[pi]


def Pi2pi(Pi): #{0,1}^n×m → {1,…,m}^n
    return np.argmax(Pi, axis=1)


def project_P(P): #row-stochastic [0,1]^n×m → {0,1}^n×m 
    return pi2Pi(Pi2pi(P), P.shape[1])


def dis(Pi, DX, DY):
    return np.abs(DX - Pi @ DY @ Pi.T).max()


def dis_l2(Pi, DX, DY):
    return ((DX - Pi @ DY @ Pi.T)**2).sum()


def build_heat_kernel(A, t=1):
    lambdas, Phi = la.eigh(sps.csgraph.laplacian(A))
    return Phi @ np.diagflat(np.exp(-t * lambdas)) @ Phi.T


def compile_performance(n, p, max_iter, inj, size, f_type, mx_type, show_graph_properties=True):
    perf = performance.get((n, p, max_iter, inj, size, f_type, mx_type), None)
    perf_desc = f"{perf['acc']:.1%} correct in {perf['time']:.0f}s" if perf else '-'*10
    if show_graph_properties and perf:
        perf_desc += f" (N={perf['order']}, diam={perf['diam']})"
#     print('n={}, p={}, max_iter={}, {}, {}, {}: {:.0%} correct ({} graph pairs of avg. n={:.1f} & diam={:.1f}) in {:.0f}s'.format(
#         n, p, max_iter, injs[inj], f_types[f_type], mx_types[mx_type], perf['acc'], N, perf['order'], perf['diam'], perf['time']))
    return (f'n={n}, p={p}, max_iter={max_iter}, {inj_dict[inj]}, {size_dict[size]}, {f_type_dict[f_type]}'
            f', {mx_type_dict[mx_type]}: {perf_desc}')
 

In [1014]:
def perturb_pi_XY(pi_XY, rand_n):
    pi_XY_perturbed = pi_XY.copy()
    indexes = np.random.choice(len(pi_XY), rand_n, replace=False)
    pi_XY_perturbed[indexes] = pi_XY_perturbed[indexes][np.random.permutation(rand_n)]
    
    return pi_XY_perturbed


# Frank-Wolfe algoritm.
def minimize_over_Pi(n, m, f, grad, alpha_jac, alpha_hess, dis_, dis_l2_, P0=None,
                     tol=1e-8, max_iter=100, verbose=True, is_inj=False):
    if P0 is None:
        P = np.full((n, m), 1/m)
    elif P0.shape == (n, m):
        P = P0.copy()
    elif P0.shape == (m, n):
        P = P0.T
        P[np.all(P == 0, axis=1)] = 1/m
    else:
        raise ValueError('P0 has shape {} for n={}, m={}'.format(P0.shape, n, m))
        
    if is_inj:
        assert n <= m

    #P = np.full((n, m), 1/m) if P0 is None else P0 # barycenter init (!! consider randomizing it as in FAQ?)
    if verbose:
        Pi = project_P(P)
        print('-'*10 + f'INIT dis(P)={dis_(P)}, dis_l2(P)={dis_l2_(P)}, dis(Pi)={dis_(Pi)},'
              f'dis_l2(Pi)={dis_l2_(P)}, f(P)={f(P)}')

    for i in range(max_iter):
        grad_at_P = grad(P)
            
        if is_inj: # injectivity assumes n ≤ m
            #Q = np.full((n, m), 1/m)
            #r, c = linear_sum_assignment(grad_at_P)
            #Q[r] = pi2Pi(c, m)
            Q = pi2Pi(linear_sum_assignment(grad_at_P)[1], m)
        else:
            Q = pi2Pi(np.argmin(grad_at_P, axis=1), m)
        dQP = Q - P
        
        res = minimize(lambda x: f(P + x*dQP), 0.5, bounds=[(0, 1)],
                       jac=lambda x: alpha_jac(x, P, dQP))#, hess=lambda x: alpha_hess(x, P, Q))
        alpha = res.x
        
        grad_norm = np.sum(grad_at_P**2)
        alpha_dQP = alpha*dQP
        alpha_dQP_norm = np.sum((alpha*dQP)**2)
        if grad_norm < tol or alpha_dQP_norm < tol:
            if verbose:
                print('-'*10 + f'STOP α={alpha}, ‖∇f‖²={grad_norm:.2f}, ‖α*dQP‖²={alpha_dQP_norm:.2f}',
                      '' if res.success else '(FAILURE)')
            break
                  
        try:
            assert np.allclose(np.sum(P + alpha_dQP, axis=1), np.ones(n))
        except AssertionError:
            print(P, alpha, dQP)
            raise AssertionError
        
        P += alpha_dQP
        if verbose:
            Pi = project_P(P)
            print(f'iter {i}: dis(P)={dis_(P)}, dis_l2(P)={dis_l2_(P)}, dis(Pi)={dis_(Pi)},'
                  f'dis_l2(Pi)={dis_l2_(P)}, f(P)={f(P)}, α={alpha}, ‖∇f‖²={grad_norm:.2f},'
                  f'‖α*dQP‖²={alpha_dQP_norm:.2f}', '' if res.success else '(FAILURE)')
        
    return project_P(P)


def def_functions_for_frank_wolfe(X, Y, f_type):
    if f_type == 'sq': #'‖Δ‖²'
        def S(P):
            return -X@P@Y
    elif f_type == 'sq+': #'‖Δ‖² for non-surj'
        def S(P):
            return -X@P@Y + .5*P@Y@P.T@P@Y
    elif f_type in {'exp', 'mix', 'mix+'}: #'‖e^|Δ|‖_1'
        def S(P):
            return np.e**X@P@np.e**-Y + np.e**-X@P@np.e**Y
    elif f_type == 'c_exp': #'‖c^|Δ|‖_1'
        delta = .1 # smallest positive |d_X(x, x') - d_Y(y, y')|
        c = (len(X)**2 + 1)**(1/delta) # so that argmin ‖c^|Δ|‖_1 = argmin ‖Δ‖_∞
        c = (len(X)/2 + 1)**(1/delta)
        def S(P):
            return c**X@P@2**-Y + c**-X@P@2**Y
    else:
        raise ValueError('unknown f_type, Vlad')

    def f(P):
        return np.sum(P * S(P))

    def grad(P):
        if f_type == 'sq+':
            #return 2 * (-X@P@Y + np.sum(P@Y@P.T, axis=0)[:, None] * (P@Y))
            return 2 * (-X@P@Y + P@Y@P.T@P@Y)
        else:
            return 2*S(P)

    def alpha_jac(alpha, P, d): # d is dQP = Q - P
        if f_type == 'sq+':            
            Pxd = P.T@d
            YPPY = Y@P.T@P@Y
            YPdY = Y@(Pxd + Pxd.T)@Y
            YddY = Y@d.T@d@Y
            return 2*alpha**3*np.sum(d * (d@YddY)) +\
                1.5*alpha**2*(np.sum(d*(P@YddY)) + np.sum(P*(d@YddY)) + np.sum(d*(d@YPdY))) +\
                alpha*(np.sum(P*(P@YddY)) + np.sum(d*(P@YPdY)) + np.sum(P*(d@YPdY)) + np.sum(d*(d@YPPY))) +\
                .5*(np.sum(P*(P@YPdY)) + np.sum(d*(P@YPPY)) + np.sum(P*(d@YPPY))) +\
                np.sum(d*(-X@P@Y)) + np.sum(P*(-X@d@Y)) + 2*alpha*np.sum(d*(-X@d@Y))
        else:
            return np.sum(d * S(P)) + np.sum(P * S(d)) + 2*alpha*np.sum(d * S(d))

    def alpha_hess(alpha, P, d): # d is dQP = Q - P
        if f_type == 'sq+':
            Pxd = P.T@d
            YPPY = Y@P.T@P@Y
            YPdY = Y@(Pxd + Pxd.T)@Y
            YddY = Y@d.T@d@Y
            return 6*alpha**2*np.sum(d * (d@YddY)) +\
                3*alpha*(np.sum(d*(P@YddY)) + np.sum(P*(d@YddY)) + np.sum(d*(d@YPdY))) +\
                np.sum(P*(P@YddY)) + np.sum(d*(P@YPdY)) + np.sum(P*(d@YPdY)) + np.sum(d*(d@YPPY)) +\
                2*np.sum(d*(-X@d@Y))
        else:
            return 2*np.sum(d * S(d))

    return f, grad, alpha_jac, alpha_hess
    

def find_ub_XY(X, Y, DX, DY, f_type, max_iter, inj, verbose):
    f_XY, grad_XY, alpha_jac_XY, alpha_hess_XY = def_functions_for_frank_wolfe(X, Y, f_type)
    dis_XY = partial(dis, DX=DX, DY=DY)
    dis_l2_XY = partial(dis_l2, DX=DX, DY=DY)
    minimize_over_Pi_XY = partial(
        minimize_over_Pi, n=len(X), m=len(Y), f=f_XY, grad=grad_XY, dis_=dis_XY, dis_l2_=dis_l2_XY,
        alpha_jac=alpha_jac_XY, alpha_hess=alpha_hess_XY, max_iter=max_iter, verbose=verbose)
    
    if inj == -1:
        Pi_XY = minimize_over_Pi_XY(is_inj=False)
        ub = dis_XY(Pi_XY) / 2
        inj_ub = np.nan
    else:
        if f_type.startswith('mix'):
            first_f_type = 'sq' if f_type == 'mix' else 'sq+'
            inj_Pi = minimize_over_Pi(
                len(X), len(Y), *def_functions_for_frank_wolfe(X, Y, first_f_type), is_inj=True,
                dis_=dis_XY, dis_l2_=dis_l2_XY, max_iter=max_iter, verbose=verbose)
        else:
            inj_Pi = minimize_over_Pi_XY(is_inj=True)
            
        inj_ub = dis_XY(inj_Pi) / 2
        #print('inj ub is ', inj_ub)
        if inj == 0:
            Pi_XY = minimize_over_Pi_XY(P0=inj_Pi, is_inj=False)
            ub = dis_XY(Pi_XY) / 2
        elif inj == 1:
            ub = np.nan
            Pi_XY = inj_Pi

    return inj_ub, ub, Pi_XY


def find_ub_FAQ(X, Y, DX, DY, f_type, max_iter, inj, verbose):
    if len(X) > len(Y): # ensure |X| ≤ |Y|
        X, Y = Y, X
        DX, DY = DY, DX
    
    f_XY, grad_XY, alpha_jac_XY, alpha_hess_XY = def_functions_for_frank_wolfe(X, Y, f_type)
    f_YX, grad_YX, alpha_jac_YX, alpha_hess_YX = def_functions_for_frank_wolfe(Y, X, f_type)
    dis_XY = partial(dis, DX=DX, DY=DY)
    dis_YX = partial(dis, DX=DY, DY=DX)
    minimize_over_Pi_XY = partial(
        minimize_over_Pi, n=len(X), m=len(Y), f=f_XY, grad=grad_XY, dis_=dis_XY,
        alpha_jac=alpha_jac_XY, alpha_hess=alpha_hess_XY, max_iter=max_iter, verbose=verbose)
    minimize_over_Pi_YX = partial(
        minimize_over_Pi, n=len(Y), m=len(X), f=f_YX, grad=grad_YX, dis_=dis_YX,
        alpha_jac=alpha_jac_YX, alpha_hess=alpha_hess_YX, max_iter=max_iter, verbose=verbose)
    
    if inj == -1:
        Pi_XY = minimize_over_Pi_XY(is_inj=False)
        Pi_YX = minimize_over_Pi_YX(is_inj=False)
        ub = max(dis_XY(Pi_XY), dis_YX(Pi_YX)) / 2
        inj_ub = np.nan
    else:
        if f_type == 'mix':
            inj_Pi = minimize_over_Pi(len(X), len(Y), *def_functions_for_frank_wolfe(X, Y, 'sq'),
                                      is_inj=True, dis_=dis_XY, max_iter=max_iter, verbose=verbose)
        else:
            inj_Pi = minimize_over_Pi_XY(is_inj=True)
            
        inj_ub = dis_XY(inj_Pi) / 2
        print('inj ub is ', inj_ub)
        if inj == 0:
            Pi_XY = minimize_over_Pi_XY(P0=inj_Pi, is_inj=False)
            Pi_YX = minimize_over_Pi_YX(P0=inj_Pi, is_inj=False)
            ub = max(dis_XY(Pi_XY), dis_YX(Pi_YX)) / 2
        elif inj == 1:
            ub = np.nan
            Pi_XY = Pi_YX = inj_Pi
    
    return inj_ub, ub, Pi_XY, Pi_YX


In [1018]:
all_ns = [1000, 100, 10]
all_ps = [.01, .05, .1]
all_max_iters = [100, 500]
inj_dict = {-1: 'non-inj', 0: 'inj→non', 1: 'inj'}
size_dict = {-1: 'X>Y', 0: 'X=Y', 1: 'X<Y'}
f_type_dict = {'sq': '‖Δ‖²', 'sq+': '‖Δ‖²+', 'exp': '‖e^|Δ|‖_1', 'c_exp': '‖c^|Δ|‖_1', 'mix': '‖Δ‖²→‖e^|Δ|‖', 
               'mix+': '‖Δ‖²+→‖e^|Δ|‖'}
mx_type_dict = {0: 'AX,AY', 1: 'DX,DY'}

In [1019]:
performance = dict()

In [None]:
verbose = False
N = 50 # number of graphs in one dataset
i = 0
size_diff_rel = .5
ns = [10, 100, 500] # [10]
ps = [.25, .05, .01] # [.1]
max_iters = [100]
injs = [0, 1]
sizes = [1]
f_types = ['sq', 'sq+', 'exp', 'mix', 'mix+']
mx_types = [0, 1]
n_combinations = len(ns) * len(max_iters) * len(injs) * len(sizes) * len(f_types) * len(mx_types)
for n, p in zip(ns, ps): 
    As = []
    Ds = []
    pi_YXs = []
    for _ in range(N):
        A, D = generate_graph(n, p, enforce_n=False)
        As.append(A)
        Ds.append(D)
        pi_YXs.append(np.random.permutation(len(A)))
    
    for max_iter, inj, size, f_type, mx_type in product(max_iters, injs, sizes, f_types, mx_types):
        i += 1
        
        perf_key = n, p, max_iter, inj, size, f_type, mx_type
        if perf_key in performance or (f_type == 'mix' and inj != 0):
            print(f'SKIPPING {i}', compile_performance(*perf_key))
            continue
        
        start = time.time()
        exact_XY = []
#         exact_YX = []
        for _, (AX, DX, pi_YX) in enumerate(zip(As, Ds, pi_YXs)):
            if size == 1:
                size_diff = round(size_diff_rel * len(AX))
                AY_ = AX[np.ix_(pi_YX, pi_YX)]
                AY = np.zeros((len(AX) + size_diff, len(AX) + size_diff))
                AY[:len(AX), :len(AX)] = AY_
                AY[len(AX):, :] = AY[:, len(AX):] = DX.max() # to guarantee compliance with triangle ineq.
                AY[np.arange(size_diff) + len(AX), np.arange(size_diff) + len(AX)] = 0
                DY = sps.csgraph.shortest_path(AY, directed=False)
                assert np.all(DY[:len(AX), :len(AX)] == DX[np.ix_(pi_YX, pi_YX)]), 'Y violates triangle inequality'
            elif size == 0:
                AY = AX[np.ix_(pi_YX, pi_YX)]
                DY = DX[np.ix_(pi_YX, pi_YX)]
                
            if mx_type == 0:
                X, Y = AX, AY
            elif mx_type == 1:
                X, Y = DX, DY
            
            inj_ub_XY, ub_XY, Pi_XY = find_ub_XY(X, Y, DX, DY, f_type, max_iter, inj, verbose)
        #     if ub_XY > inj_ub_XY:
        #         print('#{}: ub_XY > inj_ub_XY ({} > {})'.format(i, ub_XY, inj_ub_XY))
#             inj_ub_YX, ub_YX, Pi_YX = find_ub_XY(AY, AX, DY, DX, 'norm2(Delta)', inj, False)
        #     if ub_YX > inj_ub_YX:
        #         print('#{}: ub_YX > inj_ub_YX ({} > {})'.format(i, ub_YX, inj_ub_YX))

            exact_XY.append((inj_ub_XY if inj == 1 else ub_XY) == 0)
#             exact_YX.append(ub_YX == 0)

#             if i+1 >= n:
#                 break

        exact_XY = np.array(exact_XY)
    
        performance[perf_key] = {'acc': exact_XY.mean(), 'time': time.time() - start,
                                 'order': np.mean([len(A) for A in As]), 'diam': np.mean([D.max() for D in Ds])}
#         exact_YX = np.array(exact_YX)
#         print('exact: XY {}, YX {}, XY & YX {} ({}s, {} pairs of avg.order {} and avg.diam {})'.format(
#             exact_XY.mean(), exact_YX.mean(), np.mean(exact_XY & exact_YX), round(time.time() - start, 2),
#             n, np.mean([len(A) for A in As[:n]]), np.mean([D.max() for D in Ds[:n]])))

        print(f'DONE {i}/{n_combinations}:', compile_performance(*perf_key))


DONE 1/60: n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖², AX,AY: 0.0% correct in 0s (N=8.52, diam=4.24)
DONE 2/60: n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖², DX,DY: 0.0% correct in 0s (N=8.52, diam=4.24)
DONE 3/60: n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖²+, AX,AY: 52.0% correct in 6s (N=8.52, diam=4.24)
DONE 4/60: n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖²+, DX,DY: 90.0% correct in 2s (N=8.52, diam=4.24)
DONE 5/60: n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖e^|Δ|‖_1, AX,AY: 26.0% correct in 0s (N=8.52, diam=4.24)
DONE 6/60: n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖e^|Δ|‖_1, DX,DY: 62.0% correct in 0s (N=8.52, diam=4.24)
DONE 7/60: n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖²→‖e^|Δ|‖, AX,AY: 4.0% correct in 1s (N=8.52, diam=4.24)
DONE 8/60: n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖²→‖e^|Δ|‖, DX,DY: 0.0% correct in 0s (N=8.52, diam=4.24)
DONE 9/60: n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖²+→‖e^|Δ|‖, AX,AY: 28.0% correct in 3s (N=8.52, diam=4.24)
DONE 10/60: n

In [1005]:
ks = performance.copy()
for n, p, max_iter, inj, size, f_type, mx_type in ks:
    if size == 1:#f_type == 'sq+' or n == 10:
        del performance[n, p, max_iter, inj, size, f_type, mx_type]

In [980]:
#for n, p, max_iter, inj, size, f_type, mx_type in product(ns, ps, max_iters, injs, sizes, f_types, mx_types):
#    print(compile_performance(n, p, max_iter, inj, size, f_type, mx_type))
for n, p, max_iter, inj, size, f_type, mx_type in sorted(performance.keys()):
    if size == 1:
    #if n == 100 and f_type == 'sq':
        print(compile_performance(n, p, max_iter, inj, size, f_type, mx_type))


n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖e^|Δ|‖_1, AX,AY: 0.0% correct in 0s (N=8.7, diam=3.96)
n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖e^|Δ|‖_1, DX,DY: 0.0% correct in 0s (N=8.7, diam=3.96)
n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖²→‖e^|Δ|‖, AX,AY: 0.0% correct in 0s (N=8.84, diam=4.04)
n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖²→‖e^|Δ|‖, DX,DY: 0.0% correct in 0s (N=8.84, diam=4.04)
n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖², AX,AY: 0.0% correct in 0s (N=8.92, diam=4.3)
n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖², DX,DY: 0.0% correct in 0s (N=8.92, diam=4.3)
n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖²+, AX,AY: 0.0% correct in 6s (N=8.92, diam=4.3)
n=10, p=0.25, max_iter=100, inj→non, X<Y, ‖Δ‖²+, DX,DY: 0.0% correct in 0s (N=8.92, diam=4.3)
n=10, p=0.25, max_iter=100, inj, X<Y, ‖e^|Δ|‖_1, AX,AY: 0.0% correct in 0s (N=8.7, diam=3.96)
n=10, p=0.25, max_iter=100, inj, X<Y, ‖e^|Δ|‖_1, DX,DY: 0.0% correct in 0s (N=8.7, diam=3.96)
n=10, p=0.25, max_iter=100, inj, X<Y, 

In [None]:
'''
n=100, p=0.05, max_iter=100, inj, X=Y, ‖Δ‖², AX & AY: 48.0% correct in 4s
n=100, p=0.05, max_iter=100, inj, X=Y, ‖Δ‖², DX & DY: 100.0% correct in 1s
n=100, p=0.05, max_iter=100, inj, X=Y, ‖e^|Δ|‖_1, AX & AY: 60.0% correct in 15s
n=100, p=0.05, max_iter=100, inj, X=Y, ‖e^|Δ|‖_1, DX & DY: 30.0% correct in 15s
n=500, p=0.05, max_iter=100, inj, X=Y, ‖Δ‖², AX & AY: 98.0% correct in 75s
n=500, p=0.05, max_iter=100, inj, X=Y, ‖Δ‖², DX & DY: 100.0% correct in 29s
n=500, p=0.05, max_iter=100, inj, X=Y, ‖e^|Δ|‖_1, AX & AY: 86.0% correct in 402s
n=500, p=0.05, max_iter=100, inj, X=Y, ‖e^|Δ|‖_1, DX & DY: 100.0% correct in 169s
n=10, p=0.1, max_iter=100, inj, X=Y, ‖Δ‖², AX & AY: 88.0% correct in 0s
n=10, p=0.1, max_iter=100, inj, X=Y, ‖Δ‖², DX & DY: 90.0% correct in 0s
n=10, p=0.1, max_iter=100, inj, X=Y, ‖e^|Δ|‖_1, AX & AY: 90.0% correct in 0s
n=10, p=0.1, max_iter=100, inj, X=Y, ‖e^|Δ|‖_1, DX & DY: 86.0% correct in 0s
'''

In [597]:
AX0, DX0, pi_YX0 = AX.copy(), DX.copy(), pi_YX.copy() # conserve problematic combination (AX,AY work but DX,DY don't on 'sq', inj=1)

In [926]:
AX, DX, pi_YX = AX0.copy(), DX0.copy(), pi_YX0.copy() # restore the problematic combination
n = len(AX)
AY_ = AX[np.ix_(pi_YX, pi_YX)]
DY_ = DX[np.ix_(pi_YX, pi_YX)]
AY = np.zeros((n + 1, n + 1))
AY[:n, :n] = AY_
AY[n - 1, n] = AY[n, n - 1] = .1
DY = sps.csgraph.shortest_path(AY, directed=False)

In [683]:
AX, DX = generate_graph(20, .1, enforce_n=True) # generate a combination
n = len(AX)
pi_YX = np.random.permutation(n)
AY_ = AX[np.ix_(pi_YX, pi_YX)]
DY_ = DX[np.ix_(pi_YX, pi_YX)]
AY = np.zeros((n + 1, n + 1))
AY[:n, :n] = AY_
AY[n - 1, n] = AY[n, n - 1] = .1
DY = sps.csgraph.shortest_path(AY, directed=False)

In [944]:
#find_ub_XY(DX, DY, partial(dis, DX=DX, DY=DY), 1, 100, 0, True)
#_,_,Pi = find_ub_XY(DY, DX, partial(dis, DX=DY, DY=DX), 1, 100, 1, True)
#dis(Pi[[0, 1, 2, 3, 4, 5, 6, 7]], DY, DX), Pi

inj_ub, ub, Pi_from_center = find_ub_XY(DX, DY, DX, DY, 'sq+', 100, 1, True)
# Pi_correct=[18,  0,  5, 14,  6, 17,  3, 11, 19,  7,  1, 12,  9, 10, 13, 15,  8, 2,  4, 16]
# pi_XY==Pi_correct for DX, DY_
inj_ub, ub, Pi2pi(Pi_from_center)#, dis(Pi_XY, DX, DY)

----------INIT dis(P)=3.5782312925170063, dis_l2(P)=596.5024758202601, dis(Pi)=6.0,dis_l2(Pi)=596.5024758202601, f(P)=-1156.7487620898696
iter 0: dis(P)=2.565446928251457, dis_l2(P)=251.9854858550616, dis(Pi)=3.0,dis_l2(Pi)=251.9854858550616, f(P)=-1329.0072570724694, α=[0.82685087], ‖∇f‖²=934176.79,‖α*dQP‖²=13.02 
----------STOP α=[0.], ‖∇f‖²=5578313.54, ‖α*dQP‖²=0.00 


(1.5,
 nan,
 array([18,  0, 17,  3,  6,  5, 10, 13, 20,  7,  1, 12,  9, 14,  4, 15,  8,
         2, 11, 16]))

In [945]:
Pi0 = pi2Pi(np.argsort(pi_YX), len(DY)) # correct mapping X → Y (dis = 0)
Pi = Pi_from_center
Pi2pi(Pi0)

array([18,  0,  5, 14,  6, 17,  3, 11, 19,  7,  1, 12,  9, 10, 13, 15,  8,
        2,  4, 16])

In [532]:
'''
[Pi0 = correct, Pi = obtained from F-W minimization (‖Δ‖², DX, DY, injections)] with no initialization

!!!! why Pi0 is not a local minimum for ‖Δ‖² on DX, DY and injections (but it is on AX, AY)?
Is it because P.T@P != I (since |Y| > |X|) and even for injective mapping:
argmin ‖DX - P@DY@P.T‖² = argmin -2<P, DX@P@DY> + ‖P@DY@P.T‖² != argmin -<P, DX@P@DY>
?

If so, then -<Pi, DX@Pi@DY> is less than -<Pi0, DX@Pi0@DY>, but ‖Pi0@DY@Pi0.T‖² ≥ ‖Pi@DY@Pi.T‖²
SEEMS NOT (b/c -<Pi, DX@Pi@DY> ≥ -<Pi0, DX@Pi0@DY>) BUT THE ABOVE MIGHT BE THE REASON TO NOT USE 'sq' (or 'mix')
for spaces of different sizes !!

UPD:
 - when no initial guess (i.e. it's the barycenter), F-W ends up in local minimum Pi (== Pi_from_center)
 - when initialized with Pi0, F-W ends up in Pi_from_Pi0 due to non-inj and
     argmin ‖DX - P@DY@P.T‖² != argmin -<P, DX@P@DY>
'''


'\n!! why correct Pi_XY is not a local minimum for ‖Δ‖² on DX, DY and injections (but it is on AX, AY)?\nIs it because P.T@P != I (since |DY| > |DX|) and even for injective mapping:\nargmin ‖DX - P@DY@P.T‖² = argmin -2<P, DX@P@DY> + ‖P@DY@P.T‖² != argmin -<P, DX@P@DY>\n?\n'

In [619]:
print('-<Pi, DX@Pi@DY>={}, -<Pi0, DX@Pi0@DY>={}, ‖Pi@DY@Pi.T‖²={}, ‖Pi0@DY@Pi0.T‖²={}'.format(
    np.sum(Pi * (-DX@Pi@DY)), np.sum(Pi0 * (-DX@Pi0@DY)),
    np.sum((Pi@DY@Pi.T)**2), np.sum((Pi0@DY@Pi0.T)**2)
))
print('-<Pi, AX@Pi@AY>={}, -<Pi0, AX@Pi0@AY>={}, ‖Pi@AY@Pi.T‖²={}, ‖Pi0@AY@Pi0.T‖²={}'.format(
    np.sum(Pi * (-AX@Pi@AY)), np.sum(Pi0 * (-AX@Pi0@AY)),
    np.sum((Pi@AY@Pi.T)**2), np.sum((Pi0@AY@Pi0.T)**2)
))
2*np.sum(Pi * (-DX@Pi@DY)) + np.sum((Pi@DY@Pi.T)**2), 2*np.sum(Pi0 * (-DX@Pi0@DY)) + np.sum((Pi0@DY@Pi0.T)**2)

-<Pi, DX@Pi@DY>=-3408.6, -<Pi0, DX@Pi0@DY>=-3204.0, ‖Pi@DY@Pi.T‖²=4650.78, ‖Pi0@DY@Pi0.T‖²=4594.0
-<Pi, AX@Pi@AY>=-22.0, -<Pi0, AX@Pi0@AY>=-10.0, ‖Pi@AY@Pi.T‖²=38.019999999999996, ‖Pi0@AY@Pi0.T‖²=46.0


(-2166.42, -1814.0)

In [950]:
# Shows that minimization diverges from initially correct mapping (this time it's b/c of |Y| > |X|:
# argmin ‖DX - P@DY@P.T‖² = argmin -2<P, DX@P@DY> + ‖P@DY@P.T‖² != argmin -<P, DX@P@DY>).
# Using 'sq+' instead of 'sq' seems to fix it.
f, grad_f, alpha_jac, alpha_hess = def_functions_for_frank_wolfe(DX, DY, 'sq')
dis_ = partial(dis, DX=DX, DY=DY)
dis_l2_ = partial(dis_l2, DX=DX, DY=DY)
Pi_from_Pi0 = minimize_over_Pi(
    len(DX), len(DY), f, grad_f, alpha_jac, alpha_hess, dis_, dis_l2_,
    P0=Pi0,
    max_iter=max_iter, is_inj=True, verbose=True)
print(dis_(Pi_from_Pi0), dis_(Pi0), np.sum(Pi_from_Pi0 * (-DX@Pi_from_Pi0@DY)), np.sum(Pi0 * (-DX@Pi0@DY)))
2*np.sum(Pi_from_Pi0 * (-DX@Pi_from_Pi0@DY)) + np.sum((Pi_from_Pi0@DY@Pi_from_Pi0.T)**2), 2*np.sum(Pi0 * (-DX@Pi0@DY)) + np.sum((Pi0@DY@Pi0.T)**2)

----------INIT dis(P)=0.0, dis_l2(P)=0.0, dis(Pi)=0.0,dis_l2(Pi)=0.0, f(P)=-2910.0
iter 0: dis(P)=2.0, dis_l2(P)=73.18, dis(Pi)=2.0,dis_l2(Pi)=73.18, f(P)=-2952.2000000000003, α=[1.], ‖∇f‖²=26206538.40,‖α*dQP‖²=6.00 
----------STOP α=[0.5], ‖∇f‖²=26771558.52, ‖α*dQP‖²=0.00 
2.0 0.0 -2952.2000000000003 -2910.0


(-2836.8200000000006, -2910.0)

In [581]:
Pi2pi(Pi), Pi2pi(Pi_XY)

(array([18,  0,  5, 14,  6, 17,  3, 11, 20, 13,  1, 12,  9, 10, 19, 15,  8,
         2,  4, 16]),
 array([20,  0,  5, 14,  6, 17,  3, 11, 18, 19,  1, 12,  9, 10, 13, 15,  8,
         2,  4, 16]))

In [662]:
f, grad_f, alpha_jac, alpha_hess = def_functions_for_frank_wolfe(DX, DY, 'sq')
dis_ = partial(dis, DX=DX, DY=DY)
Pi = minimize_over_Pi(
    len(DX), len(DY), f, grad_f, alpha_jac, alpha_hess, dis_,
    P0=Pi_from_center,
    max_iter=max_iter, is_inj=True, verbose=True)
print(dis_(Pi), dis_(Pi0), np.sum(Pi * (-DX@Pi@DY)), np.sum(Pi0 * (-DX@Pi0@DY)))

----------INIT dis(P)=4.10, dis(Pi)=4.1, f(P)=-2870.2
----------STOP α=[0.5], ‖∇f‖²=26643304.84, ‖α*dQP‖²=0.00 
4.1 0.0 -2870.2 -2910.0
