In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_mean_pool

class ConnectomeTokenizer(nn.Module):
    def __init__(self, in_channels=32, hidden_dim=64, out_dim=128):
        super(ConnectomeTokenizer, self).__init__()

        nn1 = nn.Sequential(
            nn.Linear(in_channels, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
        self.gnn = GINEConv(nn1, edge_dim=1)
        self.out_dim = out_dim

    def forward(self, list_of_graph_lists):
        """
        list_of_graph_lists: List of size [B], each element is a list of 9 PyG graphs (one per band)
        Returns: token_tensor of shape [B, 9, out_dim]
        """
        batch_size = len(list_of_graph_lists)
        all_tokens = []

        for graphs_per_sample in list_of_graph_lists:
            sample_tokens = []
            for g in graphs_per_sample:
                # Make batch dimension to use PyG
                g.batch = torch.zeros(g.x.size(0), dtype=torch.long)  # 1 sample = 1 graph
                x = self.gnn(g.x, g.edge_index, g.edge_attr) ## [N, out_dim] , N is the number of nodes in the graph.
                token = global_mean_pool(x, g.batch)  # Shape: [1, out_dim]
                sample_tokens.append(token)

            # Shape: [9, out_dim]
            sample_tokens = torch.cat(sample_tokens, dim=0).unsqueeze(0)  # Add batch dim
            all_tokens.append(sample_tokens)

        # Shape: [B, 9, out_dim]
        token_tensor = torch.cat(all_tokens, dim=0)
        return token_tensor


In [None]:
from torch_geometric.data import Data

# Simulate 1 graph (19 nodes, features = 32)
def make_dummy_graph():
    x = torch.randn(19, 32)
    edge_index = torch.randint(0, 19, (2, 50)) ###Randomly connects 50 pairs of nodes
    edge_attr = torch.rand(50, 1)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# simulate one batch of 2 samples × 9 graphs each
batch_graphs = [
    [make_dummy_graph() for _ in range(9)],
    [make_dummy_graph() for _ in range(9)]
]

# Create and test the tokenizer
model = ConnectomeTokenizer(in_channels=32, hidden_dim=64, out_dim=128)
tokens = model(batch_graphs)  # it should be [2, 9, 128] because 2=B(nb of batches), 9=nb graphs per sample, 128=dimension of token representation
print("Token shape:", tokens.shape)


Token shape: torch.Size([2, 9, 128])


What it does:
Processes a batch of graphs (one graph per frequency band) and generates one token per graph.

Each token is a fixed-size vector ([out_dim]) that represents the graph-level embedding.

Input: A batch of graphs (e.g., [B, 9, graphs]).

Output: A tensor of shape [B, 9, D], where each token corresponds to a frequency band.

this is a dummy test, once the pre process is done (real data converted to 9 connectomes), we can replace the dummy graphs in this test with the actual graphs