# Graph Transformer Model

DiGress trained a graph transformer network proposed by Dwivedi & Bresson (2021) s the denoising model. In this section, we will go through the implementation details of the model to see how it predicts the clean graph from the noisy graph. 


- **$\mathbf{X}$**: Node features matrix, shape ( bs, $n, d_x$ )
- **$\mathbf{E}$**: Edge features matrix, shape (bs, $n, n, d_e$ )
- **$\mathbf{y}$**: Global features vector, shape (bs, $d_y$ )
- **node_mask**: Node mask, shape (bs, $n$)


In [3]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.dropout import Dropout
from torch import Tensor
import math

In [4]:
# Helper functions
def assert_correctly_masked(variable, node_mask):
    assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \
        'Variables not masked properly.'

def masked_softmax(x, mask, **kwargs):
    if mask.sum() == 0:
        return x
    x_masked = x.clone()
    x_masked[mask == 0] = -float("inf")
    return torch.softmax(x_masked, **kwargs)

### Mapping Node/Edge Features($\mathbf X, \mathbf E$) to Global Features $y$


\begin{equation}
\operatorname{PNA}(\boldsymbol{X})=\operatorname{cat}(\max (\boldsymbol{X}), \min (\boldsymbol{X}), \operatorname{mean}(\boldsymbol{X}), \operatorname{std}(\boldsymbol{X})) \boldsymbol{W}
\end{equation}

In [5]:
class Xtoy(nn.Module):
    def __init__(self, dx, dy):
        """ Map node features to global features """
        super().__init__()
        self.lin = nn.Linear(4 * dx, dy)

    def forward(self, X):
        """ X: bs, n, dx. """
        m = X.mean(dim=1)     # bs, dx
        mi = X.min(dim=1)[0]  # bs, dx
        ma = X.max(dim=1)[0]  # bs, dx
        std = X.std(dim=1)    # bs, dx
        z = torch.hstack((m, mi, ma, std)) # bs, 4 * dx
        out = self.lin(z)    # bs, dy
        return out


class Etoy(nn.Module):
    def __init__(self, d, dy):
        """ Map edge features to global features. """
        super().__init__()
        self.lin = nn.Linear(4 * d, dy)

    def forward(self, E):
        """ E: bs, n, n, de
            Features relative to the diagonal of E could potentially be added.
        """
        m = E.mean(dim=(1, 2))              # bs, de
        mi = E.min(dim=2)[0].min(dim=1)[0]  # bs, de
        ma = E.max(dim=2)[0].max(dim=1)[0]  # bs, de
        std = torch.std(E, dim=(1, 2))      # bs, de
        z = torch.hstack((m, mi, ma, std))  # bs, 4 * de
        out = self.lin(z)                   # bs, dy
        return out

## Self Attention `NodeEdgeBlock`
Self-attention layer that updates the representations on the node, edges, and global features
$$\mathbf{X}_{\text {new }}, \mathbf{E}_{\text {new}}, \mathbf{y}_{\text {new}}=\operatorname{SelfAttn}\left(\mathbf{X}, \mathbf{E}, \mathbf{y}, \text{node}_\text{mask}\right)$$

### 1. Linear Projections and Maskings

\begin{equation}
\begin{aligned}
& \mathbf{Q}=\mathbf{X} \mathbf{W}_Q \odot \mathbf{x}_{\text {mask }} \\
& \mathbf{K}=\mathbf{X} \mathbf{W}_K \odot \mathbf{x}_{\text {mask }} \\
& \mathbf{V}=\mathbf{X} \mathbf{W}_V \odot \mathbf{x}_{\text {mask }}
\end{aligned}
\end{equation}

### 2. Reshape $\mathbf{Q,K,V}$

\begin{equation}
\begin{aligned}
   & \mathbf Q = \mathbf Q. \textrm{reshape(bs, n, nhead, df)} \\
   & \mathbf K = \mathbf K. \textrm{reshape(bs, n, nhead, df)} \\
   & \mathbf V = \mathbf V. \textrm{reshape(bs, n, nhead, df)}
\end{aligned}
\end{equation}

### 2. Calculate the Attention Score 
The query, key, and value matrices are transformed into multi-head

\begin{equation}
    \mathbf Y = \frac{(\mathbf Q \times \mathbf K^T)}{\sqrt{\textrm{df}}}
\end{equation}

### 3. FiLM Layers


#### (a) Incorporate Edge Features to the self-attention score

\begin{equation}
\begin{aligned}
\mathbf{E_1} & =\left( \mathbf W^{\textrm{emul}}_E \mathbf{E} \times \mathbf{e_{\textrm{mask}}} ).\textrm{reshape}(\mathrm{bs}, \mathrm{n}, \mathrm{n}, \mathrm{n} \text { head, df })\right. \\
\mathbf{E_2} & =\left( \mathbf W^{\textrm{add}}_E \mathbf{E} \times \mathbf{e_{\textrm{mask}}} ).\textrm{reshape}(\mathrm{bs}, \mathrm{n}, \mathrm{n}, \mathrm{n} \text { head, df })\right. \\
\mathbf{Y} & =\mathbf{Y} \times(\mathbf{E} \mathbf{1}+1)+\mathbf{E} \mathbf{2}
\end{aligned}
\end{equation}



#### (b) Incorporate $y$ to $\mathbf E$

\begin{equation}
\begin{aligned}
&  E_{\textrm{new}} = \mathbf Y.\textrm{reshape}(\textrm{bs, n, n, dx}) \\
& \mathbf E_{\textrm{out}} = W^{\textrm{add}}_{ye}y +  (W^{\textrm{mul}}_{ye}y +1 ) \times \textrm{new}\mathbf E \times \mathbf{e_{\textrm{mask}}}
\end{aligned}
\end{equation}

### 4. Compute Normalized Attention Scores 

\begin{equation}
    \textrm{Attn} = \textrm{softmax} (\mathbf Y \times \textrm{softmax}_{\textrm{mask}}) \in \mathbb R^{bs \times n\times n \times \textrm{nhead}}
\end{equation}

### 5. Compute Weighted Values 
This step aggregates information from connected nodes weighted by the computed attention scores, effectively updating node representations based on their neighborhood.

\begin{equation}
\mathbf V_{\textrm{weighted}} = \sum_i \textrm{Attn} \times \mathbf V \in \mathbb R^{bs \times n \times dx}
\end{equation}

### 6. Update Representations
Node, edge, and global feature representations are updated through additional FiLM layers and linear transformations:

#### (a) Incorporate $y$ to $\mathbf X$

\begin{equation}
\begin{aligned}
    &  \mathbf X_{\textrm{new}} = W_{yx}^{\textrm{add}} + (W_{yx}^{\textrm{mul}} + 1) \times \mathbf V_{\textrm{weighted}} \\
    &  \mathbf X_{\textrm{out}} = \mathbf X_{\textrm{new}} W_{xx} \odot \mathbf x_{\textrm{mask}}
\end{aligned}
\end{equation}


#### (b) Process $y$ based on $\mathbf{X, E}$

\begin{equation}
    y_{\textrm{out}} = (yW_{yy} + \mathbf E W_{ey} + X W_{xy}) W_{yy}
\end{equation}


In [12]:
class NodeEdgeBlock(nn.Module):
    def __init__(self, dx, de, dy, n_head, **kwargs):
        super().__init__()
        assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}"
        self.dx = dx
        self.de = de
        self.dy = dy
        self.df = int(dx / n_head)
        self.n_head = n_head

        # Attention
        self.q = Linear(dx, dx)
        self.k = Linear(dx, dx)
        self.v = Linear(dx, dx)
        # FiLM E to X
        self.e_add = Linear(de, dx)
        self.e_mul = Linear(de, dx)
        # FiLM y to E
        self.y_e_mul = Linear(dy, dx)           # Warning: here it's dx and not de
        self.y_e_add = Linear(dy, dx)
        # FiLM y to X
        self.y_x_mul = Linear(dy, dx)
        self.y_x_add = Linear(dy, dx)
        # Process y
        self.y_y = Linear(dy, dy)
        self.x_y = Xtoy(dx, dy)
        self.e_y = Etoy(de, dy)
        # Output layers
        self.x_out = Linear(dx, dx)
        self.e_out = Linear(dx, de)
        self.y_out = nn.Sequential(nn.Linear(dy, dy), nn.ReLU(), nn.Linear(dy, dy))

    def forward(self, X, E, y, node_mask):
        """
        :param X: bs, n, d        node features
        :param E: bs, n, n, d     edge features
        :param y: bs, dy           global features
        :param node_mask: bs, n
        :return: newX, newE, new_y with the same shape.
        """
        bs, n, _ = X.shape
        x_mask = node_mask.unsqueeze(-1)        # bs, n, 1
        e_mask1 = x_mask.unsqueeze(2)           # bs, n, 1, 1
        e_mask2 = x_mask.unsqueeze(1)           # bs, 1, n, 1

        # 1. Map X to keys and queries
        Q = self.q(X) * x_mask                  # (bs, n, dx)
        K = self.k(X) * x_mask           
        V = self.v(X) * x_mask               
        assert_correctly_masked(Q, x_mask)
        
        # 2. Reshape to (bs, n, n_head, df) with dx = n_head * df
        Q = Q.reshape((Q.size(0), Q.size(1), self.n_head, self.df))
        K = K.reshape((K.size(0), K.size(1), self.n_head, self.df))
        V = V.reshape((V.size(0), V.size(1), self.n_head, self.df))

        Q = Q.unsqueeze(2)                              # (bs, 1, n, n_head, df)
        K = K.unsqueeze(1)                              # (bs, n, 1, n head, df)
        V = V.unsqueeze(1)                              # (bs, 1, n, n_head, df)

        # Compute unnormalized attentions. Y is (bs, n, n, n_head, df)
        Y = Q * K
        Y = Y / math.sqrt(Y.size(-1)) 

        assert_correctly_masked(Y, (e_mask1 * e_mask2).unsqueeze(-1))

        E1 = self.e_mul(E) * e_mask1 * e_mask2                        # bs, n, n, dx
        E1 = E1.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df)) # bs, n, n, n_head, df


        E2 = self.e_add(E) * e_mask1 * e_mask2                        # bs, n, n, dx
        E2 = E2.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df))

        # Incorporate edge features to the self attention scores.
        Y = Y * (E1 + 1) + E2                  # (bs, n, n, n_head, df)

        # Incorporate y to E
        newE = Y.flatten(start_dim=3)                    # bs, n, n, dx
        ye1 = self.y_e_add(y).unsqueeze(1).unsqueeze(1)  # bs, 1, 1, de
        ye2 = self.y_e_mul(y).unsqueeze(1).unsqueeze(1)
        newE = ye1 + (ye2 + 1) * newE

        # Output E
        newE = self.e_out(newE) * e_mask1 * e_mask2      # bs, n, n, de
        assert_correctly_masked(newE, e_mask1 * e_mask2)

        # Compute attentions. attn is still (bs, n, n, n_head, df)
        softmax_mask = e_mask2.expand(-1, n, -1, self.n_head)    # bs, 1, n, 1
        attn = masked_softmax(Y, softmax_mask, dim=2)  # bs, n, n, n_head

        # Compute weighted values
        weighted_V = attn * V                          # bs, n, n, n_head, df
        weighted_V = weighted_V.sum(dim=2)             # bs, n, n_head, df

        # Send output to input dim
        weighted_V = weighted_V.flatten(start_dim=2)            # bs, n, dx

        # Incorporate y to X
        yx1 = self.y_x_add(y).unsqueeze(1)
        yx2 = self.y_x_mul(y).unsqueeze(1)
        newX = yx1 + (yx2 + 1) * weighted_V

        # Output X
        newX = self.x_out(newX) * x_mask
        assert_correctly_masked(newX, x_mask)

        # Process y based on X and E
        y = self.y_y(y)
        e_y = self.e_y(E)
        x_y = self.x_y(X)
        new_y = y + x_y + e_y
        new_y = self.y_out(new_y)               # bs, dy

        return newX, newE, new_y

In [13]:
# Example parameters for the NodeEdgeBlock
dx = 64
de = 32
dy = 16
n_head = 8
bs = 2  # batch size
n = 10  # number of nodes

# Initialize the NodeEdgeBlock
node_edge_block = NodeEdgeBlock(dx, de, dy, n_head)

# Create example input tensors
X = torch.randn(bs, n, dx)
E = torch.randn(bs, n, n, de)
y = torch.randn(bs, dy)
node_mask = torch.ones(bs, n)  # example mask where all nodes are valid

# Forward pass
newX, newE, new_y = node_edge_block(X, E, y, node_mask)

# Print shapes of the output tensors
print("Shape of newX:", newX.shape)
print("Shape of newE:", newE.shape)
print("Shape of new_y:", new_y.shape)

torch.Size([2, 10, 10, 64])
torch.Size([2, 10, 10, 8, 8])
Shape of newX: torch.Size([2, 10, 64])
Shape of newE: torch.Size([2, 10, 10, 32])
Shape of new_y: torch.Size([2, 16])


Transformer that updates node, edge and global features
- **$\mathbf{X}$**: Node features matrix, shape ( bs, $n, d_x$ )
- **$\mathbf{E}$**: Edge features matrix, shape (bs, $n, n, d_e$ )
- **$\mathbf{y}$**: Global features vector, shape (bs, $d_y$ )
- **node_mask**: Node mask, shape (bs, $n$)
- **$\mathbf{W}$**: Weight matrices for the linear layers
- **$\mathbf{b}$**: Bias vectors for the linear layers
- **LN**: Layer normalization ...
- **Dropout**: Dropout operation ...
- **ReLU**: Rectified Linear Unit activation function

In [7]:
class XEyTransformerLayer(nn.Module):
    """ Transformer that updates node, edge and global features
        d_x: node features
        d_e: edge features
        dz : global features
        n_head: the number of heads in the multi_head_attention
        dim_feedforward: the dimension of the feedforward network model after self-attention
        dropout: dropout probablility. 0 to disable
        layer_norm_eps: eps value in layer normalizations.
    """
    def __init__(self, dx: int, de: int, dy: int, n_head: int, dim_ffX: int = 2048,
                 dim_ffE: int = 128, dim_ffy: int = 2048, dropout: float = 0.1,
                 layer_norm_eps: float = 1e-5, device=None, dtype=None) -> None:
        kw = {'device': device, 'dtype': dtype}
        super().__init__()

        self.self_attn = NodeEdgeBlock(dx, de, dy, n_head, **kw)

        self.linX1 = Linear(dx, dim_ffX, **kw)
        self.linX2 = Linear(dim_ffX, dx, **kw)
        self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.dropoutX1 = Dropout(dropout)
        self.dropoutX2 = Dropout(dropout)
        self.dropoutX3 = Dropout(dropout)

        self.linE1 = Linear(de, dim_ffE, **kw)
        self.linE2 = Linear(dim_ffE, de, **kw)
        self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.dropoutE1 = Dropout(dropout)
        self.dropoutE2 = Dropout(dropout)
        self.dropoutE3 = Dropout(dropout)

        self.lin_y1 = Linear(dy, dim_ffy, **kw)
        self.lin_y2 = Linear(dim_ffy, dy, **kw)
        self.norm_y1 = LayerNorm(dy, eps=layer_norm_eps, **kw)
        self.norm_y2 = LayerNorm(dy, eps=layer_norm_eps, **kw)
        self.dropout_y1 = Dropout(dropout)
        self.dropout_y2 = Dropout(dropout)
        self.dropout_y3 = Dropout(dropout)

        self.activation = F.relu

    def forward(self, X: Tensor, E: Tensor, y, node_mask: Tensor):
        """ Pass the input through the encoder layer.
            X: (bs, n, d)
            E: (bs, n, n, d)
            y: (bs, dy)
            node_mask: (bs, n) Mask for the src keys per batch (optional)
            Output: newX, newE, new_y with the same shape.
        """
        newX, newE, new_y = self.self_attn(X, E, y, node_mask=node_mask)

        newX_d = self.dropoutX1(newX)
        X = self.normX1(X + newX_d)

        newE_d = self.dropoutE1(newE)
        E = self.normE1(E + newE_d)

        new_y_d = self.dropout_y1(new_y)
        y = self.norm_y1(y + new_y_d)

        ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X))))
        ff_outputX = self.dropoutX3(ff_outputX)
        X = self.normX2(X + ff_outputX)

        ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E))))
        ff_outputE = self.dropoutE3(ff_outputE)
        E = self.normE2(E + ff_outputE)

        ff_output_y = self.lin_y2(self.dropout_y2(self.activation(self.lin_y1(y))))
        ff_output_y = self.dropout_y3(ff_output_y)
        y = self.norm_y2(y + ff_output_y)

        return X, E, y

### XEyTransformerLayer

```python
def __init__(self, dx: int, de: int, dy: int, n_head: int, dim_ffX: int = 2048,
                 dim_ffE: int = 128, dim_ffy: int = 2048, dropout: float = 0.1,
                 layer_norm_eps: float = 1e-5)
```

Transformer that updates node, edge and global features
- **$\mathbf{X}$**: Node features matrix, shape ( bs, $n, d_x$ )
- **$\mathbf{E}$**: Edge features matrix, shape (bs, $n, n, d_e$ )
- **$\mathbf{y}$**: Global features vector, shape (bs, $d_y$ )
- **node_mask**: Node mask, shape (bs, $n$ )
- **$\mathbf{W}$**: Weight matrices for the linear layers
- **$\mathbf{b}$**: Bias vectors for the linear layers
- **LN**: Layer normalization
- **Dropout**: Dropout operation
- **ReLU**: Rectified Linear Unit activation function

#### Self Attention

$$\mathbf{X}_{\text {new }}, \mathbf{E}_{\text {new}}, \mathbf{y}_{\text {new}}=\operatorname{SelfAttn}\left(\mathbf{X}, \mathbf{E}, \mathbf{y}, \text{node}_\text{mask}\right)$$


#### Residual and Layer Normalization 

\begin{gathered}
\mathbf{X}_{\text {residual }}=\mathbf{X}+\operatorname{Dropout}\left(\mathbf{X}_{\text {new }}\right) \\
\mathbf{X}=\operatorname{LN}\left(\mathbf{X}_{\text {residual }}\right)
\end{gathered}


####  Feed-Forward Layer 

\begin{gathered}
\mathbf{X}_{\mathrm{ff}}=\operatorname{Dropout}\left(\operatorname{ReLU}\left(\mathbf{X} \mathbf{W}_{X 1}+\mathbf{b}_{X 1}\right)\right) \\
\mathbf{X}_{\mathrm{ff}}=\mathbf{X}_{\mathrm{ff}} \mathbf{W}_{X 2}+\mathbf{b}_{X 2} \\
\mathbf{X}_{\mathrm{ff}}=\operatorname{Dropout}\left(\mathbf{X}_{\mathrm{ff}}\right) \\
\mathbf{X}=\mathrm{LN}\left(\mathbf{X}+\mathbf{X}_{\mathrm{ff}}\right)
\end{gathered}


### NodeEdgeBlock
Self-attention layer that also updates the representations on the edges

```python
def __init__(self, dx, de, dy, n_head, **kwargs):
```

#### Linear Projections and Maskings 

\begin{aligned}
& \mathbf{Q}=\mathbf{X} \mathbf{W}_Q \odot \mathbf{x}_{\text {mask }} \\
& \mathbf{K}=\mathbf{X} \mathbf{W}_K \odot \mathbf{x}_{\text {mask }} \\
& \mathbf{V}=\mathbf{X} \mathbf{W}_V \odot \mathbf{x}_{\text {mask }}
\end{aligned}


#### Reshape for Multi-head Attention

\begin{aligned}
&\mathbf{Q} = \operatorname{reshape}\left(\mathbf{Q}, \left(\mathrm{bs}, n, \mathrm{n_head}, \mathrm{df}\right)\right) \\
&\mathbf{K} = \operatorname{reshape}\left(\mathbf{K}, \left(\mathrm{bs}, n, \mathrm{n_head}, \mathrm{df}\right)\right) \\
&\mathbf{V} = \operatorname{reshape}\left(\mathbf{V}, \left(\mathrm{bs}, n, \mathrm{n_head}, \mathrm{df}\right)\right)
\end{aligned}


#### Compute Unnormalized Attention Scores

$$\mathbf{Y} = \frac{(\mathbf{Q} \odot \mathbf{x}_{\text{mask}}) \cdot (\mathbf{K} \odot \mathbf{x}_{\text{mask}})^{\top}}{\sqrt{\mathrm{df}}}$$


#### Incorporate Edge Features 

\begin{gathered}
\mathbf{E}_{\text {mul }}=\operatorname{reshape}\left(\mathbf{E} \mathbf{W}_{\text {e_mul }},\left(\mathrm{bs}, n, n, \mathrm{n} _ \text {head, df }\right)\right) \odot \mathbf{e}_{\text {mask1 }} \odot \mathbf{e}_{\text {mask2 }} \\
\mathbf{E}_{\text {add }}=\operatorname{reshape}\left(\mathbf{E} \mathbf{W}_{\text {e_add }},\left(\mathrm{bs}, n, n, \mathrm{n} _ \text {head, df }\right)\right) \odot \mathbf{e}_{\text {mask1 }} \odot \mathbf{e}_{\text {mask2 }} \\
\mathbf{Y}=\mathbf{Y} \odot\left(\mathbf{E}_{\text {mul }}+1\right)+\mathbf{E}_{\text {add }}
\end{gathered}


#### Compute Weighted Values

\begin{aligned}
& \mathbf{V} = \operatorname{reshape}(\mathbf{V}, (\text{bs}, 1, n, \text{n_head}, \text{df})) \odot \mathbf{x}_{\text{mask}} \quad (\text{expand mask}) \\
& \mathbf{weighted_V} = \sum_{j} \mathbf{A}_{ij} \mathbf{V}_{j} \\
& \mathbf{weighted_V} = \operatorname{reshape}(\mathbf{weighted_V}, (\text{bs}, n, d_x))
\end{aligned}


#### Incorporate Global Features to Node Features

\begin{aligned}
& \mathbf{y}_{\text{x_add}} = \mathbf{y} \mathbf{W}_{\text{y_x_add}} \odot \mathbf{x}_{\text{mask}} \\
& \mathbf{y}_{\text{x_mul}} = \mathbf{y} \mathbf{W}_{\text{y_x_mul}} \odot \mathbf{x}_{\text{mask}} \\
& \mathbf{newX} = \mathbf{y}_{\text{x_add}} + (\mathbf{y}_{\text{x_mul}} + 1) \odot \mathbf{weighted_V}
\end{aligned}


#### Output Node Features

$$\mathbf{newX} = \mathbf{newX} \mathbf{W}_{\text{x_out}} \odot \mathbf{x}_{\text{mask}}$$


### Summary

The update of the node features $\mathbf{X}$ in the `NodeEdgeBlock` involves:
1. Projecting $\mathbf{X}$ to queries $\mathbf{Q}$, keys $\mathbf{K}$, and values $\mathbf{V}$.
2. Computing scaled dot-product attention scores.
3. Incorporating edge features into the attention scores.
4. Applying a masked softmax to obtain attention weights.
5. Computing weighted sums of values based on the attention weights.
6. Incorporating global features into the updated node features.
7. Producing the final updated node features.

### GraphTransformer 

The `GraphTransformer` class combines multi-layer perceptrons (MLPs) and transformer layers to process graph data. 

1. **Initial Feature Processing:**

   \begin{aligned}
   &\mathbf{X}_{\text{in}} = \sigma(\mathbf{W}_{1X} \sigma(\mathbf{W}_{0X} \mathbf{X} + \mathbf{b}_{0X}) + \mathbf{b}_{1X}) \\
   &\mathbf{E}_{\text{in}} = \frac{1}{2} (\sigma(\mathbf{W}_{1E} \sigma(\mathbf{W}_{0E} \mathbf{E} + \mathbf{b}_{0E}) + \mathbf{b}_{1E}) + \sigma(\mathbf{W}_{1E} \sigma(\mathbf{W}_{0E} \mathbf{E}^T + \mathbf{b}_{0E}) + \mathbf{b}_{1E})) \\
   &\mathbf{y}_{\text{in}} = \sigma(\mathbf{W}_{1y} \sigma(\mathbf{W}_{0y} \mathbf{y} + \mathbf{b}_{0y}) + \mathbf{b}_{1y}) \\
   \end{aligned}


2. **Transformer Layers:**
   For each transformer layer $i$:

   $$\mathbf{X}, \mathbf{E}, \mathbf{y} = \text{tf_layers}[i](\mathbf{X}, \mathbf{E}, \mathbf{y}, \text{node_mask})$$


3. **Final Feature Processing:**



   \begin{aligned}
   &\mathbf{X}_{\text{out}} = \sigma(\mathbf{W}_{3X} \sigma(\mathbf{W}_{2X} \mathbf{X} + \mathbf{b}_{2X}) + \mathbf{b}_{3X}) + \mathbf{X} \\
   &\mathbf{E}_{\text{out}} = \sigma(\mathbf{W}_{3E} \sigma(\mathbf{W}_{2E} \mathbf{E} + \mathbf{b}_{2E}) + \mathbf{b}_{3E}) \cdot \text{diag_mask} + \mathbf{E} \cdot \text{diag_mask} \\
   &\mathbf{y}_{\text{out}} = \sigma(\mathbf{W}_{3y} \sigma(\mathbf{W}_{2y} \mathbf{y} + \mathbf{b}_{2y}) + \mathbf{b}_{3y}) + \mathbf{y} \\
   \end{aligned}


4. **Symmetrizing Edge Features:**

   $$\mathbf{E}_{\text{out}} = \frac{1}{2} (\mathbf{E}_{\text{out}} + \mathbf{E}_{\text{out}}^T)$$


5. **Return Masked Outputs:**

   $$\text{utils.PlaceHolder}(\mathbf{X} = \mathbf{X}_{\text{out}}, \mathbf{E} = \mathbf{E}_{\text{out}}, \mathbf{y} = \mathbf{y}_{\text{out}}).mask(\text{node_mask})$$


### Summary
The `GraphTransformer` processes node (\(\mathbf{X}\)), edge (\(\mathbf{E}\)), and global (\(\mathbf{y}\)) features through MLPs, applies multiple transformer layers to update the features, and then processes the updated features through output MLPs before returning the final results with symmetric edge features and masked nodes.

In [None]:





class XEyTransformerLayer(nn.Module):
    """ Transformer that updates node, edge and global features
        d_x: node features
        d_e: edge features
        dz : global features
        n_head: the number of heads in the multi_head_attention
        dim_feedforward: the dimension of the feedforward network model after self-attention
        dropout: dropout probablility. 0 to disable
        layer_norm_eps: eps value in layer normalizations.
    """
    def __init__(self, dx: int, de: int, dy: int, n_head: int, dim_ffX: int = 2048,
                 dim_ffE: int = 128, dim_ffy: int = 2048, dropout: float = 0.1,
                 layer_norm_eps: float = 1e-5, device=None, dtype=None) -> None:
        kw = {'device': device, 'dtype': dtype}
        super().__init__()

        self.self_attn = NodeEdgeBlock(dx, de, dy, n_head, **kw)

        self.linX1 = Linear(dx, dim_ffX, **kw)
        self.linX2 = Linear(dim_ffX, dx, **kw)
        self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.dropoutX1 = Dropout(dropout)
        self.dropoutX2 = Dropout(dropout)
        self.dropoutX3 = Dropout(dropout)

        self.linE1 = Linear(de, dim_ffE, **kw)
        self.linE2 = Linear(dim_ffE, de, **kw)
        self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.dropoutE1 = Dropout(dropout)
        self.dropoutE2 = Dropout(dropout)
        self.dropoutE3 = Dropout(dropout)

        self.lin_y1 = Linear(dy, dim_ffy, **kw)
        self.lin_y2 = Linear(dim_ffy, dy, **kw)
        self.norm_y1 = LayerNorm(dy, eps=layer_norm_eps, **kw)
        self.norm_y2 = LayerNorm(dy, eps=layer_norm_eps, **kw)
        self.dropout_y1 = Dropout(dropout)
        self.dropout_y2 = Dropout(dropout)
        self.dropout_y3 = Dropout(dropout)

        self.activation = F.relu

    def forward(self, X: Tensor, E: Tensor, y, node_mask: Tensor):
        """ Pass the input through the encoder layer.
            X: (bs, n, d)
            E: (bs, n, n, d)
            y: (bs, dy)
            node_mask: (bs, n) Mask for the src keys per batch (optional)
            Output: newX, newE, new_y with the same shape.
        """

        newX, newE, new_y = self.self_attn(X, E, y, node_mask=node_mask)

        newX_d = self.dropoutX1(newX)
        X = self.normX1(X + newX_d)

        newE_d = self.dropoutE1(newE)
        E = self.normE1(E + newE_d)

        new_y_d = self.dropout_y1(new_y)
        y = self.norm_y1(y + new_y_d)

        ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X))))
        ff_outputX = self.dropoutX3(ff_outputX)
        X = self.normX2(X + ff_outputX)

        ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E))))
        ff_outputE = self.dropoutE3(ff_outputE)
        E = self.normE2(E + ff_outputE)

        ff_output_y = self.lin_y2(self.dropout_y2(self.activation(self.lin_y1(y))))
        ff_output_y = self.dropout_y3(ff_output_y)
        y = self.norm_y2(y + ff_output_y)

        return X, E, y



class GraphTransformer(nn.Module):
    """
    n_layers : int -- number of layers
    dims : dict -- contains dimensions for each feature type
    """
    def __init__(self, n_layers: int, input_dims: dict, hidden_mlp_dims: dict, hidden_dims: dict,
                 output_dims: dict, act_fn_in: nn.ReLU(), act_fn_out: nn.ReLU()):
        super().__init__()
        self.n_layers = n_layers
        self.out_dim_X = output_dims['X']
        self.out_dim_E = output_dims['E']
        self.out_dim_y = output_dims['y']

        self.mlp_in_X = nn.Sequential(nn.Linear(input_dims['X'], hidden_mlp_dims['X']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['X'], hidden_dims['dx']), act_fn_in)

        self.mlp_in_E = nn.Sequential(nn.Linear(input_dims['E'], hidden_mlp_dims['E']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['E'], hidden_dims['de']), act_fn_in)

        self.mlp_in_y = nn.Sequential(nn.Linear(input_dims['y'], hidden_mlp_dims['y']), act_fn_in,
                                      nn.Linear(hidden_mlp_dims['y'], hidden_dims['dy']), act_fn_in)

        self.tf_layers = nn.ModuleList([XEyTransformerLayer(dx=hidden_dims['dx'],
                                                            de=hidden_dims['de'],
                                                            dy=hidden_dims['dy'],
                                                            n_head=hidden_dims['n_head'],
                                                            dim_ffX=hidden_dims['dim_ffX'],
                                                            dim_ffE=hidden_dims['dim_ffE'])
                                        for i in range(n_layers)])

        self.mlp_out_X = nn.Sequential(nn.Linear(hidden_dims['dx'], hidden_mlp_dims['X']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['X'], output_dims['X']))

        self.mlp_out_E = nn.Sequential(nn.Linear(hidden_dims['de'], hidden_mlp_dims['E']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['E'], output_dims['E']))

        self.mlp_out_y = nn.Sequential(nn.Linear(hidden_dims['dy'], hidden_mlp_dims['y']), act_fn_out,
                                       nn.Linear(hidden_mlp_dims['y'], output_dims['y']))

    def forward(self, X, E, y, node_mask):
        bs, n = X.shape[0], X.shape[1]

        diag_mask = torch.eye(n)
        diag_mask = ~diag_mask.type_as(E).bool()
        diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)

        X_to_out = X[..., :self.out_dim_X]
        E_to_out = E[..., :self.out_dim_E]
        y_to_out = y[..., :self.out_dim_y]

        new_E = self.mlp_in_E(E)
        new_E = (new_E + new_E.transpose(1, 2)) / 2
        logging.debug(f"X shape: {X.shape}")
        after_in = utils.PlaceHolder(X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y)).mask(node_mask)
        logging.debug(f"after_in.X shape: {after_in.X.shape}")
        X, E, y = after_in.X, after_in.E, after_in.y

        for layer in self.tf_layers:
            X, E, y = layer(X, E, y, node_mask)

        X = self.mlp_out_X(X)
        E = self.mlp_out_E(E)
        y = self.mlp_out_y(y)

        X = (X + X_to_out)
        E = (E + E_to_out) * diag_mask
        y = y + y_to_out

        E = 1/2 * (E + torch.transpose(E, 1, 2))

        return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)
