In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import numpy as np

In [None]:

class GraphConvNet(MessagePassing):
    def __init__(self, num_features, hidden_channels):
        super(GraphConvNet, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin1 = nn.Linear(num_features, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, hidden_channels)
        self.lin3 = nn.Linear(hidden_channels, 1)

    def forward(self, x, edge_index):
        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.relu(x)

        # Step 3: Start propagating messages.
        row, col = edge_index
        deg = degree(row, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        x = self.propagate(edge_index, x=x, norm=norm)

        # Step 4: Return output.
        x = self.lin3(x)
        x = F.relu(x)
        x = global_mean_pool(x, batch) # global mean pooling
        return x

    def message(self, x_j, norm):
        # Normalize node features.
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # No additional update.
        return aggr_out
