In [2]:
# Links: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data
from torch_geometric.data import Data
import torch

In [13]:
class PairData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_1':
            return self.x_1.size(0)
        if key == 'edge_index_2':
            return self.x_2.size(0)
        return super().__inc__(key, value, *args, **kwargs)

In [37]:
# Node features of shape (num_nodes, num_node_features) and type torch.float32
x_1 = torch.tensor([[0, 0, 0],
                    [1, 1, 1],
                    [2, 2, 2],
                    [3, 3, 3]], dtype=torch.float32)
x_2 = torch.tensor([[4, 4, 4],
                    [5, 5, 5],
                    [6, 6, 6]], dtype=torch.float32)

# Edge indices of shape (2, num_edges) and type torch.long
edge_index_1 = torch.tensor([[0, 1, 1, 2, 2, 3],
                             [1, 0, 2, 1, 3, 2]], dtype=torch.long)
edge_index_2 = torch.tensor([[0, 1, 1, 2],
                             [1, 0, 2, 1]], dtype=torch.long)

# Edge features of shape (num_edges, num_edge_features) and type torch.float32
edge_attr_1 = torch.tensor([[0],
                            [1],
                            [2],
                            [3],
                            [4],
                            [5]], dtype=torch.float32)
edge_attr_2 = torch.tensor([[6],
                            [7],
                            [8],
                            [9]], dtype=torch.float32)

# Pair label of shape (1,) and type torch.long
y = torch.tensor([1], dtype=torch.long)


data = PairData(x_1=x_1, edge_index_1=edge_index_1, edge_attr_1=edge_attr_1,  # Graph 1.
                x_2=x_2, edge_index_2=edge_index_2, edge_attr_2=edge_attr_2,  # Graph 2.
                y = y)

print(data.edge_attr_1)

tensor([[1., 1., 2., 2., 3., 3.]])


In [15]:
from torch_geometric.loader import DataLoader

In [32]:
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2, follow_batch=['x_1', 'x_2'])

for batch in loader:
    print(batch.x_2)
    print("Which nodes correspond to which graph:", batch.x_1_batch)
    print("Which nodes correspond to which graph:", batch.x_2_batch)

tensor([[4., 4., 4.],
        [5., 5., 5.],
        [6., 6., 6.],
        [4., 4., 4.],
        [5., 5., 5.],
        [6., 6., 6.]])
Which nodes correspond to which graph: tensor([0, 0, 0, 0, 1, 1, 1, 1])
Which nodes correspond to which graph: tensor([0, 0, 0, 1, 1, 1])
