In [2]:
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
from networks import FractalNet, FractalNetShared, Net, GNN_no_rel, GNN
from subgraph import Graph_to_Subgraph
from train import train_model, get_qm9
from train import train_model, get_qm9
from subgraph import Subgraph
from layers import MP

In [3]:
# Load data
dataset = get_qm9('./data')
# get a sample data
data = dataset[0]
sample_data = data[0].to('cpu')

  warn("Using non-standard permutation since permute.pt does not exist.")


In [4]:
# find the smallest data and set it as sample_data
for data_point in data:
    if data_point.num_nodes < sample_data.num_nodes:
        # stop if the number of num nodes is less than 5
        if data_point.num_nodes < 5:
            break
        sample_data = data_point.to('cpu')

In [5]:
sample_data

Data(x=[6, 11], edge_index=[2, 12], edge_attr=[12, 4], y=[1, 19], pos=[6, 3], idx=[1], name='gdb_174', z=[6])

In [5]:
subgraph_example = Subgraph(sample_data, mode='transformer_3').convert_to_subgraph()

In [6]:
# print all the statistics
print('Number of nodes: ', subgraph_example.x.shape)
print('Number of edges: ', subgraph_example.edge_index.shape)
print('Subedge index', subgraph_example.subgraph_edge_index.shape)
# print the maximum value in the subedge index
print('Maximum value in subedge index: ', torch.max(subgraph_example.subgraph_edge_index))

Number of nodes:  torch.Size([20, 8])
Number of edges:  torch.Size([2, 8])
Subedge index torch.Size([2, 45])
Maximum value in subedge index:  tensor(19)


In [15]:
# print the subgraph edge index
print(subgraph_example.subgraph_edge_index)
sample_data = subgraph_example

tensor([[ 5,  5,  5,  6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10, 10,
         11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16,
         17, 17, 17, 18, 18, 18, 19, 19, 19],
        [ 5,  6,  7,  5,  6,  7,  5,  6,  7,  8,  9, 10,  8,  9, 10,  8,  9, 10,
         11, 12, 13, 11, 12, 13, 11, 12, 13, 14, 15, 16, 14, 15, 16, 14, 15, 16,
         17, 18, 19, 17, 18, 19, 17, 18, 19]])


In [54]:
import torch.nn as nn
import torch_geometric as tg
import torch
import torch.nn.functional as F
from layers import FractalMP, MP
import torch_geometric.nn as geom_nn
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
import os
from utils import catch_lone_sender
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

class FractalNet(nn.Module):
    def __init__(self, node_features, edge_features, hidden_features, out_features, depth=1, pool="mean", add_residual_skip=False):
        super().__init__()
        self.depth = depth
        self.pool = pool
        self.add_residual_skip = add_residual_skip
        self.embedding = nn.Linear(node_features, hidden_features)
        self.ground_mps = nn.ModuleList()
        self.ground_to_sub_mps = nn.ModuleList()
        self.sub_mps = nn.ModuleList()
        self.sub_to_ground_mps = nn.ModuleList()
        for i in range(depth):
            self.ground_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
            self.ground_to_sub_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
            self.sub_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
            self.sub_to_ground_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
        self.output = nn.Linear(hidden_features, out_features)

    def forward(self, x, edge_index, subgraph_edge_index, node_subnode_index, subnode_node_index ,ground_node, subgraph_batch_index, batch_idx, edge_attr=None):
        num_nodes = x.shape[0]
        x = self.embedding(x)

        for i in range(self.depth):
            if self.add_residual_skip:
                x_0 = x

            update_mask = catch_lone_sender(edge_index, num_nodes)
            x_backup = x[~update_mask]
            # check whether x contains zeros
            # print states of the not ground nodes

            x_subnode_state = x[~ground_node]

            x = self.ground_mps[i](x, edge_index, edge_attr)
            # print sattes of the not ground nodes after mp
            # check whether the x subnode state is changed by comparing
            x[~update_mask] = x_backup
            x_subnode_state_after = x[~ground_node]
            print('states of the not ground nodes after mp: ', torch.equal(x_subnode_state, x_subnode_state_after))
            #x[~update_mask] = x_backup
            #TODO: Check the order of edge indices; directed in which direction? subnode to node or vice versa

            #update_mask = catch_lone_sender(node_subnode_index, num_nodes)
            #x_backup = x[~update_mask]

            update_mask = catch_lone_sender(node_subnode_index, num_nodes)
            x_backup = x[~update_mask]
            x_subnode_state = x[~ground_node]
            x = self.ground_to_sub_mps[i](x, node_subnode_index, edge_attr)
            x[~update_mask] = x_backup
            x_subnode_state_after = x[~ground_node]
            print('states of the not ground nodes after mp: ', torch.equal(x_subnode_state, x_subnode_state_after))

            x[~update_mask] = x_backup

            update_mask = catch_lone_sender(subgraph_edge_index, num_nodes)
            #x_backup = x[~update_mask]
            #
            x_subnode_state = x[~ground_node]
            x = self.sub_mps[i](x, subgraph_edge_index, edge_attr)
            x[~update_mask] = x_backup
            x_subnode_state_after = x[~ground_node]
            print('states of the not ground nodes after mp: ', torch.equal(x_subnode_state, x_subnode_state_after))

            x[~update_mask] = x_backup

            update_mask = catch_lone_sender(subnode_node_index, num_nodes)
            x_backup = x[~update_mask]
            #
            x_subnode_state = x[~ground_node]
            x = self.sub_to_ground_mps[i](x, subnode_node_index, edge_attr)
            x[~update_mask] = x_backup
            x_subnode_state_after = x[~ground_node]
            print('states of the not ground nodes after mp: ', torch.equal(x_subnode_state, x_subnode_state_after))
            #x[~update_mask] = x_backup

            if self.add_residual_skip:
                x = x + x_0
        # global pooling over nodes whose ground node is true
        if self.pool == "mean":
            x = tg.nn.global_mean_pool(x[ground_node], batch_idx)
        elif self.pool == "add":
            x = tg.nn.global_add_pool(x[ground_node], batch_idx)
        elif self.pool == "max":
            x = tg.nn.global_max_pool(x[ground_node], batch_idx)
        x = self.output(x)
        return x

In [None]:
class FractalNet(nn.Module):
    def __init__(self, node_features, edge_features, hidden_features, out_features, depth=1, pool="mean", add_residual_skip=False, masking=False):
        super().__init__()
        self.depth = depth
        self.pool = pool
        self.add_residual_skip = add_residual_skip
        self.masking = masking
        self.embedding = nn.Linear(node_features, hidden_features)
        self.ground_mps = nn.ModuleList()
        self.ground_to_sub_mps = nn.ModuleList()
        self.sub_mps = nn.ModuleList()
        self.sub_to_ground_mps = nn.ModuleList()
        for i in range(depth):
            self.ground_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
            self.ground_to_sub_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
            self.sub_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
            self.sub_to_ground_mps.append(MP(hidden_features, edge_features, hidden_features, hidden_features))
        self.output = nn.Linear(hidden_features, out_features)

    def forward(self, x, edge_index, subgraph_edge_index, node_subnode_index, subnode_node_index ,ground_node, subgraph_batch_index, batch_idx, edge_attr=None):
        num_nodes = x.shape[0]
        x = self.embedding(x)

        for i in range(self.depth):
            if self.add_residual_skip:
                x_0 = x

            update_mask = catch_lone_sender(edge_index, num_nodes)
            x_backup = x[~update_mask]
            x = self.ground_mps[i](x, edge_index, edge_attr)
            if self.masking:
                x[~update_mask] = x_backup
            #TODO: Check the order of edge indices; directed in which direction? subnode to node or vice versa

            update_mask = catch_lone_sender(node_subnode_index, num_nodes)
            x_backup = x[~update_mask]
            x = self.ground_to_sub_mps[i](x, node_subnode_index, edge_attr)
            if self.masking:
                x[~update_mask] = x_backup

            update_mask = catch_lone_sender(subgraph_edge_index, num_nodes)
            x_backup = x[~update_mask]
            x = self.sub_mps[i](x, subgraph_edge_index, edge_attr)
            if self.masking:
                x[~update_mask] = x_backup

            update_mask = catch_lone_sender(subnode_node_index, num_nodes)
            x_backup = x[~update_mask]
            x = self.sub_to_ground_mps[i](x, subnode_node_index, edge_attr)
            if self.masking:
                x[~update_mask] = x_backup

            if self.add_residual_skip:
                x = x + x_0
        # global pooling over nodes whose ground node is true
        if self.pool == "mean":
            x = tg.nn.global_mean_pool(x[ground_node], batch_idx)
        elif self.pool == "add":
            x = tg.nn.global_add_pool(x[ground_node], batch_idx)
        elif self.pool == "max":
            x = tg.nn.global_max_pool(x[ground_node], batch_idx)
        x = self.output(x)
        return x

In [55]:
# send the sample datapoint through the fractal network
# the idea is to check whether backup does what was intented (i.e. not changing the state of the not ground nodes when they are not participating in the MP. This is due to update net updating all nodes)
fractal_net = FractalNet(sample_data.x.shape[1],0, 64, 1, depth=1, pool=None, add_residual_skip=True)
fractal_net.forward(sample_data.x, sample_data.edge_index, sample_data.subgraph_edge_index, sample_data.node_subnode_index, sample_data.subnode_node_index, sample_data.ground_node, sample_data.subgraph_batch_index, None)

states of the not ground nodes after mp:  True
states of the not ground nodes after mp:  False
states of the not ground nodes after mp:  False
states of the not ground nodes after mp:  True


tensor([[ 0.0665],
        [-0.1118],
        [ 0.0665],
        [ 0.0665],
        [ 0.0665],
        [ 0.0333],
        [-0.1199],
        [-0.1755],
        [ 0.0332],
        [-0.1200],
        [-0.1756],
        [ 0.0333],
        [-0.1199],
        [-0.1755],
        [ 0.0333],
        [-0.1199],
        [-0.1755],
        [ 0.0333],
        [-0.1199],
        [-0.1755]], grad_fn=<AddmmBackward0>)