In [1]:
import numpy as np 
import pandas as pd
import networkx as nx
import scipy as sp
from sklearn.utils import Bunch
from scipy.io import loadmat
from dyneusr.core import DyNeuGraph  
from collections import defaultdict

In [2]:
## Load the NeuMapper result
mat = loadmat('../data/res_hax.mat')
res = mat['res_hax'][0][0]
res = Bunch(**{k:res[i] for i,k in enumerate(res.dtype.names)})
res = res.get('res', res.get('var', res))

In [3]:
# load one-hot encoding matrix of timing labels 
timing_onehot = pd.read_csv('SBJ02_timing_onehot.tsv', sep='\t')

In [4]:
## Convert to KeplerMapper format
membership = res.clusterBins
adjacency = membership @ membership.T
np.fill_diagonal(adjacency, 0)
adjacency = (adjacency > 0).astype(int)

In [5]:
membership

array([[0, 0, 0, ..., 0, 0, 1],
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 0, ..., 1, 0, 0],
       ...,
       [1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0]], dtype=uint8)

In [6]:
# get node link data 
G = nx.Graph(adjacency)
graph = nx.node_link_data(G)

In [7]:
# update format of nodes  e.g. {node: [row_i, ...]}
nodes = defaultdict(list) 
for n, node in enumerate(membership):
    nodes[n] = node.nonzero()[0].tolist()

In [8]:
# update format of links  e.g. {source: [target, ...]}
links = defaultdict(list) 
for link in graph['links']:
    u, v = link['source'], link['target']
    if u != v:
        links[u].append(v)

In [9]:
# update graph data
graph['nodes'] = nodes
graph['links'] = links

In [10]:
## Visualize the shape graph using DyNeuSR's DyNeuGraph
dG = DyNeuGraph(G=graph, y=timing_onehot)
dG.visualize('haxby_decoding_neumapper_dyneusr.html')

label,group,value,row_count
rest,0,588,1452
scissors,1,108,1452
face,2,108,1452
cat,3,108,1452
shoe,4,108,1452
house,5,108,1452
scrambledpix,6,108,1452
bottle,7,108,1452
chair,8,108,1452


Already serving localhost:None 
[Force Graph] file:///home/poetz/Desktop/PhD/Projects/Projects%20CSSS/TDA/CSSS22-TDA/code/haxby_decoding_neumapper_dyneusr.html


In [11]:
graph

{'directed': False,
 'multigraph': False,
 'graph': {},
 'nodes': defaultdict(list,
             {0: [1451],
              1: [1450],
              2: [1449],
              3: [1448],
              4: [1447],
              5: [1446],
              6: [1444, 1445],
              7: [1443],
              8: [1442],
              9: [1441],
              10: [1440],
              11: [1438, 1439],
              12: [1437],
              13: [1436],
              14: [1435],
              15: [1434],
              16: [1433],
              17: [1431, 1432],
              18: [1430],
              19: [1427, 1428],
              20: [1424, 1425, 1426],
              21: [1423],
              22: [1417],
              23: [1414],
              24: [1412, 1413, 1414],
              25: [1405, 1406, 1407, 1408, 1409, 1410],
              26: [1404, 1405],
              27: [1402, 1403],
              28: [1398, 1399, 1400],
              29: [1397],
              30: [1396],
              31: 