In [2]:
import torch
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.datasets import HGBDataset
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = HGBDataset(root='data/', name='acm', transform=None)
data = dataset[0]

In [4]:
data['paper']

{'x': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'y': tensor([2, 2, 2,  ..., 2, 2, 2]), 'train_mask': tensor([False,  True, False,  ..., False, False,  True]), 'test_mask': tensor([ True, False,  True,  ...,  True,  True, False])}

In [15]:
data.num_features

{'paper': 1902, 'author': 1902, 'subject': 1902, 'term': 0}

In [None]:
def create_global_node_ids(data: torch_geometric.data.Data):
    """
    Create global node IDs for each node in the graph.
    """
    num_nodes_dict = {ntype: data[ntype].num_nodes for ntype in data.node_types}
    node_type_offsets = {}
    offset = 0
    for ntype in data.node_types:
        node_type_offsets[ntype] = offset
        offset += num_nodes_dict[ntype]

def get_global_node_id(node_type, node_id, node_offsets):
    """
    Get the global node ID for a given node type and local node ID.
    """
    return node_offsets[node_type] + node_id

def build_intra_edge_index(data: torch_geometric.data.Data, node_offsets: dict):
    """
    Build intra-edge indices for each layer in the multiplex graph.
    """
    intra_edge_indices = []
    for layer_id, edge_type in enumerate(data.edge_types):
        src_type, _, dst_type = edge_type
        edge_index = data[edge_type].edge_index

        src_global = get_global_node_id(src_type, edge_index[0], node_offsets)
        dst_global = get_global_node_id(dst_type, edge_index[1], node_offsets)

        src_layer = src_global + layer_id * data.num_nodes
        dst_layer = dst_global + layer_id * data.num_nodes

        intra_edge_indices.appned(torch.stack([src_layer, dst_layer], dim=0))
        return intra_edge_indices

def build_inter_edge_index(data: torch_geometric.data.Data):
    """
    Build inter-edge indices for the multiplex graph.
    """
    inter_src = []
    inter_dst = []
    for node_id in range(data.num_nodes):
        for layer_id_1 in range(len(data.edge_types)):
            for layer_id_2 in range(layer_id_1 + 1, len(data.edge_types)):
                if layer_id_1 != layer_id_2:
                    n1 = node_id + layer_id_1 * data.num_nodes
                    n2 = node_id + layer_id_2 * data.num_nodes
                    inter_src.extend([n1, n2])
                    inter_dst.extend([n2, n1]) # bidirectional edges
    inter_edge_index = torch.tensor([inter_src, inter_dst], dtype=torch.long)  
    return inter_edge_index

def get_node_features(data: torch_geometric.data.Data):
    """
    Get node features for the multiplex graph based on original node features from heterogenous graph.
    """
    feat_dims = data.num_features 
    node_feat_list = []
    for node_type in data.node_types:
        if feat_dims[node_type] == 0: # no features for this node type
            num_nodes = data[node_type].num_nodes
            x = torch.zeros((num_nodes, max(feat_dims.values())), dtype=torch.float)
        else:
            x = data[node_type].x # Assuming all node types with features have the same dimension
        node_feat_list.append(x)

    global_node_features = torch.cat(node_feat_list, dim=0) # obtain global node features
    multiplex_node_features = global_node_features.repeat(len(data.edge_types), 1) # replicate features across layers
    return multiplex_node_features

def get_labels(data: torch_geometric.data.Data):
    """
    Get labels and masks for the multiplex graph based on the original data.
    """
    global_labels = []
    for node_type in data.node_types:
        if 'y' in data[node_type].keys():
            labels = data[node_type].y
        else:
            labels = torch.zeros(data[node_type].num_nodes, dtype=torch.long)
        global_labels.append(labels)
    global_labels = torch.cat(global_labels, dim=0)
    multiplex_labels = global_labels.repeat(len(data.edge_types)) # replicate labels across layers
    return multiplex_labels

def masks(data: torch_geometric.data.Data):
    """
    Get masks for training, validation, and testing based on the original data.
    """
    train_mask = []
    val_mask = []
    test_mask = []
    for node_type in data.node_types:
        if 'train_mask' in data[node_type].keys():
            train_mask.append(data[node_type].train_mask)
        else:
            train_mask.append(torch.zeros(data[node_type].num_nodes, dtype=torch.bool))

        if 'val_mask' in data[node_type].keys():
            val_mask.append(data[node_type].val_mask)
        else:
            val_mask.append(torch.zeros(data[node_type].num_nodes, dtype=torch.bool))

        if 'test_mask' in data[node_type].keys():
            test_mask.append(data[node_type].test_mask)
        else:
            test_mask.append(torch.zeros(data[node_type].num_nodes, dtype=torch.bool))

    multiplex_train_mask = train_mask.repeat(len(data.edge_types))
    multiplex_val_mask = val_mask.repeat(len(data.edge_types))
    multiplex_test_mask = test_mask.repeat(len(data.edge_types))

    return multiplex_train_mask, multiplex_val_mask, multiplex_test_mask
            

def build_multiplex_graph(data: torch_geometric.data.Data) -> HeteroData:
    """
    Build a multiplex graph from the heterogenous graph data.
    """
    node_offsets = create_global_node_ids(data)
    
    intra_edge_index = build_intra_edge_index(data, node_offsets)
    inter_edge_index = build_inter_edge_index(data)

    multiplex_data = HeteroData()
    multiplex_data['intra_layer'].edge_index = intra_edge_index
    multiplex_data['inter_layer'].edge_index = inter_edge_index
    multiplex_data.x = get_node_features(data)

    return multiplex_data
            