# Precomputed pooling operations

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tgp-team/torch-geometric-pool/blob/main/docs/source/tutorials/preprocessing_and_transforms.ipynb)

Some pooling operators such as [`NDPPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NDPPooling), [`GraclusPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.GraclusPooling), [`NMFPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NMFPooling) (and some configurations of [`KMISPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.KMISPooling) and [`LaPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.LaPooling)), compute the $\texttt{SEL}$ only based on the topology of the adacency matrix. 
As opposed to the node features, which are modified by each layer of the GNN and evolve during training, the adjacency matrix is and remains fixed. 
Therefore, the $\texttt{SEL}$ and the $\texttt{CON}$ operations of these poolers is always the same and can be **precomputed** bofeore starting to train the GNN.
This, allows us to save a lot of time during training because the only operation that we need to compute is the $\texttt{RED}$ to compute the features of the supernodes.

Let's start by loading some data. 

In [1]:
import sys
import torch
if 'google.colab' in sys.modules:
    import os
    os.environ["TORCH"] = torch.__version__
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
    !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
    !pip install -q torch_geometric_pool[notebook]

In [2]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG", force_reload=True)

Processing...
Done!


Let's now take the first graph.

In [3]:
data = dataset[0]
print(data)

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])


Let's consider [`NDPPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NDPPooling): its $\texttt{SEL}$ operation only looks at the graph connectivity. 
This means that we can compute the [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) without having to pass the node features.

In [4]:
from tgp.connect import KronConnect
from tgp.select import NDPSelect

selector = NDPSelect()
connector = KronConnect()

# Compute pooled graph
so = selector(data.edge_index)
print(so)

SelectOutput(num_nodes=17, num_supernodes=8, extra={'L'})


This also means that we can compute the coarsened graph connectivity witht the $\texttt{CON}$ operation.

In [5]:
edge_index_pool = connector(data.edge_index, so)
print(edge_index_pool)

(tensor([[0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6,
         7, 7],
        [1, 2, 0, 2, 4, 0, 1, 3, 4, 2, 4, 6, 1, 2, 3, 5, 6, 4, 6, 7, 3, 4, 5, 7,
         5, 6]]), tensor([0.5000, 0.5000, 0.5000, 0.3333, 0.3333, 0.5000, 0.3333, 0.5000, 0.3333,
        0.5000, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.5000, 0.3333, 0.5000,
        0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333]))


```{note}
[`NDPPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NDPPooling) uses the Kron reduction implemented by [`KronConnect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/connect.html#tgp.connect.KronConnect) to compute the $\texttt{connect}$ operation. However, once the [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) is computed other $\texttt{CON}$ opertions, e.g., [`SparseConnect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/connect.html#tgp.connect.SparseConnect), can be used.
````

At this point, we can apply the $\texttt{SEL}$ and the $\texttt{CON}$ operation one more time on the pooled graph. 
This is useful if we want to use a GNN architecture that applies pooling multiple times.

In [6]:
so2 = selector(edge_index_pool[0], edge_index_pool[1])
print(so2)

edge_index_pool2 = connector(edge_index_pool[0], so2, edge_index_pool[1])
print(edge_index_pool2)

SelectOutput(num_nodes=8, num_supernodes=4, extra={'L'})
(tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
        [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]]), tensor([0.2045, 0.3182, 0.2045, 0.5979, 0.1154, 0.3182, 0.5979, 0.3077, 0.1154,
        0.3077]))


We can repeat the procedure iteratively for all the pooling levels that we want to have in our GNN. 

## The Precoarsening transform

Precomputing pooling allows us to save a lot of time because we only need to do it once before starting to train our GNN.
However, for each sample in our dataset we end up having an instance of [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) and a pooled connectivity for each pooling level. 
Handling all of them during training, while keeping the correct association between data structures when we shuffle the data, is cumbersome.

<img src="../_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp provides a couple of tools to handle precomputed pooled graphs efficiently. 
The first is the [`PreCoarsening`](https://torch-geometric-pool.readthedocs.io/en/latest/api/data/transforms.html#tgp.data.transforms.PreCoarsening) transform, which can be directly applied to the dataset like all the other [PyG `transforms`](https://pytorch-geometric.readthedocs.io/en/2.5.2/modules/transforms.html).

In [7]:
from tgp.poolers import NDPPooling
from tgp.data import PreCoarsening

dataset = TUDataset(
    root="/tmp/MUTAG",
    name="MUTAG",
    pre_transform=PreCoarsening(
        pooler=NDPPooling(), recursive_depth=2
    ),
    force_reload=True,
)

data = dataset[0]
print(data)

Processing...


Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1], pooled_data=[2])


Done!


Once again we look at the first element of the dataset and this time we see that, compared to the standard [`Data`](https://pytorch-geometric.readthedocs.io/en/stable/generated/torch_geometric.data.Data.html) structure, there is an additional field, `pooled_data`, which is a list of length `recursive_depth`.
The elements in the list are the hierarchy of pooled graphs computed with the `selector` and `connector` that we defined in the [`PreCoarsening`](https://torch-geometric-pool.readthedocs.io/en/latest/api/data/transforms.html#tgp.data.transforms.PreCoarsening) transform. 

Each pooled graph is a [`Data`](https://pytorch-geometric.readthedocs.io/en/stable/generated/torch_geometric.data.Data.html) structure containing the [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) and the pooled connectivity matrix. 
Since we are using [`NDPPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NDPPooling), which internally calls [`NDPSelect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.NDPSelect), we get an extra argument `L` representing the Laplacian matrix used by [`KronConnect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/connect.html#tgp.connect.KronConnect).

In [8]:
for pooled_data in data.pooled_data:
    print(pooled_data)

Data(edge_index=[2, 26], edge_weight=[26], so=SelectOutput(num_nodes=17, num_supernodes=8, extra={'L'}), num_nodes=8)
Data(edge_index=[2, 12], edge_weight=[12], so=SelectOutput(num_nodes=8, num_supernodes=4, extra={'L'}), num_nodes=4)


This new Data strcture is very convenient as it carries all the information that the GNN needs to perform pooling at each coarsening level.
With it, we do not need to keep track manually of the association between data samples and their pooled graph.

## The PoolDataLoader

The field `pooled_data` in these custom Data structures is *not* handled properly by the standard [`DataLoader`](https://pytorch-geometric.readthedocs.io/en/2.5.2/modules/loader.html#torch_geometric.loader.DataLoader) of <img src="https://raw.githubusercontent.com/TorchSpatiotemporal/tsl/main/docs/source/_static/img/logos/pyg.svg" width="20px" align="center"/> PyG.
While the node features, `x`, the edge indices, edge attributes, etc... are batched correctly, the pooled graphs are just concatenated in a list rather than being combined into a single batched graph for each pooling level.

In [9]:
from torch_geometric.loader import DataLoader

pyg_loader = DataLoader(dataset, batch_size=4, shuffle=True)

next_batch = next(iter(pyg_loader))
print(next_batch)
print(next_batch.pooled_data[0])

DataBatch(edge_index=[2, 152], x=[70, 7], edge_attr=[152, 4], y=[4], pooled_data=[4], batch=[70], ptr=[5])
[Data(edge_index=[2, 30], edge_weight=[30], so=SelectOutput(num_nodes=15, num_supernodes=9, extra={'L'}), num_nodes=9), Data(edge_index=[2, 20], edge_weight=[20], so=SelectOutput(num_nodes=9, num_supernodes=5, extra={'L'}), num_nodes=5)]


To obtain well-formed batches with precomputed pooled graphs <img src="../_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp provides the [`PoolDataLoader`](https://torch-geometric-pool.readthedocs.io/en/latest/api/data/loaders.html#tgp.data.loaders.PoolDataLoader).
Now, the field `pooled_data` in the batch is a list containing a single batched graph for each coarsening level (2 in our case).

In [10]:
from tgp.data import PoolDataLoader

tgp_loader = PoolDataLoader(dataset, batch_size=16, shuffle=True)

next_batch = next(iter(tgp_loader))
print(next_batch)
print(next_batch.pooled_data[0])

DataPooledBatch(edge_index=[2, 588], x=[270, 7], edge_attr=[588, 4], y=[16], pooled_data=[2], batch=[270], ptr=[17])
DataPooledBatch(edge_index=[2, 426], edge_weight=[426], so=SelectOutput(num_nodes=270, num_supernodes=137, extra={'L'}), num_nodes=137, batch=[137], ptr=[17])


A complete example of usage can be found [here](https://github.com/tgp-team/torch-geometric-pool/blob/main/examples/pre_coarsening.py).

## Other data transforms

Some pooling layers come with custom transforms that should be applied to the data before starting to train the GNN.
For example, [`JustBalancePooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.JustBalancePooling) transforms the connectivity matrix $\mathbf{A}$ as follows:

$$\mathbf{A} \to \mathbf{I} - \delta \mathbf{L}$$

The transforms associated with a given pooling operator are stored in the field `data_transforms()`. 
They can be accessed and passed to the dataset as any other <img src="https://raw.githubusercontent.com/TorchSpatiotemporal/tsl/main/docs/source/_static/img/logos/pyg.svg" width="20px" align="center"/> PyG [`transform`](https://pytorch-geometric.readthedocs.io/en/2.5.2/modules/transforms.html).

In [11]:
from tgp.poolers import JustBalancePooling

pooler = JustBalancePooling(in_channels=dataset.num_features, k=10)
print(pooler.data_transforms())

NormalizeAdj()


In [12]:
dataset = TUDataset(
    root="/tmp/MUTAG",
    name="MUTAG",
    force_reload=True,
    pre_transform=pooler.data_transforms(),  # transform specific for the pooler
)

Processing...
Done!
