In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric as tg

In [10]:
# import QM9 dataset
from torch_geometric.datasets import QM9

In [28]:
# load some data from QM9
data = QM9('~/data/QM9') # this will download the dataset if it is not already present
sample_data = data[0]

In [111]:
class Subgraph:
    def __init__(self, graph, depth=1):
        self.num_nodes = graph.x.shape[0]
        self.subgraph = graph.clone()

    def convert_to_subgraph(self):
        self.add_subnode_features()
        self.add_node_flags()
        self.add_subnode_edges()
        self.add_node_subnode_edges()
        self.add_subgraph_batch_index()
        return self.subgraph

    def add_subnode_features(self):
        self.subgraph.x = self.subgraph.x.repeat(self.num_nodes+1,1)
        self.total_num_nodes = self.subgraph.x.shape[0]

    def add_node_flags(self):
        if hasattr(self.subgraph, 'x'):
            self.subgraph.ground_node = torch.arange(self.subgraph.x.shape[0]) < self.num_nodes
        else:
            print('No node features found. Please add node features first.')

    def add_subnode_edges(self):
        self.subgraph.subgraph_edge_index = self.subgraph.edge_index + self.num_nodes
        for subg in range(self.num_nodes):
            self.subgraph.subgraph_edge_index = torch.cat([self.subgraph.subgraph_edge_index, self.subgraph.edge_index + (subg+1)*self.num_nodes], dim=1)

    def add_node_subnode_edges(self):
        self.subgraph.node_subnode_index = torch.stack([torch.arange(self.num_nodes).repeat_interleave(self.num_nodes), torch.arange(self.num_nodes, self.total_num_nodes)], dim=0)

    def add_subgraph_batch_index(self):
        self.subgraph.subgraph_batch_index = torch.arange(self.num_nodes).repeat_interleave(self.num_nodes)

In [114]:
class FractalMP(tg.nn.MessagePassing):
    """Message Passing Neural Network Layer"""
    def __init__(
        self,
        node_features,
        edge_features,
        hidden_features,
        out_features,
        aggr="add",
        act=nn.ReLU,
        edge_inference=False,
    ):
        super().__init__(aggr=aggr)
        self.edge_inference = edge_inference
        self.message_net = nn.Sequential(
            nn.Linear(2 * node_features + edge_features, hidden_features),
            act(),
            nn.Linear(hidden_features, hidden_features),
            act(),
        )

        self.update_net = nn.Sequential(
            nn.Linear(node_features + hidden_features, hidden_features),
            act(),
            nn.Linear(hidden_features, out_features),
        )

        if edge_inference:
            self.edge_inferrer = nn.Sequential(
                nn.Linear(hidden_features, 1), nn.Sigmoid()
            )

    def forward(self, x, edge_index, subgraph_edge_index, node_subnode_index, ground_node, subgraph_batch_index, edge_attr=None):
        """Propagate"""
        x = self.embedding(x)
        x = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        x = self.propagate(node_subnode_index, x=x, edge_attr=edge_attr)
        x = self.propagate(subgraph_edge_index, x=x, edge_attr=edge_attr)
        # global pool over nodes whose ground node is false
        x[ground_node] = tg.nn.global_mean_pool(x[~ground_node], subgraph_batch_index)
        return x

    def message(self, x_i, x_j, edge_attr):
        """Send message with edge attributes"""
        input = [x_i, x_j, edge_attr]
        input = [val for val in input if val is not None]
        input = torch.cat(input, dim=-1)
        message = self.message_net(input)

        if self.edge_inference:
            message = message * self.edge_inferrer(message)
        return message

    def update(self, message, x):
        """Update node"""
        input = torch.cat((x, message), dim=-1)
        update = self.update_net(input)
        return update

In [None]:
class FractalNet(nn.Module):
    def __init__(self, node_features, edge_features, hidden_features, out_features, depth=1):
        super().__init__()
        self.depth = depth
        self.embedding = nn.Linear(node_features, hidden_features)
        self.fractal_mps = nn.ModuleList()
        for i in range(depth):
            self.fractal_mps.append(FractalMP(hidden_features, edge_features, hidden_features, hidden_features))
        self.output = nn.Linear(hidden_features, out_features)

    def forward(self, x, edge_index, subgraph_edge_index, node_subnode_index, ground_node, subgraph_batch_index, edge_attr=None):
        x = self.embedding(x)
        for i in range(self.depth):
            x = self.fractal_mps[i](x, edge_index, subgraph_edge_index, node_subnode_index, ground_node, subgraph_batch_index, edge_attr)
        x = self.output(x)
        # global pooling over nodes whose ground node is true
        x = tg.nn.global_mean_pool(x[ground_node], subgraph_batch_index)
        return x

In [115]:
sample_graph = Subgraph(sample_data).convert_to_subgraph()
node_features = sample_graph.x.shape[1]
edge_features = 0
hidden_features = 64
out_features = 64
layer = FractalMP(node_features, edge_features, hidden_features, out_features)
out = layer(sample_graph.x, sample_graph.edge_index, sample_graph.subgraph_edge_index, sample_graph.node_subnode_index, sample_graph.ground_node, sample_graph.subgraph_batch_index)

In [117]:
out[10]

tensor([-0.0263,  0.0020, -0.0888,  0.0458,  0.0152, -0.0808,  0.0896, -0.1155,
        -0.0826, -0.0725, -0.0338, -0.0385,  0.1004, -0.0594,  0.1465,  0.0600,
        -0.1349, -0.0576,  0.1031, -0.0441, -0.0286, -0.0843,  0.0871,  0.0326,
        -0.1241, -0.1079,  0.1021,  0.0963, -0.1220,  0.0195, -0.0154, -0.0978,
        -0.0361,  0.1233,  0.1183,  0.1136, -0.1251, -0.0688,  0.0894,  0.0901,
        -0.0143,  0.0708,  0.0143, -0.1047, -0.0889,  0.0551,  0.0380,  0.0014,
         0.0119, -0.0628,  0.0366, -0.0665, -0.0357,  0.0654,  0.0114,  0.0592,
        -0.0043,  0.1494,  0.0709,  0.0020,  0.1297,  0.0260,  0.0497,  0.1348],
       grad_fn=<SelectBackward0>)

In [None]:
# MISC
def graph_to_subgraph(graph, depth=1):
    num_nodes = graph.x.shape[0]
    subgraph = graph.clone()
    subgraph.x = subgraph.x.repeat(num_nodes+1,1) # Initialize subnodes by concatenation
    total_num_nodes = subgraph.x.shape[0]
    ground_node = torch.arange(subgraph.x.shape[0]) < num_nodes # Mask for whether the node is a subnode or not
    subgraph.ground_node = ground_node # Add to the Data object
    # Create subgraph edge index which is same as edge index but for every subgraph
    subgraph.subgraph_edge_index = subgraph.edge_index + num_nodes
    for subg in range(num_nodes):
        subgraph.subgraph_edge_index = torch.cat([subgraph.subgraph_edge_index, subgraph.edge_index + (subg+1)*num_nodes], dim=1)
    # Create edge index for directed edges between ground node and it's subnodes
    subgraph.interaction_index = torch.stack([torch.arange(num_nodes).repeat_interleave(num_nodes), torch.arange(num_nodes, total_num_nodes)], dim=0)