In [None]:
import torch
from torch_geometric.data import Data
import numpy as np

In [2]:
# Returns a PyG Data object, takes in matrices for N nodes with M features each.
def create_graph(feature_matrix, adjacency_matrix, weight_matrix):
    
    # make them tensors
    if not torch.is_tensor(feature_matrix):
        feature_matrix = torch.tensor(feature_matrix, dtype=torch.float)
    if not torch.is_tensor(adjacency_matrix):
        adjacency_matrix = torch.tensor(adjacency_matrix, dtype=torch.float)
    if not torch.is_tensor(weight_matrix):
        weight_matrix = torch.tensor(weight_matrix, dtype=torch.float)

    # get indices where edges exist
    edge_index = torch.nonzero(adjacency_matrix).t()

    # Get corresponding weights for these edges
    edge_weights = weight_matrix[edge_index[0], edge_index[1]]

    graph = Data(
        x=feature_matrix,
        edge_index=edge_index,
        edge_attr=edge_weights
    )

    return graph

In [3]:
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x116042af0>

In [4]:
N = 4  # Number of nodes/agents
M = 3  # Number of features per node

# Create random feature matrix (N x M)
feature_matrix = np.random.randn(N, M)

# Create example adjacency matrix (N x N)
adjacency_matrix = np.array([
    [0, 1, 0, 1],
    [1, 0, 1, 0],
    [0, 1, 0, 1],
    [1, 0, 1, 0]
])

# Create example weight matrix (N x N)
weight_matrix = np.array([
    [0.0, 0.5, 0.0, 0.3],
    [0.5, 0.0, 0.2, 0.0],
    [0.0, 0.2, 0.0, 0.7],
    [0.3, 0.0, 0.7, 0.0]
])

In [5]:
graph = create_graph(feature_matrix, adjacency_matrix, weight_matrix)

In [6]:
# TODO: message passing using PyG, maybe some simple node classifier