# Can STAGE Handle the Pressure? Testing Its Generalizability Across Graph Models
In this notebook, we provide an overview of our implementation of STAGE and the various model architectures we explored. Given the size of our codebase, we focus on showcasing the key components of the code. For clarity, we have copied the implementation of each model into this notebook and included explanations of their main components. 

**Note that the code cells are for reference only and are not executable. At the end of the notebook, we have included a bash script to facilitate the full training and evaluation of the models.**

## RGCN

### RGCNConv
The implementation of RGCNConv is in the file `nbfnet/edge_rgcn_conv.py`. However, to make the code more readable, we copied the implementation of RGCN to this notebook in the following code cell.

#### Incorporating Edge Features in RGCNConv
In traditional RGCNConv implementations, the focus is primarily on node features and relation types, while edge attributes (such as edge weights or embeddings) are often ignored or treated as static. For our use case, it is critical to incorporate edge features directly into the message-passing framework to enrich the representation power of the model.

This modified RGCNConv implementation, shown in the provided code snippet, extends the standard RGCNConv to handle edge embeddings. These are seamlessly integrated into the message-passing phase by:
1. Transforming Edge Embeddings: Depending on the edge_method, edge embeddings are either transformed via a learnable linear layer (edgegraph_mlp) or left unaltered.
2. Combining Messages and Edge Embeddings: Depending on the edge_method, the edge embeddings are either concatenated with the node messages or added directly to them.

#### Efficiency Improvements
This implementation introduces several optimizations to make the enhanced RGCNConv practical for large graphs:
1. Basis Decomposition: For graphs with many relation types, maintaining separate weight matrices for each type can be computationally prohibitive. Basis decomposition reduces the number of parameters while still allowing expressive modeling of relations.
2. Sparse Operations: By leveraging PyTorch Geometric’s scatter operations, the implementation avoids dense adjacency matrix computations, allowing it to scale efficiently to large, sparse graphs.

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch_scatter import scatter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot

class RGCNConv(MessagePassing):
    """
    A relational graph convolutional network (RGCN) layer with additional edge features.
    Supports edge embeddings, multiple aggregation strategies, and various update mechanisms.

    Args:
        input_dim (int): Dimension of input node features.
        output_dim (int): Dimension of output node features.
        num_relation (int): Number of relation types.
        aggregate_func (str): Aggregation function ('add', 'mean', 'max', etc.).
        layer_norm (bool): Whether to apply layer normalization.
        activation (str or Callable): Activation function to use (e.g., 'relu', 'tanh').
        num_bases (int, optional): Number of bases for relation weight decomposition. Defaults to 0.
        edge_method (str): Determines how edge features are used ('method1', 'method2', etc.).
        edge_embed_dim (int, optional): Dimension of edge embeddings. Defaults to None.
    """

    def __init__(
        self,
        input_dim,
        output_dim,
        num_relation,
        aggregate_func,
        layer_norm,
        activation,
        num_bases=0,
        edge_method="method1",
        edge_embed_dim=None,
    ):
        super(RGCNConv, self).__init__()

        # Initialize layer parameters
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_relation = num_relation
        self.aggregate_func = aggregate_func
        self.edge_embed_dim = edge_embed_dim
        self.num_bases = num_bases

        # Determine edge processing strategy
        if edge_method in ["method1", "method2"]:
            self.stage_method = "add"
        else:
            self.stage_method = "cat"
        self.transform_edge = edge_method in ["method2", "method4"]

        # Initialize layer normalization if specified
        self.layer_norm = nn.LayerNorm(output_dim) if layer_norm else None

        # Initialize activation function
        self.activation = getattr(F, activation) if isinstance(activation, str) else activation

        # Linear transformation for node updates
        self.lin_s = nn.Linear(input_dim, output_dim)
        nn.init.xavier_uniform_(self.lin_s.weight)

        # Linear transformation for edge and node combination
        if self.stage_method == "cat":
            assert edge_embed_dim is not None, "edge_embed_dim must be specified for concatenation"
            self.lin_f = nn.Linear(edge_embed_dim + input_dim, output_dim)
            nn.init.xavier_uniform_(self.lin_f.weight)
        elif self.stage_method == "add":
            self.lin_f = nn.Identity()

        # Edge embedding transformation
        if edge_embed_dim is not None:
            self.edgegraph_mlp = nn.Linear(edge_embed_dim, output_dim)
            nn.init.xavier_uniform_(self.edgegraph_mlp.weight)

        # Relation weight initialization
        if num_bases > 0:
            self.weight = nn.Parameter(torch.empty(num_bases, input_dim, output_dim))
            self.comp = nn.Parameter(torch.empty(num_relation, num_bases))
            glorot(self.weight)
            glorot(self.comp)
        else:
            self.lin_r = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_relation)])
            for lin in self.lin_r:
                nn.init.xavier_uniform_(lin.weight)

    def forward(self, input, edge_index, edge_type, edge_weight=None, edge_embed=None):
        """
        Forward pass for the EdgeRGCNConv layer.

        Args:
            input (torch.Tensor): Node features of shape (num_nodes, input_dim).
            edge_index (torch.Tensor): Edge indices of shape (2, num_edges).
            edge_type (torch.Tensor): Edge type indices of shape (num_edges,).
            edge_weight (torch.Tensor, optional): Edge weights of shape (num_edges,). Defaults to None.
            edge_embed (torch.Tensor, optional): Edge feature embeddings. Defaults to None.

        Returns:
            torch.Tensor: Updated node features of shape (num_nodes, output_dim).
        """
        num_node = input.size(0)

        # Default edge weight is 1 for all edges
        if edge_weight is None:
            edge_weight = torch.ones(len(edge_type), device=input.device)

        # Ensure edge embeddings are handled correctly
        edge_type = edge_type.to(torch.long)
        edge_embed = edge_embed if edge_embed is not None else 0

        # Perform message passing
        output = self.propagate(
            edge_index=edge_index,
            input=input,
            edge_type=edge_type,
            edge_embed=edge_embed,
            size=(num_node, num_node),
            edge_weight=edge_weight,
        )
        return output

    def message(self, input_j, edge_type, edge_embed):
        """
        Compute messages to be passed along edges.

        Args:
            input_j (torch.Tensor): Source node features of shape (num_edges, input_dim).
            edge_type (torch.Tensor): Edge type indices of shape (num_edges,).
            edge_embed (torch.Tensor): Edge feature embeddings.

        Returns:
            torch.Tensor: Messages of shape (num_edges, output_dim).
        """
        # Compute relation-specific messages
        num_edges, _ = input_j.size()
        message = torch.zeros((num_edges, self.output_dim), device=input_j.device)
        # compute weight if using basis decomposition
        if self.num_bases > 0:
            weight = (self.comp @ self.weight.view(self.num_bases, -1)).view(
                self.num_relation, self.input_dim, self.output_dim
            )
        for rel_type in range(self.num_relation):
            mask = (edge_type == rel_type).unsqueeze(-1)
            # rel_mapped: (batch_size, num_edges, self.output_dim)
            if self.num_bases > 0:
                rel_mapped = torch.matmul(
                    input_j,
                    weight[rel_type],
                )
            else:
                rel_mapped = self.lin_r[rel_type](input_j)
            message += rel_mapped * mask

        # Incorporate edge embeddings
        transformed_edge_embed = self.edgegraph_mlp(edge_embed) if self.transform_edge else edge_embed
        if self.stage_method == "cat":
            message = torch.cat([message, transformed_edge_embed], dim=-1)
        elif self.stage_method == "add":
            message += transformed_edge_embed
        else:
            raise ValueError(f"Unknown stage method {self.stage_method}")

        return message

    def aggregate(self, input, edge_weight, index, edge_type, dim_size):
        """
        Aggregate messages from neighbors.

        Args:
            input (torch.Tensor): Messages to aggregate.
            edge_weight (torch.Tensor): Edge weights for scaling messages.
            index (torch.Tensor): Target nodes of messages.
            edge_type (torch.Tensor): Types of edges in the graph.
            dim_size (int): Total number of nodes.

        Returns:
            torch.Tensor: Aggregated messages.
        """
        # Scale messages by edge weights
        shape = [1] * input.ndim
        shape[0] = -1
        edge_weight = edge_weight.view(shape)
        weighted_input = input * edge_weight

        # Aggregate messages based on edge type
        output = torch.zeros((dim_size, self.output_dim), device=input.device)
        for rel_type in range(self.num_relation):
            mask = edge_type == rel_type
            output += scatter(
                weighted_input[mask],
                index[mask],
                dim=0,
                dim_size=dim_size,
                reduce=self.aggregate_func,
            )
        return output

    def update(self, update, input):
        """
        Update node features after aggregation.

        Args:
            update (torch.Tensor): Aggregated messages of shape (num_nodes, output_dim).
            input (torch.Tensor): Previous node features of shape (num_nodes, input_dim).

        Returns:
            torch.Tensor: Updated node features.
        """
        # Perform linear transformation and apply activation function
        output = self.lin_s(input) + self.lin_f(update)
        if self.layer_norm:
            output = self.layer_norm(output)
        if self.activation:
            output = self.activation(output)
        return output

### RGCN Model
The implementation of the RGCN model is in the file `nbfnet/rgcn.py`. The following code cell is a simplified version of the RGCN model for explaining purposes.

The model consists of multiple RGCN layers implemented using the `EdgeRGCNConv` class, which extends the standard RGCN to handle edge embeddings. Each layer processes input node features, edge indices, edge types, and optional edge attributes to generate updated node features. These features are computed through relational message-passing and aggregation, with optional shortcut connections to facilitate training deeper models.

The forward pass initializes node features with a placeholder tensor if no initial features are provided. It processes the graph structure (edge indices and types) and optional edge embeddings, propagating messages through the RGCN layers. The final output consists of node embeddings for triples in a batch, where each triple represents a source node, a relation, and a target node. The embeddings are then used to compute scores for link prediction tasks through dismult.

In [None]:
from dataclasses import dataclass
from torch_geometric.data import Data

@dataclass
class RGCNConfig:
    """
    Configuration class for the RGCN model.

    Attributes:
        input_dim (int): Dimension of input node features.
        num_layers (int): Number of RGCN layers.
        aggregate_func (str): Aggregation function (e.g., "mean", "add", "max").
        short_cut (int): Whether to use shortcut connections between layers.
        layer_norm (int): Whether to apply layer normalization after each layer.
        activation (str): Activation function to use (e.g., "relu").
        concat_hidden (int): Whether to concatenate outputs of all layers.
        num_bases (int): Number of bases for parameter sharing in relation modeling.
        use_stage (int): Whether to enable STAGE features (edge embeddings).
        edge_method (str): Method for handling edge embeddings.
    """

    input_dim: int = 256
    num_layers: int = 6
    aggregate_func: str = "mean"
    short_cut: int = 1
    layer_norm: int = 1
    activation: str = "relu"
    concat_hidden: int = 0
    num_bases: int = 0
    use_stage: int = 1
    edge_method: str = "method1"


class RGCN(nn.Module):
    """
    Relational Graph Convolutional Network (RGCN) with support for edge embeddings.

    Args:
        num_relation (int): Number of relation types.
        edge_embed_dim (int or None): Dimension of edge embeddings (if applicable).
        cfg (RGCNConfig): Configuration object for the RGCN.
    """

    def __init__(self, num_relation, edge_embed_dim, cfg: RGCNConfig):
        super(RGCN, self).__init__()

        # Disable edge embeddings if STAGE is not used
        if not cfg.use_stage:
            edge_embed_dim = None

        # Initialize attributes
        self.dims = [cfg.input_dim] * (cfg.num_layers + 1)  # Feature dimensions for all layers
        self.num_relation = num_relation
        self.short_cut = cfg.short_cut
        self.concat_hidden = cfg.concat_hidden
        self.edge_embed_dim = edge_embed_dim
        self.num_bases = cfg.num_bases

        # Define RGCN layers
        self.layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            self.layers.append(
                RGCNConv(
                    self.dims[i],
                    self.dims[i + 1],
                    num_relation,
                    cfg.aggregate_func,
                    cfg.layer_norm,
                    cfg.activation,
                    num_bases=cfg.num_bases,
                    edge_method=cfg.edge_method,
                    edge_embed_dim=edge_embed_dim,
                )
            )

        # Compute feature dimension if concatenating hidden states
        feature_dim = cfg.input_dim * cfg.num_layers

        # Embedding for relations
        self.relation_emb = nn.Embedding(num_relation, cfg.input_dim)
        nn.init.xavier_uniform_(self.relation_emb.weight, gain=nn.init.calculate_gain(cfg.activation))

        # Define final linear layer if concatenating hidden states
        if self.concat_hidden:
            self.final_linear = nn.Linear(feature_dim, cfg.input_dim)

        # Print number of parameters
        num_params = sum(p.numel() for p in self.parameters())
        print(f"Number of parameters in RGCN: {num_params}")

    def forward(self, data: Data, batch: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the RGCN.

        Args:
            data (Data): PyTorch Geometric Data object containing the graph structure and features.
            batch (torch.Tensor): Tensor of shape [batch_size, num_negative + 1, 3]
                                   containing (source node, relation, target node) triples.

        Returns:
            torch.Tensor: Scores for each triple in the batch.
        """
        # Initialize node features as ones (placeholder for featureless input)
        x = torch.ones((data.num_nodes, self.dims[0]), device=data.edge_index.device)

        # Retrieve edge information from the data object
        edge_index = data.edge_index  # Edge indices [2, num_edges]
        edge_type = data.original_edge_type  # Edge types [num_edges]

        # Retrieve edge embeddings if available
        if self.edge_embed_dim is not None:
            edge_embed = data.edge_embeddings  # Edge embeddings [num_edges, edge_embed_dim]
        else:
            edge_embed = None

        # Retrieve edge weights if provided
        edge_weight = data.edge_weight if hasattr(data, "edge_weight") else None

        # To store outputs of all layers for concatenation (if enabled)
        hidden_states = []

        # Pass input through each RGCN layer
        for layer in self.layers:
            new_x = layer.forward(x, edge_index, edge_type, edge_weight, edge_embed)

            # Apply shortcut connection if enabled
            if self.short_cut:
                new_x = new_x + x

            # Update current node features
            x = new_x
            hidden_states.append(x)

        # Concatenate hidden states if configured
        if self.concat_hidden:
            x = torch.cat(hidden_states, dim=-1)  # Concatenate along feature dimension
            x = self.final_linear(x)  # Reduce concatenated features to output dimension

        # Expand node embeddings for batch processing
        x = x.expand(batch.size(0), -1, -1)

        # Extract source and target node embeddings from the batch
        source_nodes = batch[:, :, 0].unsqueeze(-1).expand(-1, -1, x.size(-1))
        target_nodes = batch[:, :, 1].unsqueeze(-1).expand(-1, -1, x.size(-1))
        relations = batch[:, :, 2]  # Relation indices

        source_emb = x.gather(1, source_nodes)  # Source node embeddings
        target_emb = x.gather(1, target_nodes)  # Target node embeddings
        relation_emb = self.relation_emb(relations)  # Relation embeddings

        # Compute triple scores dismult
        score = torch.sum(source_emb * relation_emb * target_emb, dim=-1)

        return score

## CompGCN

### CompGCNConv

The implementation of CompGCNConv is in the file `nbfnet/compgcn/compgcn_conv.py`. However, to make the code more readable, we copied the implementation of CompGCNConv to this notebook in the following code cell. This code cell is not runnable, but it is here to show our implementation of CompGCNConv. Notice that we extend the original CompGCNConv implementation to incorporate edge features.

#### Incorporating Edge Features in CompGCNConv

In the original CompGCNConv implementation, edge features were not explicitly considered. To enrich the representation power of the model, we extended CompGCNConv to handle edge embeddings.

This extended implementation modifies the message method to include edge embeddings:
1.	Transforming Edge Embeddings: Depending on the specified edge method, edge embeddings are either directly added to the relation embeddings or transformed via a learnable weight matrix before addition.
2.	Combining Edge Features with Messages: The edge embeddings are integrated into the message-passing process at different stages, depending on the edge method chosen.

#### Efficiency Improvements

The original implementation of CompGCNConv was already highly efficient due to its reliance on scatter operations for sparse aggregation. These operations scale well to large, sparse graphs. Our modifications retain this efficiency while adding the capability to model edge features, ensuring the model remains practical for large-scale datasets.

In [None]:
class CompGCNConv(MessagePassing):
    """
    A Compositional Graph Convolutional Network (CompGCN) layer.
    This layer supports directed graphs with multiple relation types and can incorporate edge features.

    Args:
        in_channels (int): Dimension of input node features.
        out_channels (int): Dimension of output node features.
        num_rels (int): Number of relation types.
        act (Callable): Activation function to apply to the output.
        params (Namespace, optional): Additional parameters such as dropout, bias, and edge method.
    """

    def __init__(self, in_channels, out_channels, num_rels, act=lambda x: x, params=None):
        super(self.__class__, self).__init__()

        # Initialize attributes
        self.p = params
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_rels = num_rels
        self.act = act
        self.device = None

        # Learnable weights for transformations
        self.e_weight = get_param((in_channels, out_channels))  # For edge transformations
        self.w_loop = get_param((in_channels, out_channels))  # For self-loops
        self.w_in = get_param((in_channels, out_channels))  # For incoming edges
        self.w_out = get_param((in_channels, out_channels))  # For outgoing edges
        self.w_pp = get_param((in_channels, out_channels))  # For pairwise edges
        self.w_rel = get_param((in_channels, out_channels))  # For relation embeddings
        self.loop_rel = get_param((1, in_channels))  # Embedding for self-loop relation

        # Regularization components
        self.drop = torch.nn.Dropout(self.p.dropout)  # Dropout layer
        self.bn = torch.nn.BatchNorm1d(out_channels)  # Batch normalization

        # Optional bias term
        if self.p.bias:
            self.register_parameter("bias", Parameter(torch.zeros(out_channels)))

    def forward(self, x, edge_index, edge_type, rel_embed, edge_embed):
        """
        Forward pass for the CompGCN layer.

        Args:
            x (torch.Tensor): Node features of shape (num_nodes, in_channels).
            edge_index (torch.Tensor): Edge indices of shape (2, num_edges).
            edge_type (torch.Tensor): Relation types of edges (num_edges,).
            rel_embed (torch.Tensor): Relation embeddings of shape (num_rels, in_channels).
            edge_embed (torch.Tensor, optional): Edge features (num_edges, in_channels).

        Returns:
            torch.Tensor: Updated node features of shape (num_nodes, out_channels).
            torch.Tensor: Updated relation embeddings of shape (num_rels, out_channels).
        """
        # Initialize device and augment relation embeddings with self-loop relation
        if self.device is None:
            self.device = edge_index.device
        rel_embed = torch.cat([rel_embed, self.loop_rel], dim=0)

        # Partition edges into different types
        num_pp_edge = edge_type[edge_type == edge_type.max()].size(0)
        num_edges = (edge_index.size(1) - num_pp_edge) // 2
        num_ent = x.size(0)

        # Separate edge indices and types for in, out, and pairwise edges
        self.in_index, self.out_index, self.pp_index = (
            edge_index[:, :num_edges],
            edge_index[:, num_edges : 2 * num_edges],
            edge_index[:, 2 * num_edges :],
        )
        self.in_type, self.out_type, self.pp_type = (
            edge_type[:num_edges],
            edge_type[num_edges : 2 * num_edges],
            edge_type[2 * num_edges :],
        )

        # Partition edge embeddings, if provided
        if edge_embed is not None:
            self.in_embed, self.out_embed, self.pp_embed = (
                edge_embed[:num_edges],
                edge_embed[num_edges : 2 * num_edges],
                edge_embed[2 * num_edges :],
            )
        else:
            self.in_embed = self.out_embed = self.pp_embed = None

        # Add self-loop edges and compute normalization factors
        self.loop_index = torch.stack([torch.arange(num_ent), torch.arange(num_ent)]).to(self.device)
        self.loop_type = torch.full((num_ent,), rel_embed.size(0) - 1, dtype=torch.long).to(self.device)

        self.in_norm = self.compute_norm(self.in_index, num_ent)
        self.out_norm = self.compute_norm(self.out_index, num_ent)
        self.pp_norm = self.compute_norm(self.pp_index, num_ent)

        # Perform message passing for in, out, self-loop, and pairwise edges
        in_res = self.propagate(
            "add",
            self.in_index,
            x=x,
            edge_type=self.in_type,
            rel_embed=rel_embed,
            edge_norm=self.in_norm,
            mode="in",
            edge_embed=self.in_embed,
        )
        loop_res = self.propagate(
            "add",
            self.loop_index,
            x=x,
            edge_type=self.loop_type,
            rel_embed=rel_embed,
            edge_norm=None,
            mode="loop",
            edge_embed=None,
        )
        out_res = self.propagate(
            "add",
            self.out_index,
            x=x,
            edge_type=self.out_type,
            rel_embed=rel_embed,
            edge_norm=self.out_norm,
            mode="out",
            edge_embed=self.out_embed,
        )
        pp_res = self.propagate(
            "add",
            self.pp_index,
            x=x,
            edge_type=self.pp_type,
            rel_embed=rel_embed,
            edge_norm=self.pp_norm,
            mode="pp",
            edge_embed=self.pp_embed,
        )

        # Aggregate results from different edge types
        out = (
            self.drop(in_res) * (1 / 4)
            + self.drop(out_res) * (1 / 4)
            + loop_res * (1 / 4)
            + self.drop(pp_res) * (1 / 4)
        )

        # Apply bias and batch normalization
        if self.p.bias:
            out = out + self.bias
        out = self.bn(out)

        # Return updated node and relation embeddings
        return self.act(out), torch.matmul(rel_embed, self.w_rel)[:-1]

    def rel_transform(self, ent_embed, rel_embed):
        """
        Apply a compositional transformation between entity and relation embeddings.

        Args:
            ent_embed (torch.Tensor): Entity embeddings.
            rel_embed (torch.Tensor): Relation embeddings.

        Returns:
            torch.Tensor: Transformed embeddings.
        """
        if self.p.opn == "corr":
            trans_embed = ccorr(ent_embed, rel_embed)
        elif self.p.opn == "sub":
            trans_embed = ent_embed - rel_embed
        elif self.p.opn == "mult":
            trans_embed = ent_embed * rel_embed
        else:
            raise NotImplementedError

        return trans_embed

    def message(self, x_j, edge_type, rel_embed, edge_norm, mode, edge_embed):
        """
        Compute messages to be passed along edges.

        Args:
            x_j (torch.Tensor): Features of source nodes (num_edges, in_channels).
            edge_type (torch.Tensor): Edge types (num_edges,).
            rel_embed (torch.Tensor): Relation embeddings (num_rels, in_channels).
            edge_norm (torch.Tensor, optional): Normalization factors for edges.
            mode (str): Type of edges ('in', 'out', 'loop', 'pp').
            edge_embed (torch.Tensor, optional): Edge features (num_edges, in_channels).

        Returns:
            torch.Tensor: Messages (num_edges, out_channels).
        """
        weight = getattr(self, f"w_{mode}")  # Select weight matrix for the mode
        rel_emb = torch.index_select(rel_embed, 0, edge_type)  # Relation-specific embeddings

        # Optionally incorporate edge embeddings
        if edge_embed is not None:
            if self.p.edge_method == "method1":
                rel_emb = rel_emb + edge_embed
            elif self.p.edge_method == "method2":
                edge_embed_transformed = torch.mm(edge_embed, self.e_weight)
                rel_emb = rel_emb + edge_embed_transformed

        # Apply relational transformation and combine with weights
        xj_rel = self.rel_transform(x_j, rel_emb)
        out = torch.mm(xj_rel, weight)

        # Additional edge embedding handling
        if edge_embed is not None:
            if self.p.edge_method == "method3":
                out = out + edge_embed
            elif self.p.edge_method == "method4":
                edge_embed_transformed = torch.mm(edge_embed, self.e_weight)
                out = out + edge_embed_transformed

        return out if edge_norm is None else out * edge_norm.view(-1, 1)

    def update(self, aggr_out):
        """
        Update step after aggregation.

        Args:
            aggr_out (torch.Tensor): Aggregated node features.

        Returns:
            torch.Tensor: Updated node features.
        """
        return aggr_out

    def compute_norm(self, edge_index, num_ent):
        """
        Compute edge normalization factors.

        Args:
            edge_index (torch.Tensor): Edge indices (2, num_edges).
            num_ent (int): Number of entities (nodes).

        Returns:
            torch.Tensor: Edge normalization factors (num_edges,).
        """
        row, col = edge_index
        edge_weight = torch.ones_like(row).float()
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_ent)  # Degree of nodes
        norm = edge_weight / ((deg[row].pow(0.5) + deg[col].pow(0.5)))
        return norm

    def __repr__(self):
        """
        String representation of the layer.
        """
        return "{}({}, {}, num_rels={})".format(
            self.__class__.__name__, self.in_channels, self.out_channels, self.num_rels
        )

### CompGCN Model
The implementation of the `CompGCN` is in the file `nbfnet/compgcn/models.py`. The following code cell is a simplified version of the RGCN model for explaining purposes.

The `CompGCN` class serves as the main entry point, initializing the appropriate variant of the model (`CompGCN_TransE` or `CompGCN_DistMult`) based on the chosen scoring function. These variants build upon the `CompGCNBase`, which implements the core graph convolutional layers using the `CompGCNConv` module. Each layer processes node features, relation embeddings, and optional edge embeddings, iteratively updating the node and relation representations. The forward pass uses initialized node embeddings, passes them through the layers, and computes scores for triples by extracting subject, relation, and object embeddings.

The TransE variant computes scores using a distance-based metric in the embedding space, while the DistMult variant uses an element-wise multiplication of the subject, relation, and object embeddings followed by a summation. Both variants use dropout for regularization to improve generalization. The implementation also leverages bases for parameter sharing in relation-specific embeddings, making it efficient for large graphs with many relation types.

In [None]:
@dataclass
class CompGCNConfig:
    """
    Configuration class for the CompGCN model.

    Attributes:
        input_dim (int): Dimension of input features for nodes.
        num_layers (int): Number of layers in the CompGCN.
        num_bases (int): Number of bases for relation-specific weight sharing.
        use_stage (int): Whether to enable edge embedding usage (1 for yes, 0 for no).
        score_func (str): Scoring function ('distmult' or 'transe').
        dropout (float): Dropout rate for input embeddings.
        hid_drop (float): Dropout rate for hidden layers.
        gamma (float): Margin used in scoring function.
        bias (int): Whether to include bias terms (1 for yes, 0 for no).
        opn (str): Compositional operator to use ('corr', 'sub', 'mult').
        edge_method (str): Method to handle edge embeddings ('method1', 'method2', 'method3').

    Raises:
        AssertionError: Ensures valid configuration parameters during initialization.
    """

    input_dim: int = 256
    num_layers: int = 2
    num_bases: int = 0
    use_stage: int = 1
    score_func: str = "distmult"
    dropout: float = 0.1
    hid_drop: float = 0.3
    gamma: float = 40.0  # Margin
    bias: int = 1
    opn: str = "corr"
    edge_method: str = "method1"

    def __post_init__(self):
        # Validate configuration parameters
        assert self.edge_method in ["method1", "method2", "method3"]
        assert self.opn in ["corr", "sub", "mult"]
        assert self.score_func in ["transe", "distmult"]
        assert self.use_stage in [0, 1]


class CompGCN(torch.nn.Module):
    """
    CompGCN model for multi-relational graphs, supporting edge embeddings.

    Args:
        num_relation (int): Number of relation types.
        edge_embed_dim (int or None): Dimension of edge embeddings (if applicable).
        cfg (CompGCNConfig): Configuration object for CompGCN.
    """

    def __init__(self, num_relation, edge_embed_dim, cfg: CompGCNConfig):
        super(CompGCN, self).__init__()

        # Disable edge embeddings if not using STAGE
        if not cfg.use_stage:
            edge_embed_dim = None
        self.edge_embed_dim = edge_embed_dim

        # Select the model based on the scoring function
        if cfg.score_func == "transe":
            self.model = CompGCN_TransE(num_relation // 2, edge_embed_dim, cfg)
        elif cfg.score_func == "distmult":
            self.model = CompGCN_DistMult(num_relation // 2, edge_embed_dim, cfg)
        else:
            raise NotImplementedError

    def forward(self, data, batch):
        """
        Forward pass for CompGCN.

        Args:
            data (Data): Graph data containing edge indices, edge types, and optional edge embeddings.
            batch (torch.Tensor): Tensor of triples (source, relation, target).

        Returns:
            torch.Tensor: Scores for the input triples.
        """
        # Set the number of nodes in the graph
        self.num_nodes = data.num_nodes
        self.model.num_nodes = self.num_nodes

        # Extract edge data
        edge_index = data.edge_index
        edge_type = data.original_edge_type

        # Extract edge embeddings if enabled
        if self.edge_embed_dim is not None:
            edge_embed = data.edge_embeddings
        else:
            edge_embed = None

        # Extract batch data (source, target, relation triples)
        source_nodes = batch[:, :, 0]
        target_nodes = batch[:, :, 1]
        relations = batch[:, :, 2]

        # Forward through the selected model
        return self.model(edge_index, edge_type, source_nodes, relations, target_nodes, edge_embed)


class BaseModel(torch.nn.Module):
    """
    Base class for CompGCN models, providing common utilities.

    Args:
        cfg (CompGCNConfig): Configuration object for the model.
    """

    def __init__(self, cfg: CompGCNConfig):
        super(BaseModel, self).__init__()
        self.cfg = cfg
        self.act = torch.tanh  # Activation function
        self.bceloss = torch.nn.BCELoss()  # Binary cross-entropy loss

    def loss(self, pred, true_label):
        """
        Compute loss for binary classification tasks.

        Args:
            pred (torch.Tensor): Predictions.
            true_label (torch.Tensor): Ground truth labels.

        Returns:
            torch.Tensor: Computed loss value.
        """
        return self.bceloss(pred, true_label)


class CompGCNBase(BaseModel):
    """
    Base class for CompGCN layers with forward propagation.

    Args:
        num_rel (int): Number of relations.
        edge_embed_dim (int or None): Dimension of edge embeddings.
        cfg (CompGCNConfig): Configuration object.
    """

    def __init__(self, num_rel, edge_embed_dim, cfg: CompGCNConfig):
        super(CompGCNBase, self).__init__(cfg)

        # Initialize relation embeddings
        if self.cfg.num_bases > 0:
            self.init_rel = get_param((self.cfg.num_bases, self.cfg.input_dim))
        else:
            self.init_rel = get_param(
                (num_rel * 2 if self.cfg.score_func != "transe" else num_rel, self.cfg.input_dim)
            )

        # Initialize CompGCN layers
        self.layers = torch.nn.ModuleList()
        for _ in range(self.cfg.num_layers):
            self.layers.append(
                CompGCNConv(self.cfg.input_dim, self.cfg.input_dim, num_rel, act=self.act, params=self.cfg)
            )

    def forward_base(self, edge_index, edge_type, sub, rel, obj, drop, edge_embed=None):
        """
        Forward pass through the base CompGCN model.

        Args:
            edge_index (torch.Tensor): Edge indices.
            edge_type (torch.Tensor): Edge types.
            sub (torch.Tensor): Subject node indices.
            rel (torch.Tensor): Relation indices.
            obj (torch.Tensor): Object node indices.
            drop (torch.nn.Dropout): Dropout layer.
            edge_embed (torch.Tensor, optional): Edge embeddings.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Subject, relation, and object embeddings.
        """
        # Initialize relation embeddings
        r = self.init_rel if self.cfg.score_func != "transe" else torch.cat([self.init_rel, -self.init_rel], dim=0)

        # Initialize node features
        init_embed = torch.ones((self.num_nodes, self.cfg.input_dim), device=sub.device)
        x = init_embed

        # Pass through all layers
        for layer in self.layers:
            x, r = layer(x, edge_index, edge_type, rel_embed=r, edge_embed=edge_embed)
            x = drop(x)

        # Extract embeddings for subjects, relations, and objects
        batch_size, num_neg_plus_1 = sub.size()
        sub_emb = torch.index_select(x, 0, sub.view(-1)).view(batch_size, num_neg_plus_1, -1)
        rel_emb = torch.index_select(r, 0, rel.view(-1)).view(batch_size, num_neg_plus_1, -1)
        obj_emb = torch.index_select(x, 0, obj.view(-1)).view(batch_size, num_neg_plus_1, -1)

        return sub_emb, rel_emb, obj_emb


class CompGCN_TransE(CompGCNBase):
    """
    CompGCN model with TransE scoring function.

    Args:
        num_rel (int): Number of relations.
        edge_embed_dim (int or None): Dimension of edge embeddings.
        cfg (CompGCNConfig): Configuration object.
    """

    def __init__(self, num_rel, edge_embed_dim, cfg: CompGCNConfig):
        super(self.__class__, self).__init__(num_rel, edge_embed_dim, cfg)
        self.drop = torch.nn.Dropout(self.cfg.hid_drop)

    def forward(self, edge_index, edge_type, sub, rel, obj, edge_embed=None):
        # Compute embeddings for subjects, relations, and objects
        sub_emb, rel_emb, obj_emb = self.forward_base(edge_index, edge_type, sub, rel, obj, self.drop, edge_embed)

        # TransE scoring function: score based on distance in embedding space
        pred_emb = sub_emb + rel_emb
        x = self.cfg.gamma - torch.norm(pred_emb.unsqueeze(1) - obj_emb, p=1, dim=2)

        return x


class CompGCN_DistMult(CompGCNBase):
    """
    CompGCN model with DistMult scoring function.

    Args:
        num_rel (int): Number of relations.
        edge_embed_dim (int or None): Dimension of edge embeddings.
        cfg (CompGCNConfig): Configuration object.
    """

    def __init__(self, num_rel, edge_embed_dim, cfg: CompGCNConfig):
        super(self.__class__, self).__init__(num_rel, edge_embed_dim, cfg)
        self.drop = torch.nn.Dropout(self.cfg.hid_drop)

    def forward(self, edge_index, edge_type, sub, rel, obj, edge_embed=None):
        # Compute embeddings for subjects, relations, and objects
        sub_emb, rel_emb, obj_emb = self.forward_base(edge_index, edge_type, sub, rel, obj, self.drop, edge_embed)

        # DistMult scoring function: element-wise multiplication followed by summation
        x = sub_emb * rel_emb * obj_emb
        x = torch.sum(x, dim=2)

        return x

## NBFNet
The STAGE framework has already implemented the NBFNet model, providing all the necessary components and functionality. As a result, we did not make any modifications or additional implementations for NBFNet. Instead, we utilized the existing implementation directly, allowing us to focus on integrating it into our overall workflow and evaluating its performance.

## STAGE + Friends
The `EdgeGraphsModel` model combines STAGE embeddings with a final graph model (NBFNet, RGCN, or CompGCN) to perform relational learning tasks. The model is configured through the `EdgeGraphsModelConfig`, which allows customization of edge embedding dimensions, the number of layers in the edge embedding module, the type of edge model (GINEConv or GCNConv), and the final model to use.

The model first generates edge embeddings using an `MPNN` (Message Passing Neural Network) module. The `MPNN` processes edge graphs through message passing and supports either `GINEConv` or `GCNConv` for convolution operations. Edge graph embeddings are pooled using a global pooling operation (`global_add_pool`) to generate a compact representation for each edge graph. Additionally, uniform embeddings are generated for user-product edges, ensuring consistent representation for edges without associated graph embeddings. These embeddings are then concatenated to create a unified edge embedding matrix.

The processed edge embeddings are added to the graph data object and passed to the selected final model, which could be NBFNet, RGCN, or CompGCN, as specified in the configuration. Each of these models operates on the relational graph data to perform link prediction tasks.

This implementation is efficient and scalable, utilizing sparse operations and pooling for edge graph embeddings and supporting multi-relational graphs.

In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import GINEConv, GCNConv
from torch_geometric.nn.pool import global_add_pool
from dataclasses import dataclass, field

from .nbfmodel import NBFNet, NBFNetConfig
from .rgcn import RGCN, RGCNConfig
from .compgcn.models import CompGCN, CompGCNConfig


class MPNN(torch.nn.Module):
    """
    Message Passing Neural Network (MPNN) for edge embedding generation.

    Args:
        input_dim (int): Input feature dimension for nodes.
        hidden_dim (int): Hidden layer dimension.
        num_layers (int): Number of MPNN layers.
        edge_model (str): Type of edge model ('GINEConv' or 'GCNConv').
        edge_dim (int): Dimension of edge attributes.
    """

    def __init__(self, input_dim, hidden_dim, num_layers, edge_model, edge_dim):
        super().__init__()

        self.convs = torch.nn.ModuleList()  # Convolution layers
        self.bns = torch.nn.ModuleList()  # Batch normalization layers

        # Initialize layers
        for _ in range(num_layers):
            if edge_model == "GINEConv":
                mlp = torch.nn.Sequential(
                    torch.nn.Linear(input_dim, hidden_dim),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_dim, hidden_dim),
                )
                self.convs.append(GINEConv(nn=mlp, edge_dim=edge_dim))
            elif edge_model == "GCNConv":
                self.convs.append(GCNConv(input_dim, hidden_dim))
            self.bns.append(torch.nn.BatchNorm1d(hidden_dim))
            input_dim = hidden_dim

    def forward(self, x, edge_index, edge_attr):
        """
        Forward pass through the MPNN.

        Args:
            x (torch.Tensor): Node features.
            edge_index (torch.Tensor): Edge indices.
            edge_attr (torch.Tensor): Edge attributes.

        Returns:
            torch.Tensor: Updated node features.
        """
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index, edge_attr)  # Message passing
            x = bn(x).relu()  # Batch normalization and ReLU activation

        return x


@dataclass
class EdgeGraphsNBFNetConfig:
    """
    Configuration class for EdgeGraphsNBFNet.

    Attributes:
        edge_embed_dim (int): Dimension of edge embeddings.
        edge_embed_num_layers (int): Number of layers in the edge embedding MPNN.
        edge_model (str): Edge model type ('GINEConv' or 'GCNConv').
        use_p_value (int): Whether to use p-values in edge attributes.

        final_model (str): Final model type ('nbf', 'rgcn', or 'compgcn').
        nbf (NBFNetConfig): Configuration for NBFNet.
        rgcn (RGCNConfig): Configuration for RGCN.
        compgcn (CompGCNConfig): Configuration for CompGCN.
    """

    edge_embed_dim: int = 256
    edge_embed_num_layers: int = 1
    edge_model: str = "GINEConv"
    use_p_value: int = 1

    final_model: str = "nbf"
    nbf: NBFNetConfig = field(default_factory=NBFNetConfig)
    rgcn: RGCNConfig = field(default_factory=RGCNConfig)
    compgcn: CompGCNConfig = field(default_factory=CompGCNConfig)


class EdgeGraphsNBFNet(nn.Module):
    """
    EdgeGraphsNBFNet combines edge graph embeddings with NBFNet, RGCN, or CompGCN.

    Args:
        num_relation (int): Number of relations in the graph.
        cfg (EdgeGraphsNBFNetConfig): Configuration object.
    """

    def __init__(self, num_relation, cfg: EdgeGraphsNBFNetConfig):
        super().__init__()

        self.edge_embed_dim = cfg.edge_embed_dim

        # Initialize the final model based on the configuration
        if cfg.final_model == "nbf":
            self.model = NBFNet(num_relation, cfg.edge_embed_dim, cfg.nbf)
        elif cfg.final_model == "rgcn":
            self.model = RGCN(num_relation, cfg.edge_embed_dim, cfg.rgcn)
        elif cfg.final_model == "compgcn":
            self.model = CompGCN(num_relation, cfg.edge_embed_dim, cfg.compgcn)
        else:
            raise ValueError(f"Invalid final model: {cfg.final_model}")

        # Define edge embedding model
        edge_dim = 2 if cfg.use_p_value else 1  # Include p-values in edge attributes if enabled
        self.edgegraph_model = MPNN(
            input_dim=1,  # Edge attributes have 1 initial feature
            hidden_dim=cfg.edge_embed_dim,
            num_layers=cfg.edge_embed_num_layers,
            edge_model=cfg.edge_model,
            edge_dim=edge_dim,
        )

        # Uniform embedding for user-product edges
        self.up_emb = torch.nn.Embedding(1, cfg.edge_embed_dim)

        self.edge_model = cfg.edge_model
        self.use_p_value = cfg.use_p_value

        # Print the number of parameters in the final model
        num_params = sum(p.numel() for p in self.model.parameters())
        print(f"Number of parameters in self.model: {num_params}")

    def forward(self, data, batch):
        """
        Forward pass for EdgeGraphsNBFNet.

        Args:
            data (Data): Graph data containing edge features and edge graphs.
            batch (torch.Tensor): Batch of triples for scoring.

        Returns:
            torch.Tensor: Scores for the input triples.
        """
        # Adjust edge graph attributes for compatibility with the edge model
        if data.edgegraph_edge_attr.dim() == 1:
            data.edgegraph_edge_attr = data.edgegraph_edge_attr.unsqueeze(-1)
        if self.edge_model == "GCNConv":
            data.edgegraph_edge_attr = data.edgegraph_edge_attr[:, 0:1]

        # Generate edge graph embeddings using MPNN
        h = self.edgegraph_model(data.edgegraph_x, data.edgegraph_edge_index, data.edgegraph_edge_attr)
        edgegraph_reprs = global_add_pool(h, data.edgegraph2ppedge)  # Pool embeddings for each edge graph

        # Generate embeddings for user-product edges
        num_up_edges = data.edge_index.size(-1) - edgegraph_reprs.size(0)
        upgraph_emb = self.up_emb.weight.repeat((num_up_edges, 1))

        # Combine user-product and edge graph embeddings
        edge_embeddings = torch.vstack([upgraph_emb, edgegraph_reprs])

        # Add edge embeddings to the data object
        data.edge_embeddings = edge_embeddings
        data.x = None  # Clear node features to avoid conflicts

        # Pass the processed data to the selected final model
        return self.model.forward(data, batch)

## Training and Evaluation Script
This bash script automates the process of training and evaluating three models (`CompGCN`, `RGCN`, and `NBFNet`) on a multi-dataset setup. For each target dataset (phone, refrig, shoe, bed, desktop), the script trains on the remaining four datasets and tests on the target dataset specified by the corresponding YAML configuration file (`config/${config}.yaml`). For each target dataset, it iterates over four edge embedding methods (`method1`, `method2`, `method3`, and `method4`) and trains both the `CompGCN` and `RGCN` models using the specified method. The results are saved in separate directories (`exp/final_${config}/compgcn_${method}` and `exp/final_${config}/rgcn_${method}`) and also WandB plots. Additionally, the NBFNet model is trained and tested on the target dataset without varying the edge embedding method, with results saved in `exp/final_${config}/nbf`. This setup ensures a systematic evaluation across all datasets, embedding methods, and models.

All the plots and numerical results used in the post are generated using the output files from this script.

In [None]:
%%bash
for config in phone refrig shoe bed
do
    for method in method1 method2 method3 method4
    do
        python script/run.py --config_path config/${config}.yaml --use_wb 1 --save_dir exp/final_${config}/compgcn_${method} --edgegraph.compgcn.edge_method $method --seed 1 --edgegraph.final_model compgcn
        python script/run.py --config_path config/${config}.yaml --use_wb 1 --save_dir exp/final_${config}/rgcn_${method} --edgegraph.compgcn.edge_method $method --seed 1 --edgegraph.final_model rgcn
    done
    python script/run.py --config_path config/${config}.yaml --use_wb 1 --save_dir exp/final_${config}/nbf --edgegraph.final_model nbf --seed 1  --edgegraph.final_model nbf
done