In [1]:
# This notebook converts the graph data from the pickle file containing the list of graph representations 
# into anoter pickle file containing the list of graph representations in the format of PyG format.
import torch
import pickle
import numpy as np

# Load data (MAC)
pickle_data_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/pickle/jh101_grs.pickle"
pickle_data = pickle.load(open(pickle_data_path, "rb"))

In [2]:
# Credit: https://github.com/pyg-team/pytorch_geometric/issues/964

import numpy as np

def build_K_n(num_nodes):
    """
        Builds the edge_index for a complete graph K_n for num_nodes = n.
    """
    # Initialize edge index matrix
    E = np.zeros((2, num_nodes * (num_nodes - 1)), dtype=np.int64)

    # Populate 1st row
    for node in range(num_nodes):
        for neighbor in range(num_nodes - 1):
            E[0, node * (num_nodes - 1) + neighbor] = node

    # Populate 2nd row
    neighbors = []
    for node in range(num_nodes):
        neighbors.append(list(np.arange(node)) + list(np.arange(node + 1, num_nodes)))
    E[1, :] = [item for sublist in neighbors for item in sublist]

    return E

In [3]:
# Convert entries of pickle_data from [[A, x, edge_attr], Y] to [[edge_index, x, new_edge_attr], Y]
# where edge_index is a complete graph, x is unchanged, and new_edge_attr is the weights of adjacency matrix A
def adj_to_edge_attr(A, edge_index, edge_attr = None):
    """
    Converts a fully connected weighted adjacency matrix to edge features.
    
    
    Args:
        A (numpy array): Adjacency matrix.
        edge_attr (numpy array): Edge features of shape (num_edges, num_edge_features), optionally included.
    
    returns:
        edge_attr (numpy array): New edge fatures of shape (num_edges, num_edge_features + 1).
    """

    num_nodes = A.shape[0]

    # If edge_attr is not included, initialize it
    if edge_attr is None:
        num_edges = int(num_nodes * (num_nodes - 1) / 2)
        edge_attr = np.zeros((num_edges, 1))
        for i in range(num_edges):
            k, l = edge_index[:, i]
            edge_attr[i] = A[k, l]
    
        return edge_attr
    
    # TODO: Case when edge_attr is included
    else:
        return None


# Initialize new data list
complete_graph = build_K_n(pickle_data[0][0][0].shape[0])
pyg_data = []

for i in range(len(pickle_data)):
    pyg_data.append([[complete_graph, pickle_data[i][0][1], adj_to_edge_attr(pickle_data[i][0][0], complete_graph)], 
                    pickle_data[i][1]])

In [4]:
# Pickle pyg_data
save_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/pickle/"
file_name = "jh101_grs_pyg"
with open(save_path + file_name + ".pkl", "wb") as f:
    pickle.dump(pyg_data, f)

In [6]:
pyg_data_path = "/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/pickle/jh101_grs_pyg.pkl"
pyg_data = pickle.load(open(pyg_data_path, "rb"))

print(len(pyg_data))
print(len(pickle_data))
print(type(pyg_data[0][0][0]))
print("Edge features stored in edge_attr:", pyg_data[0][0][2])
print("Edge features stored in adj:", pickle_data[0][0][0])

4484
4484
<class 'numpy.ndarray'>
Edge features stored in edge_attr: [[ 0.66069462]
 [-0.12578357]
 [-2.10979259]
 ...
 [-0.41955958]
 [ 0.48147154]
 [ 1.84190391]]
Edge features stored in adj: [[ 2.44109208  0.66069462 -0.12578357 ...  0.83818009 -1.58988827
  -0.80577129]
 [ 0.66069462  2.44109208  1.6717804  ...  0.60534856 -2.1600128
  -1.95782143]
 [-0.12578357  1.6717804   2.44109208 ...  0.43843359 -1.49362948
  -1.54959986]
 ...
 [ 0.83818009  0.60534856  0.43843359 ...  2.44109208 -1.4207016
  -1.2727665 ]
 [-1.58988827 -2.1600128  -1.49362948 ... -1.4207016   2.44109208
   1.89210687]
 [-0.80577129 -1.95782143 -1.54959986 ... -1.2727665   1.89210687
   2.44109208]]


In [None]:
# Link for pairs of graphs: https://pytorch-geometric.readthedocs.io/en/latest/advanced/batching.html
# Link for creating datasets: https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html
# Link for Data handling tutorial: https://www.youtube.com/watch?v=Vz5bT8Xw6Dc&list=PLGMXrbDNfqTzqxB1IGgimuhtfAhGd8lHF&index=5