In [5]:
import torch
import torch.nn as nn

# Define the dimensions
num_nodes = 6
node_feature_dim = 5
edge_feature_dim = 1
num_edge_types = 3

# Create random tensors
h = torch.rand(num_nodes, node_feature_dim)  # Features of all nodes
# Multiply last element by 1000
# Example edge index (indicating edges from node pairs)
edge_index = torch.tensor([
    [0, 2],
    [2, 3],
    [4, 5],
    [1, 3]
])

num_edges = edge_index.size(0)

# Extract features of source and target nodes for each edge
h_i = h[edge_index[:, 0]]
h_j = h[edge_index[:, 1]]

# Random edge features (let's assume it's a scalar for each edge)
dists = torch.rand(num_edges, edge_feature_dim)

# Random edge types (either 0, 1, or 2 in this example)
edge_type = torch.randint(0, num_edge_types, (num_edges,))
print('edge_type is', edge_type)
# Concatenate node features and edge feature
msg_inputs = torch.cat([h_i, h_j, dists], dim=-1)

# Create a list of random MLPs for demonstration
mlps_msg = [nn.Sequential(nn.Linear(node_feature_dim * 2 + edge_feature_dim, node_feature_dim), nn.ReLU()) for _ in range(num_edge_types)]

# Apply each MLP to the concatenated inputs
msg = torch.stack([mlp(msg_inputs) for mlp in mlps_msg], dim=0)
print('msg shape is', msg.shape)
print('msg is', msg)
# Index to get the right message for each edge based on its type
msg = msg[edge_type, torch.arange(edge_type.size(0))]

print(msg)


edge_type is tensor([0, 1, 2, 1])
msg shape is torch.Size([3, 4, 5])
msg is tensor([[[5.8993e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [5.8809e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.6389e-01],
         [2.2130e+02, 1.8343e+02, 8.1066e+00, 0.0000e+00, 0.0000e+00],
         [3.8427e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

        [[5.6941e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.0815e-02],
         [1.0759e-01, 0.0000e+00, 0.0000e+00, 2.0461e-01, 1.3025e-01],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 3.8301e+02, 0.0000e+00],
         [1.2436e-01, 0.0000e+00, 0.0000e+00, 5.4822e-01, 0.0000e+00]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.4998e-01],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9538e-01, 3.7851e-01],
         [5.1898e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 2.1061e-01, 1.1131e-01, 6.1730e-01]]],
       grad_fn=<StackBackward0>)
tensor([[0.5899, 0.0000, 0.0000, 