In [7]:
import numpy as np
import itertools
import torch
import torch.nn.functional as F
import leidenalg as la
import igraph as ig

import sys
sys.path.append("../")

from src.metrics.supervised import evaluate_supervised
# import matplotlib.pyplot as plt

In [8]:
def create_igraph(A, directed=False):

    # get row and col indices of non-zero elements in adjacency matrix
    connection_indices = np.where(A)
    # create edges as a sequence of (i, j) tuples
    edges = zip(*connection_indices)
    # get the weights corresponding to connections
    weights = A[connection_indices]
    # initialize the graph from the edge sequence
    G = ig.Graph(edges=edges, directed=directed)
    # assign node names and weights to be attributes of the vertices and edges respectively
    G.vs["id"] = range(A.shape[0])
    G.es["weight"] = weights

    # assert np.alltrue(A == np.array(G.get_adjacency(attribute="weight").data))
    
    return G

In [9]:
def permute_labels(C_pred, C_true):

    # check shape
    # check one hot

    best_acc = 0.
    best_perm = tuple()

    for perm in list(itertools.permutations(range(C_true.shape[-1]))):
        
        C_pred_prem = C_pred[..., perm]
        acc = ((C_pred_prem == C_true) * 1.).mean().item()

        if acc > best_acc:
            
            best_acc = acc
            best_perm = perm

    return C_pred[..., best_perm]

In [14]:
def train(dataset="data.npz", gamma=1, n_iter=10, seed=1234):

    data_path = "../data/" + dataset
    data_dic = dict(np.load(data_path, allow_pickle=True).items())

    samples_adj_ts = data_dic["A"]
    samples_comm_ts = torch.from_numpy(data_dic["C"]).long()

    samples_comm_ts_pred = np.zeros(shape=samples_adj_ts.shape[:-1], dtype=int)

    for sample_idx, adj_ts in enumerate(samples_adj_ts):

        graph_ts = [create_igraph(adj_t) for adj_t in adj_ts]

        community_ts, _ = la.find_partition_temporal(graph_ts,  partition_type=la.RBConfigurationVertexPartition, resolution_parameter=gamma, n_iterations=n_iter, seed=seed)
        
        samples_comm_ts_pred[sample_idx] = np.asarray(community_ts)

    samples_comm_ts_pred = torch.from_numpy(samples_comm_ts_pred)
    samples_comm_ts_pred = F.one_hot(samples_comm_ts_pred, num_classes=samples_adj_ts.shape[-1]).long()
    samples_comm_ts_pred = permute_labels(samples_comm_ts_pred, samples_comm_ts)

    metrics = evaluate_supervised(samples_comm_ts_pred, samples_comm_ts)
    # save results 

{'f1': 0.8666666746139526,
 'precision': 0.8666666746139526,
 'recall': 0.8666666746139526,
 'hamming': 0.10000002384185791,
 'jaccard': 0.9750000238418579,
 'nmi': 1.0}

In [6]:
# afig, ax = plt.subplots(figsize = (30, 1))
# ax.imshow(np.array(membership).T)
# plt.show()