In [1]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import radius_graph
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch_geometric.nn import MessagePassing

In [2]:
"""
   2 samples(graphs), each one has 4 nodes, embedding size = 1 (each graph is represented by 4 x 1 vector)
   x_dis define the location relationship between different nodes
   
"""
data = Data()
batch_size = 2
num_nodes = 4 
x_dis = torch.arange(batch_size*num_nodes).view(batch_size*num_nodes,-1).float()
print(x_dis)

tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.]])


In [3]:
"""
   Pytorch Geometric contradicts with Pytorch tradtion in terms of input data dimension
   For example, Pytorch follows N x C x E tradition, N:batch_size, C:channel_num, E: embedding size
   Pytorch geometric follows (N*C) x E
   We need addition variable "batch" to assign each node to the corresponding sample
   See details (not necessary for understanding this part):
   https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html
"""
batch = torch.arange(batch_size).view(-1,1).repeat(1,num_nodes).view(-1)
print(batch)

tensor([0, 0, 0, 0, 1, 1, 1, 1])


In [4]:
"""
   radius_graph define connectivity between nodes based on designed x_dis
   see details: https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.RadiusGraph
"""

from torch_geometric.nn import radius_graph
edge_index = radius_graph(x_dis, 1, batch, loop = True)
print(edge_index)

tensor([[0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 5, 6, 7, 6, 7],
        [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7]])


In [5]:
"""
   Message Passing networks
   For details, see: 
   https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
   https://pytorch-geometric.readthedocs.io/en/latest/notes/sparse_tensor.html

"""


class MPNN(MessagePassing):
    def __init__(self):
        super(MPNN,self).__init__(aggr = 'add')
#         self.l1 = nn.Linear(1,5)
    
    def forward(self, x , edge_index):
        return self.propagate(edge_index, x = x)
    
    def message(self, x_j):
        return x_j

In [6]:
net = MPNN()

In [7]:
x = torch.arange(8).view(-1,1) # input data
output = net(x, edge_index)

In [8]:
print(x)

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7]])


In [9]:
print(output)

tensor([[ 1],
        [ 3],
        [ 6],
        [ 5],
        [ 9],
        [15],
        [18],
        [13]])
