In [160]:
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
from networks import FractalNet, FractalNetShared, Net, GNN_no_rel, GNN
from subgraph import Graph_to_Subgraph
from train import train_model, get_qm9
from train import train_model, get_qm9
from subgraph import Subgraph
from layers import MP, SimpleMP
from utils import catch_lone_sender
from torch_geometric.utils import remove_isolated_nodes
from torch_geometric.nn import GCNConv

In [161]:
# Load data
dataset = get_qm9('./data')
# get a sample data
data = dataset[0]
sample_data = data[0].to('cpu')

  warn("Using non-standard permutation since permute.pt does not exist.")


In [162]:
# find the smallest data and set it as sample_data
for data_point in data:
    if data_point.num_nodes < sample_data.num_nodes:
        # stop if the number of num nodes is less than 5
        if data_point.num_nodes < 6:
            break
        sample_data = data_point.to('cpu')

In [164]:
sample_data

Data(x=[8, 11], edge_index=[2, 16], edge_attr=[16, 4], y=[1, 19], pos=[8, 3], idx=[1], name='gdb_144', z=[8])

In [165]:
subgraph_example = Subgraph(sample_data, mode='fractal').convert_to_subgraph()

In [166]:
# print all the statistics
print('Number of nodes: ', subgraph_example.x.shape)
print('Number of edges: ', subgraph_example.edge_index.shape)
print('Subedge index', subgraph_example.subgraph_edge_index.shape)
# print the maximum value in the subedge index
print('Maximum value in subedge index: ', torch.max(subgraph_example.subgraph_edge_index))

Number of nodes:  torch.Size([72, 5])
Number of edges:  torch.Size([2, 16])
Subedge index torch.Size([2, 128])
Maximum value in subedge index:  tensor(71)


In [190]:
from torch_geometric.utils import add_self_loops, degree
from torch.nn import Linear, Parameter
from torch_geometric.nn.conv import MessagePassing

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
       # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

In [191]:
# Initialize a one MP layer
node_features = 5
hidden_features = 2
edge_features = 0

x = subgraph_example.x.clone()
edge_index = subgraph_example.edge_index
subgraph_edge_index = subgraph_example.subgraph_edge_index
node_subnode_index = subgraph_example.node_subnode_index
subnode_node_index = subgraph_example.subnode_node_index
ground_node = subgraph_example.ground_node

MP_layer = MP(hidden_features, edge_features, hidden_features, hidden_features)
GCN_layer = GCNConv(hidden_features, hidden_features)
embedding = nn.Linear(node_features, hidden_features)

x = embedding(x)
num_nodes = x.shape[0]

In [192]:
x_original = x.clone()
#print('x_original', x_original)
update_mask = catch_lone_sender(edge_index, num_nodes)
x_backup = x[~update_mask]
x = MP_layer(x, edge_index) +3
x[~update_mask] = x_backup

In [193]:
x

tensor([[ 2.5152,  2.4932],
        [ 2.6387,  2.7755],
        [ 2.6122,  2.8730],
        [ 2.6146,  2.8703],
        [ 2.5930,  2.6250],
        [ 2.5305,  2.5506],
        [ 2.5305,  2.5506],
        [ 2.5305,  2.5506],
        [ 0.5239, -0.7072],
        [ 0.4755, -0.0507],
        [ 0.4755, -0.0507],
        [ 0.4755, -0.0507],
        [ 0.7347, -0.3035],
        [ 0.5418, -0.4718],
        [ 0.5418, -0.4718],
        [ 0.5418, -0.4718],
        [ 0.5239, -0.7072],
        [ 0.4755, -0.0507],
        [ 0.4755, -0.0507],
        [ 0.4755, -0.0507],
        [ 0.7347, -0.3035],
        [ 0.5418, -0.4718],
        [ 0.5418, -0.4718],
        [ 0.5418, -0.4718],
        [ 0.5239, -0.7072],
        [ 0.4755, -0.0507],
        [ 0.4755, -0.0507],
        [ 0.4755, -0.0507],
        [ 0.7347, -0.3035],
        [ 0.5418, -0.4718],
        [ 0.5418, -0.4718],
        [ 0.5418, -0.4718],
        [ 0.5239, -0.7072],
        [ 0.4755, -0.0507],
        [ 0.4755, -0.0507],
        [ 0.4755, -0

In [194]:
update_mask = catch_lone_sender(node_subnode_index, num_nodes)
x_backup = x[~update_mask]
x = MP_layer(x, node_subnode_index)
x[~update_mask] = x_backup

In [195]:
x

tensor([[ 2.5152,  2.4932],
        [ 2.6387,  2.7755],
        [ 2.6122,  2.8730],
        [ 2.6146,  2.8703],
        [ 2.5930,  2.6250],
        [ 2.5305,  2.5506],
        [ 2.5305,  2.5506],
        [ 2.5305,  2.5506],
        [-0.3770, -0.3682],
        [-0.3386, -0.1820],
        [-0.3386, -0.1820],
        [-0.3386, -0.1820],
        [-0.3774, -0.3369],
        [-0.3618, -0.3108],
        [-0.3618, -0.3108],
        [-0.3618, -0.3108],
        [-0.3681, -0.3568],
        [-0.3424, -0.1777],
        [-0.3424, -0.1777],
        [-0.3424, -0.1777],
        [-0.3685, -0.3254],
        [-0.3529, -0.2993],
        [-0.3529, -0.2993],
        [-0.3529, -0.2993],
        [-0.3671, -0.3555],
        [-0.3429, -0.1772],
        [-0.3429, -0.1772],
        [-0.3429, -0.1772],
        [-0.3675, -0.3241],
        [-0.3518, -0.2980],
        [-0.3518, -0.2980],
        [-0.3518, -0.2980],
        [-0.3671, -0.3554],
        [-0.3429, -0.1772],
        [-0.3429, -0.1772],
        [-0.3429, -0

In [159]:
subgraph_example.x

tensor([[0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.]])

In [16]:
x = MP_layer(x, subgraph_edge_index)
x

tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [1.3722, 0.6890],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [1.3725, 0.6894],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [1.3725, 0.6894],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [1.3725, 0.6894],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [1.3725, 0.6894],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480],
        [0.3251, 0.1480]], grad_fn=<ScatterAddBackward0>)

In [17]:
x = MP_layer(x, subnode_node_index)
x

tensor([[1.6612, 0.7883],
        [1.6612, 0.7883],
        [1.6612, 0.7883],
        [1.6612, 0.7883],
        [1.6612, 0.7883],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000]], grad_fn=<ScatterAddBackward0>)

In [20]:
# forward pass with edge_index
x_subnode_original = x[~ground_node]
filtered_edge_index, _, mask = remove_isolated_nodes(edge_index)
x[ground_node] = MP_layer(x[ground_node], filtered_edge_index)
# print both of them
print('Are they the same tensors: ', torch.equal(x[~ground_node], x_subnode_original))
print('Node states: ', x)
# check if x[~ground_node] and x_subnode_original are the same pytorch tensors
x_node_original = x[ground_node]

Are they the same tensors:  True
Node states:  tensor([[1.2365, 0.2394],
        [1.2365, 0.2394],
        [1.2365, 0.2394],
        [0.6183, 0.1197],
        [0.6183, 0.1197],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000]], grad_fn=<IndexPutBackward0>)


In [93]:
x = MP_layer(x, node_subnode_index)
# check if x[ground_node] and x_node_original are the same pytorch tensors
print('Are they the same tensors: ', torch.equal(x, x_node_original))
print('Node states: ', x)

Are they the same tensors:  False
Node states:  tensor([[ 0.2830, -0.3501],
        [ 0.2766, -0.3489],
        [ 0.2734, -0.3482],
        [ 0.2730, -0.3482],
        [ 0.2730, -0.3482],
        [ 0.2730, -0.3482],
        [ 0.2972, -0.3940],
        [ 0.2972, -0.3940],
        [ 0.2964, -0.3933],
        [ 0.2902, -0.3878],
        [ 0.2902, -0.3878],
        [ 0.2902, -0.3878],
        [ 0.2972, -0.3940],
        [ 0.2972, -0.3940],
        [ 0.2964, -0.3933],
        [ 0.2902, -0.3878],
        [ 0.2902, -0.3878],
        [ 0.2902, -0.3878],
        [ 0.2972, -0.3940],
        [ 0.2972, -0.3940],
        [ 0.2964, -0.3933],
        [ 0.2902, -0.3878],
        [ 0.2902, -0.3878],
        [ 0.2902, -0.3878],
        [ 0.2972, -0.3940],
        [ 0.2972, -0.3940],
        [ 0.2964, -0.3933],
        [ 0.2902, -0.3878],
        [ 0.2902, -0.3878],
        [ 0.2902, -0.3878],
        [ 0.2972, -0.3940],
        [ 0.2972, -0.3940],
        [ 0.2964, -0.3933],
        [ 0.2902, -0.3878],


In [29]:
ground_node_states = MP_layer(x, subnode_node_index)[ground_node]
subnode_node_states = MP_layer(x, subnode_node_index)[~ground_node]
# print both of them
print('Ground node states: ', ground_node_states)
print('Subnode node states: ', subnode_node_states)

Ground node states:  tensor([[-0.0924, -1.3850],
        [-0.0924, -1.3850],
        [-0.2471, -1.0797],
        [-0.2157, -1.1582],
        [-0.2157, -1.1582],
        [-0.2157, -1.1582]], grad_fn=<IndexBackward0>)
Subnode node states:  tensor([[ 0.0076, -0.8455],
        [ 0.0076, -0.8455],
        [-0.1285, -0.5832],
        [-0.1175, -0.7103],
        [-0.1175, -0.7103],
        [-0.1175, -0.7103],
        [ 0.0076, -0.8455],
        [ 0.0076, -0.8455],
        [-0.1285, -0.5832],
        [-0.1175, -0.7103],
        [-0.1175, -0.7103],
        [-0.1175, -0.7103],
        [ 0.0076, -0.8455],
        [ 0.0076, -0.8455],
        [-0.1285, -0.5832],
        [-0.1175, -0.7103],
        [-0.1175, -0.7103],
        [-0.1175, -0.7103],
        [ 0.0076, -0.8455],
        [ 0.0076, -0.8455],
        [-0.1285, -0.5832],
        [-0.1175, -0.7103],
        [-0.1175, -0.7103],
        [-0.1175, -0.7103],
        [ 0.0076, -0.8455],
        [ 0.0076, -0.8455],
        [-0.1285, -0.5832],
      