In [1]:
import numpy as np
import ot
import pickle
from tqdm import trange
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix

import gwb as gwb
from gwb import GM as gm

In [2]:
# measures for node correctness
def nca(idxs):
    idxs = np.array(idxs,dtype=int)
    nc = 0
    for idx in idxs:
        if np.all(idx[0] == idx):
            nc += 1
    return nc/len(idxs)

def nc2(idxs):
    nc2 = 0
    for idx in idxs:
        if len(set(idx)) < len(idx):
            nc2 += 1
    return nc2/len(idxs)

# Load Data

In [3]:
with open('../s-gwl-master/data/PPI_syn_database.pkl', 'rb') as f:
    database = pickle.load(f)

  database = pickle.load(f)


In [4]:
costs, probs, idx2nodes = database["costs"], database["probs"], database["idx2nodes"]
costs = np.array([c.todense() for c in costs])
N = len(costs)

# Construct GM Spaces

In [5]:
#create gm-spaces
Xs = []
for i in range(N):
    Xs.append(gm(mode="gauge_only",g = 1/2*(costs[i] + costs[i].T),xi = probs[i].reshape(-1,)))
print("GM spaces generated!")

GM spaces generated!


# Tangential Barycenter Iterations

In [10]:
n_its_tb = 10
method = "prox"
cr = "MCR"

for n in range(3,N+1):
    print("-------------------")
    print("Number of graphs: {0}".format(n))
    print("-------------------")
    bary = 0
    init_Ps = None
    gwbl_prevs = []
    ncas = []
    nc2s = []
    for it in trange(n_its_tb):
        bary_prev = bary
        idxs, meas, Ps, ref_idx = gwb.tb(bary_prev,Xs[:n],method=method,cr=cr)
        bary = gwb.bary_from_tb(Xs[:n],idxs,meas)
        nodes = np.array([[idx2nodes[i][j] for j in idxs[:,i]] for i in range(n)],dtype=int).T  

        gwbl_prev = gwb.gwb_loss(bary_prev,Xs[:n],Ps)
        if it >= 1 and gwbl_prev > gwbl_prevs[-1]:
            print("GWB Loss has increased at iteration {0}.".format(it))
            print("Stopping TB iterations.")
            print("Final GWB Loss: {0}".format(gwbl_prevs[-1]))
            print("NCA: {0}".format(ncas[-1]))
            print("NC2: {0}".format(nc2s[-1]))
            break
        else:
            gwbl_prevs.append(gwbl_prev)
            ncas.append(nca(nodes))
            nc2s.append(nc2(nodes))

-------------------
Number of graphs: 3
-------------------


 50%|█████████████████████▌                     | 5/10 [11:44<11:44, 140.84s/it]


GWB Loss has increased at iteration 5.
Stopping TB iterations.
Final GWB Loss: 0.0014367286420870966
NCA: 0.6992031872509961
NC2: 0.9332669322709163
-------------------
Number of graphs: 4
-------------------


 30%|████████████▉                              | 3/10 [10:25<24:18, 208.35s/it]


GWB Loss has increased at iteration 3.
Stopping TB iterations.
Final GWB Loss: 0.001988952422281846
NCA: 0.6135458167330677
NC2: 0.9681274900398407
-------------------
Number of graphs: 5
-------------------


 30%|████████████▉                              | 3/10 [13:08<30:40, 262.99s/it]


GWB Loss has increased at iteration 3.
Stopping TB iterations.
Final GWB Loss: 0.002480494435587417
NCA: 0.5209163346613546
NC2: 0.9721115537848606
-------------------
Number of graphs: 6
-------------------


 20%|████████▌                                  | 2/10 [11:46<47:04, 353.10s/it]

GWB Loss has increased at iteration 2.
Stopping TB iterations.
Final GWB Loss: 0.002945820456937205
NCA: 0.4551792828685259
NC2: 0.9890438247011952



