In [13]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

class MessagePassing(Layer):
    def __init__(self, aggregate='sum', **kwargs):
        super(MessagePassing, self).__init__(**kwargs)
        self.aggregate_fn = self.get_aggregate_function(aggregate)

    def build(self, input_shape):
        super(MessagePassing, self).build(input_shape)

    def call(self, inputs, **kwargs):
        x, a = inputs
        return self.propagate(x, a)

    def propagate(self, x, a, **kwargs):
        # Compute messages
        messages = self.message(x, a, **kwargs)

        # Aggregate messages
        aggregated_messages = self.aggregate_fn(messages)

        # Update node embeddings
        updated_embeddings = self.update(aggregated_messages, x, **kwargs)

        return updated_embeddings

    def message(self, x, a, **kwargs):
        # Perform message passing using the adjacency matrix a
        messages = tf.matmul(a, x)

        return messages

    def update(self, aggregated_messages, x, **kwargs):
        # By default, the update function returns the aggregated messages as node embeddings
        return aggregated_messages

    @staticmethod
    def get_aggregate_function(aggregate):
        if aggregate == 'sum':
            return tf.reduce_sum
        elif aggregate == 'mean':
            return tf.reduce_mean
        elif aggregate == 'max':
            return tf.reduce_max
        elif aggregate == 'min':
            return tf.reduce_min
        elif aggregate == 'prod':
            return tf.reduce_prod
        else:
            raise ValueError("Unsupported aggregation function: {}".format(aggregate))

# Test the MessagePassing layer
import numpy as np

# Sample node features (latitude and longitude)
node_features = np.array([
    [37.7749, -122.4194],  # Node 1 (San Francisco)
    [34.0522, -118.2437],  # Node 2 (Los Angeles)
    [40.7128, -74.0060],   # Node 3 (New York)
    [41.8781, -87.6298]    # Node 4 (Chicago)
])

# Sample adjacency matrix
adjacency_matrix = np.array([
    [0, 1, 1, 0],  # Node 1 is connected to Node 2 and Node 3
    [1, 0, 0, 1],  # Node 2 is connected to Node 1 and Node 4
    [1, 0, 0, 0],  # Node 3 is connected to Node 1
    [0, 1, 0, 0]   # Node 4 is connected to Node 2
], dtype=np.float32)

print("updated_embeddings")
aggf=['min','sum','max','mean','prod']
# Create the message passing layer
for i in aggf:
     message_passing_layer = MessagePassing(aggregate=i)

     # Pass the inputs to the message passing layer
     updated_embeddings = message_passing_layer([node_features, adjacency_matrix])
     print(i," ",updated_embeddings)



updated_embeddings
min   tf.Tensor(-210.0492, shape=(), dtype=float32)
sum   tf.Tensor(-416.71692, shape=(), dtype=float32)
max   tf.Tensor(79.653, shape=(), dtype=float32)
mean   tf.Tensor(-52.089615, shape=(), dtype=float32)
prod   tf.Tensor(4477792000000000.0, shape=(), dtype=float32)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl import DGLGraph

# Define a simple MPNN layer
class MPNNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(MPNNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def message_func(self, edges):
        return {'msg': edges.src['h']}  # Send the node features as the message

    def reduce_func(self, nodes):
        return {'h': torch.sum(nodes.mailbox['msg'], dim=1)}  # Sum up received messages

    def forward(self, g, features):
        g.ndata['h'] = self.linear(features)  # Initialize node features
        g.update_all(self.message_func, self.reduce_func)  # Message passing
        return g.ndata.pop('h')  # Get the updated node features after message passing

# Define a simple graph and its node features
# The graph has 5 nodes and 4 edges: [(0, 1), (1, 2), (1, 3), (3, 4)]
# Node features: [node_0, node_1, node_2, node_3, node_4]
graph = DGLGraph()
graph.add_nodes(5)
graph.add_edges([0, 1, 1, 3], [1, 2, 3, 4])
node_features = torch.tensor([[0.1], [0.2], [0.3], [0.4], [0.5]], dtype=torch.float32)

# Define the Message Passing Neural Network
class MPNN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats):
        super(MPNN, self).__init__()
        self.layer1 = MPNNLayer(in_feats, hidden_feats)
        self.layer2 = MPNNLayer(hidden_feats, out_feats)

    def forward(self, g, features):
        x = F.relu(self.layer1(g, features))
        x = self.layer2(g, x)
        return x

# Create the MPNN model
mpnn_model = MPNN(in_feats=1, hidden_feats=2, out_feats=1)

# Perform message passing on the graph
output = mpnn_model(graph, node_features)

print("Updated node features after message passing:")
print(output)
