# Hierarchical Graph Neural Networks with [<img src="https://raw.githubusercontent.com/tgp-team/torch-geometric-pool/main/docs/source/_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp](https://github.com/tgp-team/torch-geometric-pool)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tgp-team/torch-geometric-pool/blob/main/docs/source/tutorials/hierarchical_gnns.ipynb)

Graph pooling is a fundamental operation in Graph Neural Networks (GNNs) that enables **hierarchical learning** by coarsening graphs into smaller representations. Just as convolutional neural networks use pooling to build hierarchical features in images, graph pooling allows GNNs to capture multi-scale patterns in graph-structured data.

In this tutorial, we'll explore [<img src="https://raw.githubusercontent.com/tgp-team/torch-geometric-pool/main/docs/source/_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp](https://github.com/tgp-team/torch-geometric-pool) (Torch Geometric Pool), a comprehensive library that provides a unified framework for graph pooling operators. We'll learn how to:

- Understand the **SRC framework** ($\texttt{SEL}$-$\texttt{RED}$-$\texttt{CON}$) that unifies all pooling operators
- **Visualize** pooling operations to understand what they do
- Build GNNs with and without pooling to understand the benefits
- Work with both **sparse** and **dense** poolers
- Create **hierarchical architectures** with multiple pooling layers
- Optimize performance with **caching** and **precoarsening**
- Create **custom pooling operators** by mixing and matching components

For more details, see the [tgp paper](https://arxiv.org/abs/2512.12642) and the [documentation](https://torch-geometric-pool.readthedocs.io).


**This is the exercise version** of the tutorial. The full step-by-step solution for the custom pooler exercise is in the [Advanced](https://github.com/tgp-team/torch-geometric-pool/blob/main/docs/source/tutorials/advanced.ipynb) notebook.

Let's get started!

## Setup and Installation

First, let's install the required libraries. If you're running this in Google Colab, you'll need to restart the session after running the installation cells below.

In [None]:
import sys
if 'google.colab' in sys.modules:
    %pip install torch==2.4.1 --index-url https://download.pytorch.org/whl/cu124

In [None]:
import sys
if 'google.colab' in sys.modules:
    %pip install torch_geometric==2.6.1
    %pip install torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.4.0+cu124.html
    %pip install pygsp==0.6.1
    %pip install -q git+https://github.com/tgp-team/torch-geometric-pool.git@main

Now let's import the libraries we'll need throughout this tutorial.

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.data import Batch
from torch_geometric.datasets import TUDataset, Planetoid
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, DenseGCNConv, global_mean_pool
from torch_geometric import seed_everything
from torch_geometric.utils import to_dense_adj, to_networkx

# Visualization imports
import matplotlib.pyplot as plt
import networkx as nx
from tgp.datasets import PyGSPDataset

from tgp.poolers import get_pooler, pooler_classes, pooler_map

# Set random seed for reproducibility
seed_everything(42)
torch.set_printoptions(threshold=2, edgeitems=2)

# Load MUTAG dataset (we'll use this throughout the tutorial)
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')

## The SRC Framework

All pooling operators in [<img src="https://raw.githubusercontent.com/tgp-team/torch-geometric-pool/main/docs/source/_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp](https://github.com/tgp-team/torch-geometric-pool) follow the **SRC framework**, which decomposes pooling into three fundamental operations:

<img src="https://raw.githubusercontent.com/tgp-team/torch-geometric-pool/main/docs/source/_static/img/src_overview.png" style="width: 55%; display: block; margin: auto;">

- $\texttt{SEL}$ (**SELECT**): Determines how nodes map to supernodes in the coarsened graph. This produces a [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) containing the assignment matrix $\mathbf{S}$.
- $\texttt{RED}$ (**REDUCE**): Aggregates features of nodes assigned to the same supernode using $\mathbf{S}$.
- $\texttt{CON}$ (**CONNECT**): Creates edges between supernodes in the coarsened graph using $\mathbf{S}$.

There's also a $\texttt{LIFT}$ operation for unpooling (mapping coarsened features back to the original graph), which is useful for tasks like node classification.

This abstraction makes it easy to understand, compare, and create custom pooling operators. For more details, see the [SRC paper](https://arxiv.org/abs/2110.05292).

Let's explore the available poolers:

In [None]:
print("Available pooling operators in tgp:")
for i, pooler in enumerate(pooler_classes):
    print(f"{i+1:2d}. {pooler}")

Each pooler has a convenient alias for quick instantiation using [`get_pooler`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.get_pooler):

In [None]:
print("Pooler aliases:")
for alias, cls in pooler_map.items():
    print(f"  '{alias:8s}' → {cls.__name__}")

Let's create a [`TopkPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.TopkPooling) operator and examine its SRC components:

In [None]:
from tgp.poolers import TopkPooling

# Create a TopK pooler
pooler = TopkPooling(in_channels=7, ratio=0.5)
print(f"Pooler structure:")
print(pooler)

### PoolingOutput and SelectOutput

When we apply a pooler, it returns a [`PoolingOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/src.html#tgp.src.PoolingOutput) object containing:
- `x`: Node features of the coarsened graph
- `edge_index`: Edge connectivity of the coarsened graph
- `edge_weight`: Edge weights (if computed)
- `batch`: Batch assignment for the coarsened nodes
- `so`: [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) - describes the node-to-supernode mapping
- `loss`: Auxiliary losses (if any)

Let's apply the pooler to a batch of graphs:

In [None]:
# Create a small batch from the dataset
batch = Batch.from_data_list(dataset[:5])
print(f"Input batch: {batch}")
print(f"  Nodes: {batch.x.shape[0]}")
print(f"  Edges: {batch.edge_index.shape[1]}")

# Apply pooling
pool_out = pooler(x=batch.x, adj=batch.edge_index, batch=batch.batch)
print(f"\nPooling output: {pool_out}")
print(f"  Pooled nodes: {pool_out.x.shape[0]}")
print(f"  Pooled edges: {pool_out.edge_index.shape[1]}")
print(f"\nGraph was coarsened by ratio: {pool_out.x.shape[0] / batch.x.shape[0]:.2f}")

In [None]:
print("SelectOutput:")
print(pool_out.so)
print(f"\nIt maps {pool_out.so.num_nodes} nodes to {pool_out.so.num_supernodes} supernodes")

### The LIFT Operation

Pooling can also be reversed! The $\texttt{LIFT}$ operation maps coarsened features back to the original node space. This is useful for node-level tasks like node classification.

See the [node classification example](https://github.com/tgp-team/torch-geometric-pool/blob/main/examples/classification_node.py) for a complete use case.

In [None]:
# Lift the pooled features back to original space
x_lifted = pooler(x=pool_out.x, so=pool_out.so, batch=pool_out.batch, lifting=True)

print(f"Original features shape: {batch.x.shape}")
print(f"Pooled features shape: {pool_out.x.shape}")
print(f"Lifted features shape: {x_lifted.shape}")
print("\nLifting successfully maps pooled features back to original node space!")

## Visualizing Pooling Operations
 
Understanding how pooling works is easier when we can visualize it. Let's create some visualizations of the SELECT operation and the resulting graph coarsening.
 
The modular and standardized API of the poolers in [<img src="https://raw.githubusercontent.com/tgp-team/torch-geometric-pool/main/docs/source/_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp](https://github.com/tgp-team/torch-geometric-pool) allows for using the same functions to visualize the pooling results.
For testing and debugging, we can rely on the simple point clouds provided by the [PyGSP](https://pygsp.readthedocs.io/en/stable/) library for which [tgp](https://github.com/tgp-team/torch-geometric-pool) provide a wrapper for converting them into a proper torch dataset.
Below, we create a batch composed of a 2D-Grid, a Ring, and a Community graphs.

In [None]:
grid = PyGSPDataset(
    root="/tmp/PyGSP/",
    name="Grid2d",
    transform=None,
    pre_transform=None,
    pre_filter=None,
    force_reload=True,
    kwargs={"N1": 5, "N2": 5},
)[0]
print(f"2D-Grid: {grid}")

ring = PyGSPDataset(
    root="/tmp/PyGSP/",
    name="Ring",
    transform=None,
    pre_transform=None,
    pre_filter=None,
    force_reload=True,
    kwargs={"N": 30},
)[0]
ring.update({"x": ring.x + 2.0})
print(f"Ring: {ring}")

community = PyGSPDataset(
    root="/tmp/PyGSP/",
    name="Community",
    transform=None,
    pre_transform=None,
    pre_filter=None,
    force_reload=True,
    kwargs={"N": 18, "Nc": 3},
)[0]
community.update(
    {"x": torch.stack([community.x[:, 0] * 0.2 + 4.0, community.x[:, 1] * 0.2], dim=1)}
)
print(f"Community: {community}")

In [None]:
# Create a batch from the three datasets
pygsp_batch = Batch.from_data_list([grid, ring, community])
print(pygsp_batch)

### Visualizing the S Matrix

The SELECT operation produces an assignment matrix **S** that maps nodes to supernodes. Let's visualize this matrix for different poolers.

In [None]:
# Create poolers
topk_pooler = TopkPooling(in_channels=2, ratio=0.5)

# Get SelectOutput
so_topk = topk_pooler.select(x=pygsp_batch.x, batch=pygsp_batch.batch)

# Visualize S matrix
plt.figure(figsize=(4, 5))
plt.imshow(so_topk.s.to_dense().detach().numpy(), cmap='viridis', aspect='auto')
plt.title(r"$\mathbf{S}$ Matrix (TopK Pooling)", fontsize=14)
plt.xlabel("Supernodes", fontsize=12)
plt.ylabel("Original Nodes", fontsize=12)
plt.colorbar(label='Assignment Weight')
plt.tight_layout()
plt.show()

print(f"S matrix shape: [{so_topk.num_nodes}, {so_topk.num_supernodes}]")
print(f"Sparse assignment: {not so_topk.is_expressive}")
print("Each node is assigned to exactly one supernode (sparse pooling)")

In [None]:
# MinCut pooler (dense assignment)
mincut_vis = get_pooler('mincut', in_channels=2, k=20)

# Apply pooler and get SelectOutput for visualization
pool_out_mincut = mincut_vis(x=pygsp_batch.x, adj=pygsp_batch.edge_index, batch=pygsp_batch.batch)
so_mincut = pool_out_mincut.so

# Visualize S matrix
plt.figure(figsize=(4, 5))
S_matrix = so_mincut.s[0].detach().numpy()  # First graph in batch
plt.imshow(S_matrix, cmap='viridis', aspect='auto')
plt.title(r"$\mathbf{S}$ Matrix (MinCut Pooling) - First Graph", fontsize=14)
plt.xlabel("Supernodes", fontsize=12)
plt.ylabel("Original Nodes", fontsize=12)
plt.colorbar(label='Assignment Weight')
plt.tight_layout()
plt.show()

print(f"S matrix shape: {S_matrix.shape}")
print(f"Soft assignment: {so_mincut.is_expressive}")
print("Nodes can belong to multiple supernodes with different weights (dense pooling)")

### Visualizing Graph Coarsening

Now let's visualize how pooling coarsens the graph structure.

In [None]:
# Visualize original graph
G = to_networkx(pygsp_batch, to_undirected=True)
pos = pygsp_batch.x.numpy()

plt.figure(figsize=(5,5))
nx.draw(G, pos=pos, node_size=20, node_color='lightblue', 
        edge_color='gray', with_labels=False, alpha=0.8)
plt.title("Original Graphs (3 graphs batched)", fontsize=14)
plt.axis('equal')
plt.tight_layout()
plt.show()

print(f"Total nodes: {pygsp_batch.x.shape[0]}")
print(f"Total edges: {pygsp_batch.edge_index.shape[1]}")

In [None]:
# Apply pooling and highlight selected nodes
pool_out = topk_pooler(x=pygsp_batch.x, adj=pygsp_batch.edge_index, batch=pygsp_batch.batch)

# Get indices of selected nodes
selected_indices = pool_out.so.s.indices()[0].numpy()

plt.figure(figsize=(6, 6))
# Draw all nodes in gray
nx.draw(G, pos=pos, node_size=20, node_color='lightgray', 
        edge_color='lightgray', with_labels=False, alpha=0.5)
# Highlight selected nodes in red
nx.draw_networkx_nodes(G, pos=pos, nodelist=selected_indices.tolist(),
                       node_color='red', node_size=20, alpha=0.9)
plt.title(f"After TopK Pooling (red nodes selected, {len(selected_indices)}/{pygsp_batch.x.shape[0]} nodes kept)", 
          fontsize=14)
plt.axis('equal')
plt.tight_layout()
plt.show()

print(f"Selected nodes: {len(selected_indices)} out of {pygsp_batch.x.shape[0]}")
print(f"Reduction ratio: {len(selected_indices)/pygsp_batch.x.shape[0]:.2f}")

In [None]:
# Visualize the coarsened graph structure
from torch_geometric.data import Data

# Create Data object for pooled graph
pooled_data = Data(x=pool_out.x, edge_index=pool_out.edge_index, batch=pool_out.batch)
G_pooled = to_networkx(pooled_data, to_undirected=True)

plt.figure(figsize=(5, 5))
# Use pooled features as positions (which are the original positions of selected nodes)
pos_pooled = pool_out.x.detach().numpy()
nx.draw_networkx_nodes(G_pooled, pos=pos_pooled, node_size=20, 
                       node_color=pool_out.batch.numpy(), 
                       cmap='Set3', alpha=0.8)
nx.draw_networkx_edges(G_pooled, pos=pos_pooled, edge_color='gray', 
                       alpha=0.5, width=2)
plt.title("Coarsened Graph Structure", fontsize=14)
plt.axis('equal')
plt.tight_layout()
plt.show()

print(f"Coarsened graph nodes: {pool_out.x.shape[0]}")
print(f"Coarsened graph edges: {pool_out.edge_index.shape[1]}")

We'll use the **MUTAG** dataset for graph classification tasks. This dataset contains 188 chemical compounds (molecules), where each graph represents a molecule and the task is to predict whether it has mutagenic effects on a bacterium (binary classification).

For more information about datasets in PyG, see the [PyG documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html).

In [None]:
# Dataset info
print(f'Dataset: {dataset}')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Create train and test loaders
train_dataset = dataset[:150]
test_dataset = dataset[150:]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Look at one batch
batch = next(iter(train_loader))
print(f'\nSample batch: {batch}')

## Baseline GNN without Pooling

Let's start by building a simple GNN **without** graph pooling. This will serve as our baseline to understand the benefits that pooling brings.

Our baseline architecture uses [`GCNConv`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html) layers followed by global mean pooling:

```
[GCNConv → ReLU → GCNConv → ReLU → GlobalMeanPool → Linear]
```

In [None]:
class BaselineGCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(BaselineGCN, self).__init__()
        torch.manual_seed(42)
        
        # Two GCN layers
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        
        # Final classifier
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings through GCN layers
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()

        # 2. Readout layer: aggregate node features to graph-level
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

baseline_model = BaselineGCN(hidden_channels=64)
print(baseline_model)

Let's train the baseline model:

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = baseline_model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

print("Training baseline GNN (no pooling)...")
for epoch in range(1, 101):
    loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    if epoch % 5 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f'\nFinal Test Accuracy (Baseline): {test_acc:.4f}')

## Adding Pooling to a GNN

Now let's see how easy it is to add graph pooling to our GNN! We simply insert a pooling layer between message-passing layers.

Our new architecture with [`MaxCutPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.MaxCutPooling):

```
[GCNConv → MaxCutPool → GCNConv → GlobalPool → Linear]
```

Notice how easy it is - we just add one line with `get_pooler()`!

In [None]:
class GNNWithPooling(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        
        # First GCN layer
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        
        # Pooling layer - easily added with get_pooler!
        self.pooler = get_pooler('maxcut', in_channels=hidden_channels, ratio=0.5)
        
        # Second GCN layer (after pooling)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        
        # Classifier
        self.lin = Linear(hidden_channels, dataset.num_classes)
    
    def forward(self, x, edge_index, edge_weight, batch):
        # First GCN layer
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        
        # Pooling: coarsen the graph
        pool_out = self.pooler(x=x, adj=edge_index, edge_weight=edge_weight, batch=batch)
        x_pooled = pool_out.x
        edge_index_pooled = pool_out.edge_index
        edge_weight_pooled = pool_out.edge_weight
        batch_pooled = pool_out.batch
        
        # Second GCN layer on the pooled graph
        x = self.conv2(x_pooled, edge_index_pooled, edge_weight_pooled)
        x = F.relu(x)
        
        # Global pooling
        x = self.pooler.global_pool(x, reduce_op='mean', batch=batch_pooled)
        
        # Classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

pooling_model = GNNWithPooling(hidden_channels=64)
print(pooling_model)
print(f"\nPooler details:")
print(pooling_model.pooler)

Notice how the pooler shows its internal SRC components: `select`, `reduce`, `lift`, and `connect`. This is the SRC framework in action!

Let's train this model:

In [None]:
model = pooling_model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_weight, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_weight, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

print("Training GNN with pooling...")
for epoch in range(1, 101):
    loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    if epoch % 5 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f'\nFinal Test Accuracy (with pooling): {test_acc:.4f}')

## Dense Poolers and Auxiliary Losses

tgp supports two types of pooling operators:
- **Sparse poolers**: Work with edge lists (like TopK, SAG, ASAP)
- **Dense poolers**: Work with dense adjacency matrices (like MinCut, DiffPool, DMoN)

Dense poolers often come with **auxiliary losses** that help guide the learning process (e.g., the cut loss and orthogonality loss in [`MinCutPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.MinCutPooling)).

Let's explore MinCut pooling:

In [None]:
# Create a MinCut pooler (dense)
dense_pooler = get_pooler('mincut', in_channels=7, k=20)
print("MinCut pooler (dense):")
print(dense_pooler)
print(f"\nIs dense? {dense_pooler.is_dense}")
print(f"Has auxiliary loss? Check the cut_loss_coeff and ortho_loss_coeff parameters")

Poolers accept \((x, adj, batch)\) directly; dense poolers handle conversion internally.

Let's apply the dense pooler and examine the auxiliary losses:

In [None]:
# Apply dense pooling
dense_pool_out = dense_pooler(x=batch.x, adj=batch.edge_index, batch=batch.batch)

print("Dense pooling output:")
print(dense_pool_out)
print(f"\nPooled graph has {dense_pool_out.so.num_supernodes} supernodes (as specified)")
print(f"\nAuxiliary losses:")
for loss_name, loss_value in dense_pool_out.loss.items():
    print(f"  {loss_name}: {loss_value:.4f}")

### Building a GNN with Dense Pooling

When using dense poolers, we need to:
1. Use [`DenseGCNConv`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.dense.DenseGCNConv.html) layers after pooling (instead of regular `GCNConv`)
2. Add auxiliary losses to the total loss during training

In [None]:
class GNNWithDensePooling(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        
        # First GCN layer (sparse)
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        
        # Dense pooling layer
        self.pooler = get_pooler('mincut', in_channels=hidden_channels, k=20, cut_loss_coeff=0.01, ortho_loss_coeff=0.01)
        
        # Second GCN layer (dense, because pooler is dense)
        self.conv2 = DenseGCNConv(hidden_channels, hidden_channels)
        
        # Classifier
        self.lin = Linear(hidden_channels, dataset.num_classes)
    
    def forward(self, x, edge_index, edge_weight, batch):
        # First GCN layer
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        
        # Dense pooling
        pool_out = self.pooler(x=x, adj=edge_index, edge_weight=edge_weight, batch=batch)
        
        # Second GCN layer (on dense adjacency)
        x = self.conv2(pool_out.x, pool_out.edge_index)
        x = F.relu(x)
        
        # Global pooling
        x = self.pooler.global_pool(x, reduce_op='sum', batch=None)  # batch=None for dense
        
        # Classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        # Return predictions and auxiliary loss
        aux_loss = sum(pool_out.get_loss_value()) if pool_out.loss else 0.0
        return x, aux_loss

dense_model = GNNWithDensePooling(hidden_channels=64)
print(dense_model)

Training with auxiliary losses - note how we add `aux_loss` to the total loss:

In [None]:
model = dense_model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train_with_aux_loss():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out, aux_loss = model(data.x, data.edge_index, data.edge_weight, data.batch)
        loss = criterion(out, data.y) + aux_loss  # Add auxiliary loss!
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test_with_aux(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        out, _ = model(data.x, data.edge_index, data.edge_weight, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

print("Training GNN with dense pooling and auxiliary losses...")
for epoch in range(1, 21):
    loss = train_with_aux_loss()
    train_acc = test_with_aux(train_loader)
    test_acc = test_with_aux(test_loader)
    if epoch % 5 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f'\nFinal Test Accuracy (Dense Pooling): {test_acc:.4f}')

## Dense Poolers: Input/Output Flexibility

Dense poolers in tgp are highly flexible. They support different configurations via the `batched` and `sparse_output` flags:

| Configuration | `batched` | `sparse_output` | Input | Output |
|---------------|-----------|-----------------|-------|--------|
| Batched Dense | `True` (default) | `False` (default) | Sparse edge list | Dense adjacency |
| Batched Sparse | `True` | `True` | Sparse edge list | Sparse edge list |
| Unbatched Dense | `False` | `False` | Sparse edge list | Dense adjacency |
| Unbatched Sparse | `False` | `True` | Sparse edge list | Sparse edge list |

For details on when to use each mode and the trade-offs involved, see the [tgp paper](https://arxiv.org/abs/2512.12642).

In [None]:
# Create 4 MinCut configurations
configs = [
    ('Batched Dense', dict(batched=True, sparse_output=False)),
    ('Batched Sparse', dict(batched=True, sparse_output=True)),
    ('Unbatched Dense', dict(batched=False, sparse_output=False)),
    ('Unbatched Sparse', dict(batched=False, sparse_output=True)),
]

batch = next(iter(train_loader))

for name, config in configs:
    pooler = get_pooler('mincut', in_channels=7, k=10, **config)
    out = pooler(x=batch.x, adj=batch.edge_index, batch=batch.batch)
    
    # Get output shapes
    x_shape = out.x.shape
    adj_shape = out.edge_index.shape if out.edge_index.dim() == 2 else f"{out.edge_index.shape} (dense)"
    
    print(f"\n{name}:")
    print(f"  pooler.batched = {pooler.batched}")
    print(f"  pooler.sparse_output = {pooler.sparse_output}")
    print(f"  Output x shape: {x_shape}")
    print(f"  Output edge_index shape: {adj_shape}")

**Trade-offs:**

- **Batched mode** (`batched=True`): Efficient for batches of small graphs, but requires padding to handle variable-sized graphs
- **Unbatched mode** (`batched=False`): Processes graphs one at a time, avoiding padding overhead for large graphs
- **Dense output** (`sparse_output=False`): Required if using dense layers (like `DenseGCNConv`) after pooling
- **Sparse output** (`sparse_output=True`): Allows using regular sparse layers after pooling, reducing memory for large graphs

## Hierarchical Architectures

One of the key benefits of graph pooling is building truly **hierarchical architectures** with multiple pooling layers. This allows the network to learn features at different scales.

Architecture with 2 pooling layers:
```
[GCN → Pool₁ → GCN → Pool₂ → GCN → GlobalPool → Linear]
```

In [None]:
class HierarchicalGNN(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        
        # First GCN layer
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        
        # First pooling layer (keep 50% of nodes)
        self.pooler1 = get_pooler('topk', in_channels=hidden_channels, ratio=0.5)
        
        # Second GCN layer
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        
        # Second pooling layer (keep 50% of remaining nodes)
        self.pooler2 = get_pooler('topk', in_channels=hidden_channels, ratio=0.5)
        
        # Third GCN layer
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        
        # Classifier
        self.lin = Linear(hidden_channels, dataset.num_classes)
    
    def forward(self, x, edge_index, edge_weight, batch):
        # First GCN layer
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        
        # First pooling
        out1 = self.pooler1(x=x, adj=edge_index, edge_weight=edge_weight, batch=batch)
        
        # Second GCN layer
        x = self.conv2(out1.x, out1.edge_index, out1.edge_weight)
        x = F.relu(x)
        
        # Second pooling
        out2 = self.pooler2(x=x, adj=out1.edge_index, edge_weight=out1.edge_weight, batch=out1.batch)
        
        # Third GCN layer
        x = self.conv3(out2.x, out2.edge_index, out2.edge_weight)
        x = F.relu(x)
        
        # Global pooling
        x = self.pooler2.global_pool(x, reduce_op='sum', batch=out2.batch)
        
        # Classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        # Combine auxiliary losses from both poolers (if any)
        total_aux_loss = 0.0
        if out1.loss:
            total_aux_loss += sum(out1.get_loss_value())
        if out2.loss:
            total_aux_loss += sum(out2.get_loss_value())
        
        return x, total_aux_loss

hierarchical_model = HierarchicalGNN(hidden_channels=64)
print(hierarchical_model)

Let's train the hierarchical model:

In [None]:
model = hierarchical_model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print("Training hierarchical GNN (2 pooling layers)...")
for epoch in range(1, 21):
    loss = train_with_aux_loss()
    train_acc = test_with_aux(train_loader)
    test_acc = test_with_aux(test_loader)
    if epoch % 5 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f'\nFinal Test Accuracy (Hierarchical with 2 pooling layers): {test_acc:.4f}')

## Caching for Single-Graph Tasks

For tasks involving a **single large graph** (like node classification), the pooler is called with \((x, adj)\) and the decoder can use the original graph's adjacency.

Let's switch to the **Cora** dataset, a citation network where we classify research papers into topics:

In [None]:
# Load Cora dataset
cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')
cora_data = cora_dataset[0]

print(f'Dataset: {cora_dataset}')
print(f'Data: {cora_data}')
print(f'Number of nodes: {cora_data.num_nodes}')
print(f'Number of edges: {cora_data.num_edges}')
print(f'Number of features: {cora_dataset.num_features}')
print(f'Number of classes: {cora_dataset.num_classes}')

For node classification, we often use an **encoder-decoder architecture** with pooling and $\texttt{LIFT}$ in the middle:

```
[Encoder → Pool → Bottleneck → LIFT → Decoder]
```

The pooler is called with \((x, adj, batch)\) directly; for the decoder we use the original graph's adjacency (e.g. via `to_dense_adj`).

In [None]:
class NodeClassificationGNN(torch.nn.Module):
    def __init__(self, hidden_channels=16):
        super().__init__()
        
        # Encoder
        self.conv_enc = GCNConv(cora_dataset.num_features, hidden_channels)
        
        # Pooler
        self.pooler = get_pooler(
            'mincut',
            in_channels=hidden_channels,
            k=cora_data.num_nodes // 20,  # Coarsen to ~5% of original size
            sparse_output=False
        )
        
        # Bottleneck (on pooled graph)
        self.conv_pool = DenseGCNConv(hidden_channels, hidden_channels // 2)
        
        # Decoder (after lifting)
        self.conv_dec = DenseGCNConv(hidden_channels // 2, cora_dataset.num_classes)
    
    def forward(self, x, edge_index, edge_weight):
        # Encoder
        x = self.conv_enc(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        
        # Pool
        pool_out = self.pooler(x=x, adj=edge_index, edge_weight=edge_weight, batch=None)
        
        # Bottleneck
        x_pool = self.conv_pool(pool_out.x, pool_out.edge_index)
        x_pool = F.relu(x_pool)
        x_pool = F.dropout(x_pool, p=0.5, training=self.training)
        
        # Lift back to original space
        x_lift = self.pooler(x=x_pool, so=pool_out.so, lifting=True, batch=None)
        
        # Decoder (dense adj for single graph)
        adj_dense = to_dense_adj(edge_index)
        x = self.conv_dec(x_lift, adj_dense)
        
        # Extract from batch dimension
        if x.dim() == 3:
            x = x[0]
        
        # Return predictions and auxiliary loss
        aux_loss = sum(pool_out.get_loss_value()) if pool_out.loss else 0.0
        return F.log_softmax(x, dim=-1), aux_loss

node_model = NodeClassificationGNN(hidden_channels=16)
print(node_model)

Training for node classification (we only train on labeled nodes):

In [None]:
model = node_model.to(device)
cora_data = cora_data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train_node():
    model.train()
    optimizer.zero_grad()
    out, aux_loss = model(cora_data.x, cora_data.edge_index, cora_data.edge_weight)
    loss = F.nll_loss(out[cora_data.train_mask], cora_data.y[cora_data.train_mask]) + aux_loss
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test_node():
    model.eval()
    out, _ = model(cora_data.x, cora_data.edge_index, cora_data.edge_weight)
    pred = out.argmax(dim=1)
    
    accs = []
    for mask in [cora_data.train_mask, cora_data.val_mask, cora_data.test_mask]:
        correct = pred[mask].eq(cora_data.y[mask]).sum().item()
        acc = correct / mask.sum().item()
        accs.append(acc)
    return accs

print("Training node classification with caching...")
for epoch in range(1, 201):
    loss = train_node()
    train_acc, val_acc, test_acc = test_node()
    if epoch % 20 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

print(f'\nFinal Test Accuracy: {test_acc:.4f}')

## Precoarsening for Non-Learnable Poolers

Some pooling operators are **non-learnable** - they compute node assignments based solely on the graph structure, not on learned node features. Examples include:
- [`NDPPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NDPPooling) (Node Decimation Pooling)
- [`GraclusPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.GraclusPooling)
- [`NMFPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NMFPooling)

Since the graph structure doesn't change during training, we can **precompute** the pooling operations using [`PreCoarsening`](https://torch-geometric-pool.readthedocs.io/en/latest/api/data/transforms.html#tgp.data.transforms.PreCoarsening)!

For advanced precoarsening usages (multiple poolers, different configs per level), see the [Precoarsening and transforms](https://github.com/tgp-team/torch-geometric-pool/blob/main/docs/source/tutorials/precoarsening_and_transforms.ipynb) notebook.

In [None]:
from tgp.data import PreCoarsening, PoolDataLoader
from tgp.reduce import BaseReduce, global_reduce
from torch_geometric.nn import ARMAConv

# Create a non-learnable pooler
ndp_pooler = get_pooler('ndp')  # Node Decimation Pooling
print(f"NDP Pooler: {ndp_pooler}")

# Apply PreCoarsening transform to the dataset
# This will precompute pooling for 2 levels
precoarsened_dataset = TUDataset(
    root='/tmp/MUTAG_precoarsened',
    name='MUTAG',
    pre_transform=PreCoarsening(
        pooler=ndp_pooler,
        recursive_depth=2  # 2 levels of coarsening
    ),
    force_reload=True
)

print(f"\nDataset with precoarsening:")
sample = precoarsened_dataset[0]
print(sample)
print(f"\nPooled data levels: {len(sample.pooled_data)}")
for i, pooled in enumerate(sample.pooled_data):
    print(f"  Level {i+1}: {pooled}")

We need a special [`PoolDataLoader`](https://torch-geometric-pool.readthedocs.io/en/latest/api/data/loaders.html#tgp.data.loaders.PoolDataLoader) to properly batch graphs with precoarsened data:

In [None]:
# Use PoolDataLoader instead of regular DataLoader
precoarsened_train = precoarsened_dataset[:150]
precoarsened_test = precoarsened_dataset[150:]

pool_train_loader = PoolDataLoader(precoarsened_train, batch_size=32, shuffle=True)
pool_test_loader = PoolDataLoader(precoarsened_test, batch_size=32)

# Check a batch
batch = next(iter(pool_train_loader))
print(f"Batch: {batch}")
print(f"\nPooled data (properly batched):")
for i, pooled in enumerate(batch.pooled_data):
    print(f"  Level {i+1}: {pooled}")

Now we can build a model that uses the precomputed pooling structure. Notice that we only need `BaseReduce` - the $\texttt{SEL}$ and $\texttt{CON}$ operations are already precomputed!

In [None]:
class PrecoarsenedGNN(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        
        num_features = precoarsened_dataset.num_features
        num_classes = precoarsened_dataset.num_classes
        
        # First MP layer
        self.conv1 = ARMAConv(num_features, hidden_channels, num_layers=2)
        
        # Reducer (we only need REDUCE, SEL and CON are precomputed!)
        self.reducer = BaseReduce()
        
        # MP layers after each pooling level
        self.conv2 = ARMAConv(hidden_channels, hidden_channels, num_layers=2)
        self.conv3 = ARMAConv(hidden_channels, hidden_channels, num_layers=2)
        
        # Classifier
        self.lin = Linear(hidden_channels, num_classes)
    
    def forward(self, data):
        # First MP layer on original graph
        x = self.conv1(data.x, data.edge_index, data.edge_weight)
        x = F.relu(x)
        
        # Apply precoarsened pooling levels
        # Level 1
        pooled_1 = data.pooled_data[0]
        x, _ = self.reducer(x=x, so=pooled_1.so)  # Just REDUCE!
        x = self.conv2(x, pooled_1.edge_index, pooled_1.edge_weight)
        x = F.relu(x)
        
        # Level 2
        pooled_2 = data.pooled_data[1]
        x, _ = self.reducer(x=x, so=pooled_2.so)  # Just REDUCE!
        x = self.conv3(x, pooled_2.edge_index, pooled_2.edge_weight)
        x = F.relu(x)
        
        # Global pooling
        x = global_reduce(x, reduce_op='sum', batch=pooled_2.batch)
        
        # Classifier
        x = self.lin(x)
        
        return F.log_softmax(x, dim=-1)

precoarsened_model = PrecoarsenedGNN(hidden_channels=64)
print(precoarsened_model)

Training is much faster since we skip $\texttt{SEL}$ and $\texttt{CON}$ operations!

In [None]:
model = precoarsened_model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train_precoarsened():
    model.train()
    total_loss = 0
    for data in pool_train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(pool_train_loader.dataset)

@torch.no_grad()
def test_precoarsened(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data)
        pred = out.argmax(dim=-1)
        correct += int(pred.eq(data.y.view(-1)).sum())
    return correct / len(loader.dataset)

print("Training with precoarsened pooling...")
for epoch in range(1, 21):
    loss = train_precoarsened()
    train_acc = test_precoarsened(pool_train_loader)
    test_acc = test_precoarsened(pool_test_loader)
    if epoch % 5 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f'\nFinal Test Accuracy (Precoarsened): {test_acc:.4f}')

## Creating Custom Poolers

One of the most powerful features of tgp is the ability to create **custom pooling operators** by mixing and matching different $\texttt{SEL}$, $\texttt{RED}$, $\texttt{CON}$, and $\texttt{LIFT}$ components!

Available components:
- **Select**: [`TopkSelect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.TopkSelect), [`MLPSelect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.MLPSelect), [`NDPSelect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.NDPSelect), etc.
- **Connect**: [`SparseConnect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/connect.html#tgp.connect.SparseConnect), [`DenseConnect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/connect.html#tgp.connect.DenseConnect), [`KronConnect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/connect.html#tgp.connect.KronConnect), etc.
- **Reduce**: [`BaseReduce`](https://torch-geometric-pool.readthedocs.io/en/latest/api/reduce.html#tgp.reduce.BaseReduce)
- **Lift**: [`BaseLift`](https://torch-geometric-pool.readthedocs.io/en/latest/api/lift.html#tgp.lift.BaseLift)

### Exercise

Implement the following in the code cells below:

1. **Part 1 – Custom pooler**: Implement a custom pooler that combines **TopkSelect** with **KronConnect** (instead of the default SparseConnect). Subclass `SRCPooling`; in `__init__` pass a `TopkSelect`, `BaseReduce`, `BaseLift`, and `KronConnect`. Implement `forward` to call select → reduce → connect and return a `PoolingOutput`. Create an instance and optionally compare with `get_pooler('topk', ...)`.

2. **Part 2 – GNN**: Define a GNN that uses your custom pooler between two GCN layers (same structure as earlier in the tutorial), then global pool and classify. Use the existing `dataset`, `train_loader`, `test_loader`, and `train()` / `test()` from the notebook.

3. **Part 3 – Train**: Train the model for 100 epochs and report final test accuracy.

See the [Advanced](https://github.com/tgp-team/torch-geometric-pool/blob/main/docs/source/tutorials/advanced.ipynb) notebook for the solution.


In [None]:
from tgp.select import TopkSelect
from tgp.connect import KronConnect
from tgp.reduce import BaseReduce
from tgp.lift import BaseLift
from tgp.src import SRCPooling, PoolingOutput


class CustomTopKKronPooler(SRCPooling):
    # TODO: In __init__, call super().__init__ with selector=TopkSelect(...), reducer=BaseReduce(...), lifter=BaseLift(...), connector=KronConnect()
    def __init__(self, in_channels, ratio=0.5):
        pass  # TODO

    # TODO: In forward, call self.select -> self.reduce -> self.connect, then return PoolingOutput(...)
    def forward(self, x, adj, edge_weight=None, batch=None, **kwargs):
        pass  # TODO


# TODO: Create an instance, e.g. custom_pooler = CustomTopKKronPooler(in_channels=64, ratio=0.5)
# TODO: Optionally compare with get_pooler('topk', in_channels=64, ratio=0.5)


### Part 2: Use your pooler in a GNN


In [None]:
class GNNWithCustomPooler(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        # TODO: self.conv1 = GCNConv(...), self.pooler = CustomTopKKronPooler(...), self.conv2 = GCNConv(...), self.lin = Linear(...)
        pass

    def forward(self, x, edge_index, edge_weight, batch):
        # TODO: conv1 -> ReLU -> pooler -> conv2 -> ReLU -> global_pool -> dropout -> lin
        pass  # TODO


# TODO: custom_model = GNNWithCustomPooler(hidden_channels=64)


### Part 3: Train and evaluate


In [None]:
# TODO: model = custom_model.to(device)
# TODO: optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# TODO: Training loop for 100 epochs using train() and test(train_loader) / test(test_loader)
# TODO: Print final test accuracy


## Summary

In this tutorial, we explored [tgp](https://github.com/tgp-team/torch-geometric-pool) for building hierarchical Graph Neural Networks with pooling.

### Key Concepts

1. **SRC Framework**: All pooling operators decompose into $\texttt{SEL}$, $\texttt{RED}$, and $\texttt{CON}$ operations (plus $\texttt{LIFT}$ for unpooling)

2. **Easy Integration**: Adding pooling to a GNN is as simple as inserting a pooler between message-passing layers

3. **Sparse vs Dense Poolers**:
   - Sparse poolers work with edge lists (TopK, SAG, ASAP)
   - Dense poolers work with dense adjacency matrices (MinCut, DiffPool, DMoN)
   - Dense poolers often have auxiliary losses

4. **Hierarchical Architectures**: Stack multiple pooling layers to learn multi-scale representations

5. **Performance Optimization**:
   - **Node classification**: Encoder–pool–bottleneck–lift–decoder; use \(x, adj, batch\) for the pooler.
   - **Precoarsening**: For non-learnable poolers, precompute pooling structures

6. **Modularity**: Create custom poolers by combining different $\texttt{SEL}$, $\texttt{RED}$, $\texttt{CON}$, and $\texttt{LIFT}$ components

### Resources

- [Documentation](https://torch-geometric-pool.readthedocs.io)
- [GitHub Repository](https://github.com/tgp-team/torch-geometric-pool)
- [tgp Paper](https://arxiv.org/abs/2512.12642)
- [SRC Framework Paper](https://arxiv.org/abs/2110.05292)

Try different poolers on your own datasets and see which works best for your task!