<h2> Imports

In [1]:
import torch
from torch_geometric.data import Data

import sys
sys.path.append("../src")
from h5Dataset import h5Dataset
from model import MPNNTransformerModel

  from .autonotebook import tqdm as notebook_tqdm


<h1> Permutation Logic

In [2]:
def permute_graph_data(data: Data, perm: torch.Tensor) -> Data:
    
    #node count
    N = data.x.size(0)
    
    #inverse permutation, maps permuted indices back to original indices
    inv_perm = torch.empty_like(perm)
    inv_perm[perm] = torch.arange(N, device=perm.device)
    
    #permute node features
    x_perm = data.x[perm]

    #permute edge indices
    edge_index_perm = inv_perm[data.edge_index]
    
    #edge attributes need to be permuted according to the new edge order

    #create dictionary to map original edges to their old indices 
    src_orig, dst_orig = data.edge_index[0], data.edge_index[1]
    src_new, dst_new = edge_index_perm[0], edge_index_perm[1]
    
    edge_dict = {}
    for e_idx in range(data.edge_index.size(1)):
        edge_dict[(src_orig[e_idx].item(), dst_orig[e_idx].item())] = e_idx
    
    #new order of edge attributes based on permuted edge indices 
    edge_attr_perm_indices = []

    for e_idx in range(edge_index_perm.size(1)):
        #find old source and destination nodes for the current permuted edge
        orig_src = perm[src_new[e_idx]].item()
        orig_dst = perm[dst_new[e_idx]].item()
        
        #find old edge index for the current permuted edge
        orig_edge_idx = edge_dict[(orig_src, orig_dst)]
        #append index to the new edge attribute order
        edge_attr_perm_indices.append(orig_edge_idx)
    
    #permute edge attributes according to the new edge order
    edge_attr_perm = data.edge_attr[edge_attr_perm_indices]
    
    data_perm = Data(
        x=x_perm,
        edge_index=edge_index_perm,
        edge_attr=edge_attr_perm,
        y=data.y.clone() if hasattr(data, 'y') else None,
    )
    
    return data_perm     

<h1> Quick Test

In [3]:
def quick_test(h5_path: str, device_id: int = 0):
    # Setup
    device = torch.device("cpu")
    # Load data
    dataset = h5Dataset(h5_path)
    print(f"Dataset size: {len(dataset)} samples")
    # Get first sample
    data = dataset[1]
    N = data.x.size(0)
    print(f"Sample has {N} nodes (microphones)")
    
    # Build model
    node_in_dim = data.x.shape[-1]
    edge_in_dim = data.edge_attr.shape[-1]
    
    model = MPNNTransformerModel(
        node_in_dim=node_in_dim,
        edge_in_dim=edge_in_dim,
        num_output_sources=1,
    ).to(device)
    
    model.eval()
        
    with torch.no_grad():
        # Original prediction
        data_orig = data.to(device)
        pred_orig = model.forward_from_data(data_orig)
        target = data.y
        
        print(f"\nOriginal graph:")
        print(f"  Prediction: {pred_orig.squeeze().cpu().numpy()}")
        print(f"  Target:     {target.squeeze().cpu().numpy()}")
        
        # Test permutations
        print(f"\nPermuted graphs:")
        max_diff = 0.0
        
        for i in range(3):
            # Random permutation
            perm = torch.randperm(N)
            
            # Permute and predict
            data_perm = permute_graph_data(data, perm).to(device)
            pred_perm = model.forward_from_data(data_perm)
            
            # Compute difference
            diff = torch.abs(pred_orig - pred_perm).max().item()
            max_diff = max(max_diff, diff)
            
            print(f"  Permutation {i+1}: {pred_perm.squeeze().cpu().numpy()} "
                  f"(diff: {diff:.2e})")


quick_test("../data/samples/10samples.h5") 

Dataset size: 10 samples
Sample has 5 nodes (microphones)

Original graph:
  Prediction: [0.03259538 0.02138368]
  Target:     [ 0.03151155 -0.06910084]

Permuted graphs:
  Permutation 1: [0.03259539 0.02138364] (diff: 4.28e-08)
  Permutation 2: [0.03259538 0.02138364] (diff: 3.54e-08)
  Permutation 3: [0.03259538 0.02138365] (diff: 2.79e-08)




<h2> Notes

In [4]:
#Pemutation x_i = j means that node i in the original graph is node j in the permuted graph.
#Inverse permutation: x_i = j means that node i in the original graph is node j in the permuted graph.

# A B C D

# Index neue position Einträge Alte Position
perm = [2, 0, 3, 1]
# => C A D B

# Index alte Position, Einträge neue Position
inv_perm = [1, 3, 0, 2]
# => C A D B

edge_index =
[[0, 2, 3],   # source nodes
 [1, 0, 2]]   # destination nodes

inv_perm[edge_index] = 

SyntaxError: invalid syntax (2108332599.py, line 14)