<h1> Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

from torch.utils.data import Dataset
from torch_geometric.data import Data 
from torch_geometric.loader import DataLoader
from torch_geometric.utils import dense_to_sparse
import h5py


  from .autonotebook import tqdm as notebook_tqdm


<h1> Data Set Class

In [2]:
class h5Dataset(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path

        with h5py.File(self.h5_path, "r") as f:
            self.keys = list(f.keys())

        self._edge_cache = {} # cache for fully connected edges to improve performance
        self._f = None  # open file once 

    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, index):
        # --- open file ---
        f = self._get_file()

        # --- load sample ---
        sample = f[self.keys[index]]

        # --- load and cast raw features from sample ---
        csm = torch.from_numpy(sample["csm"][:]).squeeze().to(torch.complex64) # (N, N), complex64
        eigmode = torch.from_numpy(sample["eigmode"][:]).to(torch.complex64) # (N, N), complex64
        eigmode = torch.view_as_real(eigmode).to(torch.float32)  
        coords = torch.from_numpy(sample["cartesian_coordinates"][:]).T.to(torch.float32) # (N, 3), float32 
        loc = torch.from_numpy(sample["loc"][:]).to(torch.float32) # (3, nsources), float32
        source_strength = torch.from_numpy(sample["source_strength_analytic"][:]).squeeze(0).to(torch.float32) # (nsources,), float32


        # --- normalize raw features ---
        #TODO: check alternative approach normalize autopower by trace and cross spectra by coherence
        csm = csm / torch.trace(csm).real
        source_strength = source_strength / source_strength.sum()

        # --- define node features ---        
        theta = torch.atan2(coords[:, 1], coords[:, 0])
        cos_theta = torch.cos(theta) # (N,), float32
        sin_theta = torch.sin(theta) # (N,), float32

        r = torch.sqrt(coords[:, 0]**2 + coords[:, 1]**2) # (N,), float32
        r = r / (r.max() + 1e-8) # normalize radius  
        
        autopower = torch.diagonal(csm) # (N,), complex64
        autopower_real = autopower.real # (N,), float32
        autopower_imag = autopower.imag # (N,), float32

        #TODO: implement positional encoding (Min-Sang Baek, Joon-Hyuk Chang, and Israel Cohen) 
 
        # --- define adjacency--- 
        N = coords.size(0)
        edge_index = self.get_fully_connected_edges(N)   # (2, E), cached, no self-loops

        src, dst = edge_index  # (E,), (E,)

        # --- define edge features ---
        cross_spectra = csm[src, dst]  # (E, 1), complex64
        cross_spectra_real = cross_spectra.real # (E, 1), float32
        cross_spectra_imag = cross_spectra.imag # (E, 1), float32

        dx = (coords[dst, 0] - coords[src, 0])
        dy = (coords[dst, 1] - coords[src, 1])   
        dist = torch.sqrt(dx**2 + dy**2 + 1e-8) # (E, 1), float32
        
        unit_direction_x = dx / dist # (E, 1), float32 
        unit_direction_y = dy / dist # (E, 1), float32

        cos_sim = (cos_theta[src] * cos_theta[dst] + sin_theta[src] * sin_theta[dst]) # (E, 1), float32, computed with trigonometric identity

        #TODO: implement directional features (Jingjie Fan, Rongzhi Gu, Yi Luo, and Cong Pang)


        # --- build feature vectors ---
        node_feat = self.build_feature(coords, r, cos_theta, sin_theta, autopower_real, autopower_imag, dim=1) # (N, F_node)
        edge_attr = self.build_feature(cross_spectra_real,cross_spectra_imag, dist, unit_direction_x, unit_direction_y, cos_sim, dim=1)  # (E, F_edge)

        # ---  define eigmode tokens analog to Kujawaski et. al---
        eigmode = torch.cat([torch.cat([eigmode[..., 0], -eigmode[..., 1]], dim=-1), torch.cat([eigmode[..., 1],  eigmode[..., 0]], dim=-1),],dim=-2,)

        # --- labels ---
        loc_strongest_source = loc[:,torch.argmax(source_strength)]
        loc_strongest_source = loc_strongest_source[:2].unsqueeze(0) #x and y coordinates only

        strength_strongest_source = source_strength[torch.argmax(source_strength)] 

        # --- build PyG Data ---
        data = Data(
            x=node_feat,                 # (N, F_node)
            edge_index=edge_index,       # (2, E)
            edge_attr=edge_attr,         # (E, F_edge)
            #TODO: Change to multiple sources and strengths later on
            y=loc_strongest_source,      # label used by training loop
        )

        data.eigmode = eigmode

        return data
    

    #--- utility functions ---
    @staticmethod
    def build_feature(*feats, dim=-1):
        """
        Utility function to construct a feature tensor from multiple inputs.

        If a tensor is 1D (shape: [N]), it is automatically expanded to
        shape [N, 1] so that it can be concatenated with higher-dimensional
        feature tensors.

        Parameters
        ----------
        *feats : torch.Tensor
            Feature tensors to be combined. Must be broadcast-compatible
            except for the concatenation dimension.
        dim : int, optional
            Dimension along which to concatenate the features (default: -1).

        Returns
        -------
        torch.Tensor
            Concatenated feature tensor.
        """
        feats = [feature.unsqueeze(-1) if feature.dim() == 1 else feature for feature in feats]
        return torch.cat(feats, dim=dim)

    def _get_file(self):
        """
        Lazily opens the HDF5 file and keeps it open for reuse
        to avoids repeatedly opening and closing the HDF5 file on every
        __getitem__ call. Reduces I/O overhead.

        """
        if self._f is None:
            self._f = h5py.File(self.h5_path, "r")
        return self._f

    def get_fully_connected_edges(self, N):
        """
        Returns the edge_index of a fully connected directed graph with N nodes,
        excluding self-loops and caches the result for performance.

        Parameters
        ----------
        N : int
            Number of nodes in the graph.

        Returns
        -------
        edge_index : torch.Tensor
            Edge index tensor 
        """
        if N not in self._edge_cache:
            adj = torch.ones(N, N, dtype=torch.bool)
            adj.fill_diagonal_(False)
            self._edge_cache[N] = dense_to_sparse(adj)[0]

        return self._edge_cache[N]

<h1> Modules

In [3]:
class MPNNLayer(MessagePassing):
    """
    One message-passing block with edge features, mean aggregation,
    residual connection, and LayerNorm to reduce oversmoothing in fully-connected graphs.

    Input/Output node dim stays constant: hidden_dim -> hidden_dim
    """
    def __init__(self, hidden_dim: int, edge_dim: int, dropout: float = 0.0):
        
        super().__init__(aggr="mean") # fully connected graph, so messages dont blow up with "add" aggregation

        self.msg_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim + edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
        )

        self.upd_mlp = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
        )

        self.norm = nn.LayerNorm(hidden_dim) # helps with oversmoothing

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        # propagate: message passing + aggregation + update
        out = self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
        # residual helps with oversmoothing
        return self.norm(x + out)
        #return x + out

    def message(self, x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        # x_i: target node features [E, H]
        # x_j: source node features [E, H]
        # edge_attr:             [E, E_dim]
        msg_in = torch.cat([x_i, x_j, edge_attr], dim=-1)  # [E, 2H + E_dim]
        return self.msg_mlp(msg_in)                        # [E, H]

    def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        # aggr_out: [N, H], x: [N, H]
        upd_in = torch.cat([x, aggr_out], dim=-1)          # [N, 2H]
        return self.upd_mlp(upd_in)                        # [N, H]

class MPNNTokenizer(nn.Module):
    """
    MPNN-based tokenizer to convert graphs into token embeddings for attention mechanism:
      node_in -> hidden_dim -> (L x MPNNLayer) -> out_dim
    """
    def __init__(
        self,
        node_in_dim: int,
        edge_in_dim: int,
        hidden_dim: int = 128,
        out_dim: int = 128,
        num_layers: int = 2,
        dropout: float = 0.0,
    ):
        super().__init__()

        # --- project node features to hidden dim
        self.node_encoder = nn.Sequential(
            nn.Linear(node_in_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
        )

        # --- multiple MPNN layers
        self.layers = nn.ModuleList([
            MPNNLayer(hidden_dim=hidden_dim, edge_dim=edge_in_dim, dropout=dropout)
            for _ in range(num_layers)
        ])

        # --- final projection to the token dim you want for cross-attention
        self.node_head = nn.Linear(hidden_dim, out_dim)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        """
        x:         [N, node_in_dim]
        edge_index:[2, E]
        edge_attr: [E, edge_in_dim]
        returns:   [N, out_dim] microphone embeddings (tokens)
        """
        h = self.node_encoder(x)
        for layer in self.layers:
            h = layer(h, edge_index, edge_attr)
        return self.node_head(h)

class SelfAttentionEncoder(nn.Module):
    """
    Transformer encoder following ViT-Base architecture:
    - 12 layers of multi-head self-attention (8 heads, D=128)
    - Takes microphone embeddings and outputs pooled vector for MLP head
    """
    def __init__(
        self,
        embed_dim: int = 128,
        num_heads: int = 8, # embed_dim must be divisible by num_heads 
        num_layers: int = 12,
        dropout: float = 0.1,
    ):
        super().__init__()
        
        self.embed_dim = embed_dim

        # --- learnable CLS token (1 token, D) ---
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # --- Multihead self-attention layer ---
        multihead_self_attention_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,  # typical 4x expansion in transformer FFN
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True,  # pre-norm for better training stability
        )
        
        # -- Transformer encoder with multiple multihead self-attention layers ---
        self.transformer_encoder = nn.TransformerEncoder(
            multihead_self_attention_layer,
            num_layers=num_layers,
        )
        
    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        """
        tokens: [N, embed_dim] or [B, N, embed_dim] microphone embeddings from MPNNTokenizer
        returns: [embed_dim] or [B, embed_dim] encoded features after global pooling
        """
        # Adds batch dimension if needed: [N, D] -> [1, N, D]
        squeeze_output = False
        if tokens.dim() == 2:
            tokens = tokens.unsqueeze(0)
            squeeze_output = True
        
        #B = tokens.size(0)
        # prepend CLS token: [B, 1, D] + [B, N, D] -> [B, 1+N, D]
        #cls = self.cls_token.expand(B, -1, -1)
        #tokens = torch.cat([cls, tokens], dim=1)

        # --- Multihead self-attention layers ---
        encoded = self.transformer_encoder(tokens)  # [B, N, D]
        
        # --- Global pooling (mean over all microphone tokens) ---
        # TODO: try other pooling mechanism f.e. CLI Token
        #pooled = encoded.mean(dim=1)  # [B, D]
        # --- CLS pooling ---
        pooled = encoded[:, 0, :]  # [B, D]
        
        # Remove batch dimension if input was unbatched
        if squeeze_output:
            pooled = pooled.squeeze(0)  # [D]
        
        return pooled

class PredictionHead(nn.Module):
    """
    Two-layer MLP (512 neurons each) that outputs source locations and strengths.
    - Location head: outputs I source locations (x, y coordinates)
    - Strength head: outputs normalized strengths via Softmax
    """
    def __init__(
        self,
        embed_dim: int = 128,
        mlp_hidden_dim: int = 512,
        num_output_sources: int = 1,  # I = 1 source component
        dropout: float = 0.1,
    ):
        super().__init__()
        
        self.num_output_sources = num_output_sources
        
        # --- Two-layer MLP with 512 neurons each ---
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        
        # --- Source location head (2D coordinates per source) ---
        self.location_head = nn.Linear(mlp_hidden_dim, num_output_sources * 2)
        
        # --- Source strength head (normalized via softmax) ---
        #self.strength_head = nn.Linear(mlp_hidden_dim, num_output_sources)
        
    def forward(self, encoded_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        encoded_features: [embed_dim] or [B, embed_dim] from TransformerEncoder
        
        Returns:
        --------
        locations: [num_output_sources, 2] or [B, num_output_sources, 2] 
                   predicted source locations (x, y)
        strengths: [num_output_sources] or [B, num_output_sources]
                   normalized source strengths (sum to 1)
        """
        # Handle both batched and unbatched input
        squeeze_output = False
        if encoded_features.dim() == 1:
            encoded_features = encoded_features.unsqueeze(0)
            squeeze_output = True
        
        # --- MLP processing ---
        features = self.mlp(encoded_features)  # [B, mlp_hidden_dim]
        
        # --- LOCATION HEAD OUTPUT ---
        # Raw output is [B, num_output_sources * 2]
        locations = self.location_head(features)  # [B, I * 2]
        # Reshape to [B, I, 2] where each source has (x, y) coordinates
        locations = locations.view(-1, self.num_output_sources, 2)  # [B, I, 2]
        
        # --- STRENGTH HEAD OUTPUT ---
        # Raw output is [B, num_output_sources]
        #strengths = self.strength_head(features)  # [B, I]
        # Apply softmax to normalize strengths (they sum to 1)
        #strengths = torch.softmax(strengths, dim=-1)  # [B, I]
        
        # Remove batch dimension if input was unbatched
        if squeeze_output:
            locations = locations.squeeze(0)  # [I, 2]
            #strengths = strengths.squeeze(0)  # [I]
        
        return locations# , strengths

<h1> Model

In [4]:
class MPNNTransformerModel(nn.Module):
    """
    Full pipeline:
      data (f.e. h5Dataset) -> MPNNTokenizer -> SelfAttentionEncoder -> PredictionHead

    Expected inputs (PyGeometric style):
      x:         [N, node_in_dim]
      edge_index:[2, E]
      edge_attr: [E, edge_in_dim]

    Output:
      locations: [I, 2] (unbatched) or [B, I, 2] if you pass batched tokens later
    """
    def __init__(
        self,
        # --- tokenizer params --- #
        node_in_dim: int,
        edge_in_dim: int,
        mpnn_hidden_dim: int = 128,
        token_dim: int = 128,
        mpnn_num_layers: int = 1,
        mpnn_dropout: float = 0.0,
        # --- self-attention encoder params --- #
        attn_num_heads: int = 8, # token dim must be divisible by attn_num_heads
        attn_num_layers: int = 2, #12,
        attn_dropout: float = 0.0, #0.1,
        # --- prediction head params --- #
        head_mlp_hidden_dim: int = 512,
        num_output_sources: int = 1,
        head_dropout: float = 0.0 #0.1,
    ):
        super().__init__()

        # --- tokenizer (graph -> mic tokens) ---
        self.tokenizer = MPNNTokenizer(
            node_in_dim=node_in_dim,
            edge_in_dim=edge_in_dim,
            hidden_dim=mpnn_hidden_dim,
            out_dim=token_dim,
            num_layers=mpnn_num_layers,
            dropout=mpnn_dropout,
        )

        # --- self-attention encoder (tokens -> pooled embedding) ---
        self.encoder = SelfAttentionEncoder(
            embed_dim=token_dim,
            num_heads=attn_num_heads,
            num_layers=attn_num_layers,
            dropout=attn_dropout,
        )

        # --- prediction head (pooled embedding -> outputs) ---
        self.head = PredictionHead(
            embed_dim=token_dim,
            mlp_hidden_dim=head_mlp_hidden_dim,
            num_output_sources=num_output_sources,
            dropout=head_dropout,
        )

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor,
    ) -> torch.Tensor:
        
        # 1) graph -> mic tokens: [N, D]
        tokens = self.tokenizer(x=x, edge_index=edge_index, edge_attr=edge_attr)

        # 2) tokens -> pooled: [D] (or [B, D] if tokens were batched)
        #encoded = self.encoder(tokens)

        # 3) pooled -> locations: [I, 2] (or [B, I, 2])
        #locations = self.head(encoded)
        return tokens

    @torch.no_grad()
    def predict(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor,
    ) -> torch.Tensor:
        self.eval()
        return self.forward(x=x, edge_index=edge_index, edge_attr=edge_attr)

    def forward_from_data(self, data) -> torch.Tensor:
        """
        Convenience for PyG Data/Batch objects that expose .x, .edge_index, .edge_attr
        """
        return self.forward(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr)


<h1>Testing 

In [None]:
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Dataset
h5_path = "10samples.h5"  # <- CHANGE THIS
ds = h5Dataset(h5_path)

loader = DataLoader(ds, batch_size=10)

# take feature dims from one sample
sample0 = ds[0]
node_in_dim = sample0.x.shape[-1]
edge_in_dim = sample0.edge_attr.shape[-1]

# Build model
model = MPNNTransformerModel(
    node_in_dim=node_in_dim,
    edge_in_dim=edge_in_dim,
    num_output_sources=1,  
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-2, weight_decay=0.0)

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=50,   # every 10 epochs
    gamma=0.5       # multiply LR by 0.5
)

model.train()

#for data in loader:
    #data = data.to(device)
    #pred = model.forward_from_data(data)  
    #pred.shape

for batch in loader:
    #print(batch.batch)    
    #print(batch.batch) 
    #print(batch.batch.shape) 
    p

#Shape (Microphone Count * Batch Size, Token Dimension)


tensor([False, False, False, False, False,  True,  True,  True,  True,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])




IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)