In [None]:
import gwb as gwb
from gwb import GM as gm

import matplotlib.pyplot as plt
import numpy as np
import ot
import networkx as nx

from networkx.algorithms.community.asyn_fluid import asyn_fluidc
from networkx.algorithms.community import greedy_modularity_communities
from sklearn.cluster import SpectralClustering
from sklearn.metrics import adjusted_mutual_info_score

In [None]:
#n_trials = 10
n_partitions = 2
partition_size = 50
num_nodes = n_partitions*partition_size
p_in = 0.4
p_out = 0.1
n_its_tb = 3
i_init_tb = 0
N = 10 #number of graphs

#create graphs
Gs = []
GTs = []
for i in range(N):
    G = nx.random_partition_graph(n_partitions*[partition_size],
                                               p_in=p_in, p_out=p_out, directed=False,seed=10*i)
    Gs.append(G)
    gt = []
    for i in range(num_nodes):
        gt.append(G.nodes[i]["block"])
    GTs.append(np.array(gt))
print("Graphs generated!")

#create gm-spaces
Xs = []
for i in range(N):
    Edges = np.array(Gs[i].edges)
    Nodes = np.array(Gs[i].nodes)
    Xs.append(gm(mode="graph",gauge_mode="adjacency",Nodes=Nodes,Edges=Edges))
print("GM spaces generated!")

In [None]:
#TB iterations and spectral clustering
bary = Xs[i_init_tb]
AMIs_per_TB_it = []
for i in range(n_its_tb):
    bary_prev = bary
    bary,log = gwb.TB(bary_prev,Xs,ws = ot.unif(N),mode="avg_gauge_only",log=True)
    idxs, meas, Ps = log.values()
    #bary = sample_GM(bary,n=500)

    #spectral clustering on barycenter    
    sc = SpectralClustering(2, affinity='precomputed',n_init=100, assign_labels='discretize')
    predict_bary_prev = sc.fit_predict(bary_prev.g) 
    #print(predict_bary_prev)
    
    AMIs = []
    for k in range(N):
        ami = adjusted_mutual_info_score(GTs[k],np.array(np.round(predict_bary_prev.dot(Ps[k]/Xs[k].xi),0),dtype=int))
        AMIs.append(ami)
    AMIs_per_TB_it.append(np.array(AMIs))
print("TB iterations and clustering completed!")

In [None]:
plt.imshow(Xs[0].g)

In [None]:
plt.imshow(bary_prev.g)

In [None]:
print(np.array(np.mean(AMIs_per_TB_it,axis=1)))
#print(np.mean(AMIs_GMC))
#print(np.mean(AMIs_AF))