<h1> Imports

In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing

<h1> Message Passing Layer

In [None]:
class MPNNLayer(MessagePassing):
    """
    One message-passing block with edge features, mean aggregation (stable across N),
    residual connection, and LayerNorm to reduce oversmoothing in fully-connected graphs.

    Input/Output node dim stays constant: hidden_dim -> hidden_dim
    """
    def __init__(self, hidden_dim: int, edge_dim: int, dropout: float = 0.0):
        
        super().__init__(aggr="mean") # fully connected graph, so messages dont blow up with "add" aggregation

        self.msg_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim + edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
        )

        self.upd_mlp = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
        )

        self.norm = nn.LayerNorm(hidden_dim) #helps with oversmoothing

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        # propagate: message passing + aggregation + update
        out = self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
        # residual helps with oversmoothing
        return self.norm(x + out)

    def message(self, x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        # x_i: target node features [E, H]
        # x_j: source node features [E, H]
        # edge_attr:             [E, E_dim]
        msg_in = torch.cat([x_i, x_j, edge_attr], dim=-1)  # [E, 2H + E_dim]
        return self.msg_mlp(msg_in)                        # [E, H]

    def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        # aggr_out: [N, H], x: [N, H]
        upd_in = torch.cat([x, aggr_out], dim=-1)          # [N, 2H]
        return self.upd_mlp(upd_in)                        # [N, H]



<h1> MPNN Encoder Model

In [None]:
class MicEncoderMPNN(nn.Module):
    """
    Full microphone embedding encoder:
      node_in -> hidden_dim -> (L x MPNNLayer) -> out_dim
    """
    def __init__(
        self,
        node_in_dim: int,
        edge_in_dim: int,
        hidden_dim: int = 128,
        out_dim: int = 128,
        num_layers: int = 2,
        dropout: float = 0.0,
    ):
        super().__init__()

        # --- project node features to hidden dim
        self.node_encoder = nn.Sequential(
            nn.Linear(node_in_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
        )

        # --- multiple MPNN layers
        self.layers = nn.ModuleList([
            MPNNLayer(hidden_dim=hidden_dim, edge_dim=edge_in_dim, dropout=dropout)
            for _ in range(num_layers)
        ])

        # --- final projection to the token dim you want for cross-attention
        self.node_head = nn.Linear(hidden_dim, out_dim)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        """
        x:         [N, node_in_dim]
        edge_index:[2, E]
        edge_attr: [E, edge_in_dim]
        returns:   [N, out_dim] microphone embeddings (tokens)
        """
        h = self.node_encoder(x)
        for layer in self.layers:
            h = layer(h, edge_index, edge_attr)
        return self.node_head(h)

<h1> Instantiating and Testing 

In [None]:
#Example usage:
model = MicEncoderMPNN(
    node_in_dim=data.x.size(-1),
    edge_in_dim=data.edge_attr.size(-1),
    hidden_dim=128,
    out_dim=128,
    num_layers=2,  
    dropout=0.1,
    aggr="mean",
)

mic_tokens = model(data.x, data.edge_index, data.edge_attr)  # [N, 128]

In [None]:
#Training Loop Example
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    out, h = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out, data.y)  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss, h

for epoch in range(1000):
    loss, h = train(data)
    if epoch % 100 == 0: print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
