In [1]:
import torch
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T

reference: https://pytorch-geometric.readthedocs.io/en/latest/tutorial/heterogeneous.html

### Creating Heterogeneous Graph

In [2]:
data = HeteroData()

num_users = 10
num_items = 100
num_features_user = 5
num_features_item = 7
num_transactions = 100
num_features_transaction = 3

data['user'].x = torch.rand((num_users, num_features_user))
data['item'].x = torch.rand((num_items, num_features_item))

data['user', 'buys', 'item'].edge_index = torch.stack([
    torch.randint(high=num_users, size=(num_transactions,)),
    torch.randint(high=num_items, size=(num_transactions,))
])

data['user', 'buys', 'item'].edge_attr = torch.rand((num_transactions, num_features_transaction))

data = T.ToUndirected()(data)

print(data)

HeteroData(
  [1muser[0m={ x=[10, 5] },
  [1mitem[0m={ x=[100, 7] },
  [1m(user, buys, item)[0m={
    edge_index=[2, 100],
    edge_attr=[100, 3]
  },
  [1m(item, rev_buys, user)[0m={
    edge_index=[2, 100],
    edge_attr=[100, 3]
  }
)


### Creating Heterogeneous GNNs
PyG provides three ways to create models on heterogenrous graph data:
1. `torch_geometric.nn.to_hetero()` to convert model
2. `conv.HeteroConv` to define individual functions for different types
3. Deploy existing heterogeneous GNN operators

For me, the second one is more clear and easy to understand, so I will use it in the following exampe.

In [9]:
from torch.nn import Linear
from torch_geometric.nn import HeteroConv, SAGEConv, GATv2Conv

class HeteroGNN(torch.nn.Module):
    def __init__(self, user_dim, item_dim, transaction_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        
        self.lin_proj_user = Linear(user_dim, hidden_dim)
        self.lin_proj_item = Linear(item_dim, hidden_dim)
        
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                # ('user', 'buys', 'item'): SAGEConv(hidden_dim, hidden_dim),
                ('user', 'buys', 'item'): GATv2Conv(hidden_dim, hidden_dim, heads=2, edge_dim=transaction_dim),
                ('item', 'rev_buys', 'user'): GATv2Conv(hidden_dim, hidden_dim, heads=2, edge_dim=transaction_dim)
                # ('item', 'rev_buys', 'user'): GATv2Conv(hidden_dim, hidden_dim,)
            })
            self.convs.append(conv)
        
        self.trans_user = Linear(hidden_dim, output_dim)
        self.trans_item = Linear(hidden_dim, output_dim)
    
    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        # linear projections
        for node_type, x in x_dict.items():
            if node_type == 'user':
                x_dict[node_type] = self.lin_proj_user(x)
            elif node_type == 'item':
                x_dict[node_type] = self.lin_proj_item(x)
        
        # message passing convolutions
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict, **{'edge_attr_dict': edge_attr_dict})
            # x_dict = conv(x_dict, edge_index_dict, 
            #               **{'edge_type1': edge_attr_dict[('user', 'buys', 'item')],
            #                'edge_type2': edge_attr_dict[('item', 'rev_buys', 'user')]})
            # x_dict = conv(x_dict, edge_index_dict, **edge_attr_dict)
            # x_dict = conv(x_dict, edge_index_dict)
            x_dict = {node_type: x.relu() for node_type, x in x_dict.items()}
        
        # final transformation
        for node_type, x in x_dict.items():
            if node_type == 'user':
                x_dict[node_type] = self.trans_user(x)
            elif node_type == 'item':
                x_dict[node_type] = self.trans_item(x)
        
        return x_dict


model_kwargs = {
    'user_dim': num_features_user,
    'item_dim': num_features_item,
    'transaction_dim': num_features_transaction,
    'hidden_dim': 32,
    'output_dim': 32,
    'num_layers': 2
}
model = HeteroGNN(**model_kwargs)
print(model)

HeteroGNN(
  (lin_proj_user): Linear(in_features=5, out_features=32, bias=True)
  (lin_proj_item): Linear(in_features=7, out_features=32, bias=True)
  (convs): ModuleList(
    (0): HeteroConv(num_relations=2)
    (1): HeteroConv(num_relations=2)
  )
  (trans_user): Linear(in_features=32, out_features=32, bias=True)
  (trans_item): Linear(in_features=32, out_features=32, bias=True)
)


In [10]:
output = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict)

current edge type: ('user', 'buys', 'item')
kwargs_dict:
dict_keys(['edge_attr_dict'])


RuntimeError: index 59 is out of bounds for dimension 0 with size 10

In [None]:
print(output['item'].shape)
print(output['user'].shape)

In [12]:
data.edge_index_dict[('user', 'buys', 'item')]

tensor([[ 2,  2,  6,  4,  3,  0,  3,  8,  1,  7,  4,  4,  1,  4,  8,  1,  0,  3,
          7,  8,  6,  2,  6,  9,  9,  5,  7,  0,  8,  5,  1,  4,  6,  8,  2,  2,
          3,  3,  9,  8,  5,  4,  6,  7,  5,  6,  1,  3,  7,  1,  6,  6,  5,  7,
          5,  4,  9,  2,  3,  5,  5,  3,  9,  9,  9,  4,  5,  2,  4,  2,  7,  0,
          9,  6,  6,  7,  4,  7,  1,  9,  3,  5,  5,  2,  1,  8,  3,  9,  6,  2,
          3,  8,  0,  0,  6,  1,  2,  5,  5,  2],
        [59, 40,  0, 72, 19, 81, 72, 97, 75, 96, 86, 12, 69, 35, 40, 72, 73, 29,
         46, 18, 32, 49, 66, 37, 41, 86, 80, 70, 82,  2, 45, 28, 46, 46, 98, 61,
         61,  3, 19, 33, 39, 32, 21, 93, 92,  9, 31, 49, 90, 98, 25, 31, 74, 41,
         19, 52,  9, 81,  9, 84, 25, 57, 36, 51, 16, 38, 46, 94, 42, 61, 77, 41,
         43, 32, 98,  8,  7,  5, 32, 15, 24, 38, 44,  8, 31, 50, 42, 58, 99, 22,
         78, 89, 53, 19,  3, 61, 59, 73, 28, 15]])

In [13]:
data.edge_index_dict[('item', 'rev_buys', 'user')]

tensor([[59, 40,  0, 72, 19, 81, 72, 97, 75, 96, 86, 12, 69, 35, 40, 72, 73, 29,
         46, 18, 32, 49, 66, 37, 41, 86, 80, 70, 82,  2, 45, 28, 46, 46, 98, 61,
         61,  3, 19, 33, 39, 32, 21, 93, 92,  9, 31, 49, 90, 98, 25, 31, 74, 41,
         19, 52,  9, 81,  9, 84, 25, 57, 36, 51, 16, 38, 46, 94, 42, 61, 77, 41,
         43, 32, 98,  8,  7,  5, 32, 15, 24, 38, 44,  8, 31, 50, 42, 58, 99, 22,
         78, 89, 53, 19,  3, 61, 59, 73, 28, 15],
        [ 2,  2,  6,  4,  3,  0,  3,  8,  1,  7,  4,  4,  1,  4,  8,  1,  0,  3,
          7,  8,  6,  2,  6,  9,  9,  5,  7,  0,  8,  5,  1,  4,  6,  8,  2,  2,
          3,  3,  9,  8,  5,  4,  6,  7,  5,  6,  1,  3,  7,  1,  6,  6,  5,  7,
          5,  4,  9,  2,  3,  5,  5,  3,  9,  9,  9,  4,  5,  2,  4,  2,  7,  0,
          9,  6,  6,  7,  4,  7,  1,  9,  3,  5,  5,  2,  1,  8,  3,  9,  6,  2,
          3,  8,  0,  0,  6,  1,  2,  5,  5,  2]])