In [None]:
import torch
from torch_geometric.nn import HeteroConv, GATv2Conv
from torch_geometric.data import HeteroData

In [None]:
hyperparameters = {
    "hidden_dim": 64,
    "out_dim": 32,
    "n_heads": 2,
    "dropout": 0.1,
}

In [None]:
class HeteroGNN(torch.nn.Module):
    def __init__(
        self,
        metadata,
        hidden_dim: int,
        out_dim: int,
        n_heads: int,
        dropout: float,
    ):
        super().__init__()

        self.conv1 = HeteroConv(
            {
                edge_type: GATv2Conv(
                    in_channels=(-1, -1),
                    out_channels=hidden_dim,
                    edge_dim=13,
                    heads=n_heads,
                    dropout=dropout,
                    residual=False,
                    add_self_loops=False,
                    concat=True,
                )
                for edge_type in metadata[1]
            },
            aggr="sum",
        )

        self.conv2 = HeteroConv(
            {
                edge_type: GATv2Conv(
                    in_channels=(-1, -1),
                    out_channels=hidden_dim,
                    edge_dim=13,
                    heads=n_heads,
                    dropout=dropout,
                    residual=False,
                    add_self_loops=False,
                    concat=False,
                )
                for edge_type in metadata[1]
            },
            aggr="sum",
        )

        self.linear = torch.nn.ModuleDict(
            {
                node_type: torch.nn.Linear(hidden_dim, out_dim)
                for node_type in metadata[0]
            }
        )
        self.dropout = torch.nn.Dropout(dropout)
        self.gelu = torch.nn.GELU()

    def forward(
        self,
        x_dict,
        edge_index_dict,
        edge_attr_dict,
    ):
        x_dict = self.conv1(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        x_dict = {k: self.gelu(v) for k, v in x_dict.items()}
        x_dict = {k: self.dropout(v) for k, v in x_dict.items()}

        x_dict = self.conv2(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        x_dict = {k: self.gelu(v) for k, v in x_dict.items()}
        x_dict = {k: self.dropout(v) for k, v in x_dict.items()}

        out_dict = {k: self.linear[k](v) for k, v in x_dict.items()}
        return out_dict