In [None]:
import torch
from torchviz import make_dot
nodes = torch.randint(0, 1000, (1, 5))  # Example batch of 1, with 5 nodes
edges = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]])  # Example edges (pairs of nodes)
types = torch.tensor([0, 1, 0, 1])  # Example edge types
model = GraphEncoder(node_vocab_size=1000, relation_vocab_size=500, gnn_layers=3, embedding_size=128, node_embedding_dim=256)
output = model(nodes, edges, types)
dot = make_dot(output, params=dict(model.named_parameters()))
dot.render("graph_encoder", format="png")  # Save as a PNG image


In [None]:
class GraphEncoder(nn.Module):
    def __init__(self, num_nodes, num_relations, gnn_layers, embedding_size, initilized_embedding, dropout_ratio=0.3):
        super(GraphEncoder, self).__init__()
        self.num_nodes = num_nodes
        self.num_relations = num_relations
        self.gnn_layers = gnn_layers
        self.embedding_size = embedding_size
        self.dropout_ratio = dropout_ratio

        self.node_embedding = nn.Embedding(num_nodes, embedding_size)
        self.node_embedding.from_pretrained(torch.from_numpy(np.load(initilized_embedding)), freeze=False)

        self.dropout = nn.Dropout(dropout_ratio)

        self.gnn = []
        for layer in range(gnn_layers):
            self.gnn.append(RGCNConv(embedding_size, embedding_size,num_relations=num_relations))  # if rgcn is too slow, you can use gcn
        self.gnn = ListModule(*self.gnn)

    def forward(self, nodes, edges, types):
        """
        :param nodes: Tensor, shape [batch_size, num_nodes]
        :param edges: List[List[edge_idx]], where each edge_idx is [2, num_edges]
        :param types: List[List[edge_types]], where each edge_types is [num_edges]
        """
        batch_size = nodes.size(0)
        device = nodes.device

        # (batch_size, num_nodes, output_size)
        node_embeddings = []
        for bid in range(batch_size):
            # Convert edges and types to tensors
            edge_index = torch.tensor(edges[bid], dtype=torch.long, device=device)  # Shape: [2, num_edges]
            edge_type = torch.tensor(types[bid], dtype=torch.long, device=device)  # Shape: [num_edges]

            embed = self.node_embedding(nodes[bid, :])

            for lidx, rgcn in enumerate(self.gnn):
                if lidx == len(self.gnn) - 1:
                    embed = rgcn(embed, edge_index=edge_index, edge_type=edge_type)
                else:
                    embed = self.dropout(F.relu(rgcn(embed, edge_index=edge_index, edge_type=edge_type)))

            node_embeddings.append(embed)

        node_embeddings = torch.stack(node_embeddings, 0)  # [batch_size, num_node, embedding_size]
        return node_embeddings