In [4]:
from torch_geometric.datasets import QM9
# Helper function for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
import logging
import time

import os.path as osp

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import GRU, Linear, ReLU, Sequential
import torch.optim as optim

import torch_geometric.transforms as T
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, Set2Set, MessagePassing, global_mean_pool, aggr
from torch_geometric.utils import remove_self_loops
from torch_geometric.data import Data

In [2]:
target = 0 # the first property is the one to be predicted

class TargetTransform:
    def __call__(self, data):
        # Specify target.
        data.y = data.y[:, target]
        return data

path = './datasets/QM9'
transform = T.Compose([TargetTransform(), T.Distance(norm=False)]) # add the distance into edge attributes
dataset = QM9(path, transform=transform)

In [3]:
data = dataset[0]
print('After transformation:')
print(data)
print(data.edge_index)
print(data.edge_attr)
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

After transformation:
Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 5], y=[1], pos=[5, 3], idx=[1], name='gdb_1', z=[5])
tensor([[0, 0, 0, 0, 1, 2, 3, 4],
        [1, 2, 3, 4, 0, 0, 0, 0]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 1.0919],
        [1.0000, 0.0000, 0.0000, 0.0000, 1.0919],
        [1.0000, 0.0000, 0.0000, 0.0000, 1.0919],
        [1.0000, 0.0000, 0.0000, 0.0000, 1.0919],
        [1.0000, 0.0000, 0.0000, 0.0000, 1.0919],
        [1.0000, 0.0000, 0.0000, 0.0000, 1.0919],
        [1.0000, 0.0000, 0.0000, 0.0000, 1.0919],
        [1.0000, 0.0000, 0.0000, 0.0000, 1.0919]])
Number of nodes: 5
Number of edges: 8
Average node degree: 1.60
Has isolated nodes: False
Has self-loops: False
Is undirected: True


By defining the message() and update() functions, you control how information is propagated through the graph and how nodes update their representations based on the received messages. These functions enable the GNN to learn from the graph structure and capture useful information for the prediction task.

During the forward() pass, the GNN iteratively performs message passing and aggregation steps across the graph, combining information from neighboring nodes to update each node's representation. The output of the forward() function provides the final representations that can be further processed or used for downstream tasks

In [37]:
model_name = 'gnn_details'
config = {
    "in_channels": 3,
    "out_channels": 1,
    "hidden_dim_1": 4,
    "hidden_dim_2": 5,
    "hidden_dim_3": 32,
    "lr": 0.001,
    "batch_size": 20,
    "num_epochs": 3,
    "criterion": nn.MSELoss(),
    }

In [47]:
# Split datasets.
torch.manual_seed(12345)
dataset = dataset.shuffle()
test_dataset = dataset[:10000]
val_dataset = dataset[10000:20000]
train_dataset = dataset[20000:]
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False) #10000
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False) #10000
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) #110831

In [48]:
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[0,4,2], [1,3,1], [2,-1,1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

In [49]:
data.x

tensor([[ 0.,  4.,  2.],
        [ 1.,  3.,  1.],
        [ 2., -1.,  1.]])

In [59]:
# Define the customized message passing layer
class GNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GNN, self).__init__(aggr='mean')
        self.lin = nn.Linear(in_channels, out_channels)
        self.lin2 = nn.Linear(5, out_channels)

    def forward(self, x, edge_index, edge_attr):
        print('forward')
        print('x',x)
        print('x.shape',x.shape)
        print('edge_index',edge_index)
        print('edge_index.shape',edge_index.shape)
        return self.propagate(edge_index, edge_attr=edge_attr,x=x)

    def message(self, x_j, edge_attr):
        print('message')
        print('x_j',x_j)
        print('x_j.shape',x_j.shape) # num_edges * node_features
        print('lin(x_j)',self.lin(x_j))
        print('lin(x_j).shape',self.lin(x_j).shape) # num_edges * 64
        #print('edge_attr ',self.lin2(edge_attr)) 
        #print('edge_attr.shape ',self.lin2(edge_attr).shape) # num_edges * 64
        return self.lin(x_j)#+self.lin2(edge_attr) # can use different operators and assign them different weights
        #represents the feature representations of the neighboring nodes.
    
    def update(self, aggr_out, x):
        print('update')
        print('x', x)
        print('lin(x)', self.lin(x))
        print('aggr_out',aggr_out)
        print('aggr_out.shape',aggr_out.shape)
        print('lin(x)+aggr_out',self.lin(x)+ aggr_out)
        return aggr_out+self.lin(x)

# Define the GNN-based model
class GNNModel(nn.Module):
    def __init__(self, in_channels, hidden_dim_1, hidden_dim_2, hidden_dim_3, out_channels):
        super(GNNModel, self).__init__()
        self.gnn = GNN(in_channels, hidden_dim_1)
        self.gnn2 = GNN(hidden_dim_1, hidden_dim_2)
        # Use a global sort aggregation:
        self.global_pool = aggr.MeanAggregation()
        self.fc1 = nn.Linear(hidden_dim_2, hidden_dim_3)
        self.fc2 = nn.Linear(hidden_dim_3, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = self.gnn(x, edge_index, edge_attr).relu()
        x = self.gnn2(x, edge_index, edge_attr).relu()
        x = self.global_pool(x,batch)
        x = self.fc1(x).relu()
        x = self.fc2(x)
        return x

In [60]:
model = GNNModel(config["in_channels"], config["hidden_dim_1"], config["hidden_dim_2"], config["hidden_dim_3"], config["out_channels"])
optimizer = optim.Adam(model.parameters(), lr=config["lr"])
print(model)

GNNModel(
  (gnn): GNN()
  (gnn2): GNN()
  (global_pool): MeanAggregation()
  (fc1): Linear(in_features=5, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=1, bias=True)
)


In [61]:
output = model(data).flatten()

forward
x tensor([[ 0.,  4.,  2.],
        [ 1.,  3.,  1.],
        [ 2., -1.,  1.]])
x.shape torch.Size([3, 3])
edge_index tensor([[0, 1, 1, 2],
        [1, 0, 2, 1]])
edge_index.shape torch.Size([2, 4])
message
x_j tensor([[ 0.,  4.,  2.],
        [ 1.,  3.,  1.],
        [ 1.,  3.,  1.],
        [ 2., -1.,  1.]])
x_j.shape torch.Size([4, 3])
lin(x_j) tensor([[ 1.8508,  0.8710, -0.9157, -1.1656],
        [ 1.1132,  0.1986, -0.9161, -1.3484],
        [ 1.1132,  0.1986, -0.9161, -1.3484],
        [ 0.2837,  0.1969,  0.9099,  0.1637]], grad_fn=<AddmmBackward0>)
lin(x_j).shape torch.Size([4, 4])
update
x tensor([[ 0.,  4.,  2.],
        [ 1.,  3.,  1.],
        [ 2., -1.,  1.]])
lin(x) tensor([[ 1.8508,  0.8710, -0.9157, -1.1656],
        [ 1.1132,  0.1986, -0.9161, -1.3484],
        [ 0.2837,  0.1969,  0.9099,  0.1637]], grad_fn=<AddmmBackward0>)
aggr_out tensor([[ 1.1132,  0.1986, -0.9161, -1.3484],
        [ 1.0673,  0.5340, -0.0029, -0.5010],
        [ 1.1132,  0.1986, -0.9161, -1.34

In [54]:
device = torch.device('cpu')

# Training loop
model.to(device)
model.train()

min_valid_loss = np.inf
loss_values = []
val_loss_values = []

for epoch in range(config["num_epochs"]):
    total_loss = 0
    for batch in val_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        output = model(batch).flatten()
        loss = config["criterion"](output, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs

    average_loss = total_loss  / len(val_loader.dataset)
    loss_values.append(average_loss)
    print(f'Epoch: {epoch}, Training Loss: {average_loss:.4f}')

x.shape  torch.Size([349, 11])
edge_index.shape  torch.Size([2, 736])
forward
message
x_j  tensor([[0., 1., 0.,  ..., 0., 0., 2.],
        [0., 1., 0.,  ..., 0., 0., 2.],
        [0., 1., 0.,  ..., 0., 0., 2.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])
x_j.shape  torch.Size([736, 11])
lin(x_j)  tensor([[ 1.5753, -0.8763,  1.1009,  ...,  0.5212,  0.6490,  2.5236],
        [ 1.5753, -0.8763,  1.1009,  ...,  0.5212,  0.6490,  2.5236],
        [ 1.5753, -0.8763,  1.1009,  ...,  0.5212,  0.6490,  2.5236],
        ...,
        [ 0.6298, -0.4255,  0.4990,  ...,  0.2184,  0.3819,  0.3936],
        [ 0.6298, -0.4255,  0.4990,  ...,  0.2184,  0.3819,  0.3936],
        [ 0.6298, -0.4255,  0.4990,  ...,  0.2184,  0.3819,  0.3936]],
       grad_fn=<AddmmBackward0>)
lin(x_j).shape  torch.Size([736, 64])
update
aggr_out  tensor([[ 1.3615, -0.7446,  0.8518,  ...,  0.1454,  0.7390,  1.4452],
        [ 1.6736, -0

x.shape  torch.Size([358, 11])
edge_index.shape  torch.Size([2, 742])
forward
message
x_j  tensor([[0., 1., 0.,  ..., 0., 0., 3.],
        [0., 1., 0.,  ..., 0., 0., 3.],
        [0., 1., 0.,  ..., 0., 0., 3.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])
x_j.shape  torch.Size([742, 11])
lin(x_j)  tensor([[ 1.4758, -0.8939,  1.1788,  ...,  0.7621,  0.4967,  2.7443],
        [ 1.4758, -0.8939,  1.1788,  ...,  0.7621,  0.4967,  2.7443],
        [ 1.4758, -0.8939,  1.1788,  ...,  0.7621,  0.4967,  2.7443],
        ...,
        [ 0.6302, -0.4255,  0.4994,  ...,  0.2183,  0.3823,  0.3939],
        [ 0.6302, -0.4255,  0.4994,  ...,  0.2183,  0.3823,  0.3939],
        [ 0.6302, -0.4255,  0.4994,  ...,  0.2183,  0.3823,  0.3939]],
       grad_fn=<AddmmBackward0>)
lin(x_j).shape  torch.Size([742, 64])
update
aggr_out  tensor([[ 0.9168, -0.5294,  0.6114,  ...,  0.1737,  0.5257,  0.8165],
        [ 1.7872, -0

x.shape  torch.Size([342, 11])
edge_index.shape  torch.Size([2, 696])
forward
message
x_j  tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])
x_j.shape  torch.Size([696, 11])
lin(x_j)  tensor([[ 2.4290, -1.2862,  1.4841,  ...,  0.1197,  1.2581,  2.9302],
        [ 2.4290, -1.2862,  1.4841,  ...,  0.1197,  1.2581,  2.9302],
        [ 1.7826, -0.8411,  0.9546,  ...,  0.0457,  0.9630,  2.0907],
        ...,
        [ 0.6310, -0.4255,  0.5004,  ...,  0.2188,  0.3833,  0.3948],
        [ 0.6310, -0.4255,  0.5004,  ...,  0.2188,  0.3833,  0.3948],
        [ 0.6310, -0.4255,  0.5004,  ...,  0.2188,  0.3833,  0.3948]],
       grad_fn=<AddmmBackward0>)
lin(x_j).shape  torch.Size([696, 64])
update
aggr_out  tensor([[ 1.2068, -0.6333,  0.7275,  ...,  0.1323,  0.6731,  1.2428],
        [ 1.8421, -0

x.shape  torch.Size([353, 11])
edge_index.shape  torch.Size([2, 748])
forward
message
x_j  tensor([[0., 1., 0.,  ..., 0., 0., 3.],
        [0., 1., 0.,  ..., 0., 0., 3.],
        [0., 1., 0.,  ..., 0., 0., 3.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])
x_j.shape  torch.Size([748, 11])
lin(x_j)  tensor([[ 1.4655, -0.8939,  1.1667,  ...,  0.7517,  0.4846,  2.7330],
        [ 1.4655, -0.8939,  1.1667,  ...,  0.7517,  0.4846,  2.7330],
        [ 1.4655, -0.8939,  1.1667,  ...,  0.7517,  0.4846,  2.7330],
        ...,
        [ 0.6284, -0.4255,  0.4973,  ...,  0.2158,  0.3803,  0.3920],
        [ 0.6284, -0.4255,  0.4973,  ...,  0.2158,  0.3803,  0.3920],
        [ 0.6284, -0.4255,  0.4973,  ...,  0.2158,  0.3803,  0.3920]],
       grad_fn=<AddmmBackward0>)
lin(x_j).shape  torch.Size([748, 64])
update
aggr_out  tensor([[ 0.8638, -0.5382,  0.6466,  ...,  0.2905,  0.4458,  0.9234],
        [ 1.3124, -0

x.shape  torch.Size([368, 11])
edge_index.shape  torch.Size([2, 746])
forward
message
x_j  tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])
x_j.shape  torch.Size([746, 11])
lin(x_j)  tensor([[ 2.4113, -1.2862,  1.4632,  ...,  0.0958,  1.2375,  2.9110],
        [ 2.4113, -1.2862,  1.4632,  ...,  0.0958,  1.2375,  2.9110],
        [ 2.4752, -1.3759,  1.5103,  ...,  0.0722,  1.3094,  2.2904],
        ...,
        [ 0.6242, -0.4255,  0.4923,  ...,  0.2110,  0.3754,  0.3875],
        [ 0.6242, -0.4255,  0.4923,  ...,  0.2110,  0.3754,  0.3875],
        [ 0.6242, -0.4255,  0.4923,  ...,  0.2110,  0.3754,  0.3875]],
       grad_fn=<AddmmBackward0>)
lin(x_j).shape  torch.Size([746, 64])
update
aggr_out  tensor([[ 1.5497, -0.9007,  1.0013,  ...,  0.1416,  0.8424,  1.3389],
        [ 2.0913, -1

lin(x_j).shape  torch.Size([750, 64])
update
aggr_out  tensor([[ 0.8853, -0.5338,  0.6231,  ...,  0.2255,  0.4800,  0.8645],
        [ 1.5293, -0.8929,  1.0607,  ...,  0.3821,  0.6978,  1.9771],
        [ 1.3195, -0.7143,  0.8400,  ...,  0.2494,  0.6533,  1.6586],
        ...,
        [ 2.2676, -1.4110,  1.6560,  ...,  0.5472,  0.9947,  2.7223],
        [ 2.4168, -1.2862,  1.4695,  ...,  0.0988,  1.2438,  2.9167],
        [ 1.6669, -0.8587,  1.0135,  ...,  0.2685,  0.7919,  2.2938]],
       grad_fn=<DivBackward0>)
aggr_out.shape  torch.Size([362, 64])
x.shape  torch.Size([352, 11])
edge_index.shape  torch.Size([2, 732])
forward
message
x_j  tensor([[0., 1., 0.,  ..., 0., 0., 3.],
        [0., 1., 0.,  ..., 0., 0., 3.],
        [0., 1., 0.,  ..., 0., 0., 3.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])
x_j.shape  torch.Size([732, 11])
lin(x_j)  tensor([[ 1.4500, -0.8939,  1.1487,  ...,  0.7358,  0.

KeyboardInterrupt: 