# Precoarsening and transforms

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tgp-team/torch-geometric-pool/blob/main/docs/source/tutorials/precoarsening_and_transforms.ipynb)

Some pooling operators such as [`NDPPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NDPPooling), [`GraclusPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.GraclusPooling), [`NMFPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NMFPooling) (and some configurations of [`KMISPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.KMISPooling) and [`LaPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.LaPooling)), compute the $\texttt{SEL}$ only based on the topology of the adacency matrix. 
As opposed to the node features, which are modified by each layer of the GNN and evolve during training, the adjacency matrix is and remains fixed. 
Therefore, the $\texttt{SEL}$ and the $\texttt{CON}$ operations of these poolers is always the same and can be **precomputed** bofeore starting to train the GNN.
This, allows us to save a lot of time during training because the only operation that we need to compute is the $\texttt{RED}$ to compute the features of the supernodes.

Let's start by loading some data. 

In [None]:
import sys
import torch
if 'google.colab' in sys.modules:
    import os
    os.environ["TORCH"] = torch.__version__
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
    !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
    !pip install -q torch_geometric_pool[notebook]

In [None]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG", force_reload=True)

Let's now take the first graph.

In [None]:
data = dataset[0]
print(data)

Let's consider [`NDPPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NDPPooling): its $\texttt{SEL}$ operation only looks at the graph connectivity. 
This means that we can compute the [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) without having to pass the node features.

In [None]:
from tgp.connect import KronConnect
from tgp.select import NDPSelect

selector = NDPSelect()
connector = KronConnect()

# Compute pooled graph
so = selector(data.edge_index)
print(so)

This also means that we can compute the coarsened graph connectivity witht the $\texttt{CON}$ operation.

In [None]:
edge_index_pool = connector(data.edge_index, so)
print(edge_index_pool)

```{note}
[`NDPPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NDPPooling) uses the Kron reduction implemented by [`KronConnect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/connect.html#tgp.connect.KronConnect) to compute the $\texttt{connect}$ operation. However, once the [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) is computed other $\texttt{CON}$ opertions, e.g., [`SparseConnect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/connect.html#tgp.connect.SparseConnect), can be used.
````

At this point, we can apply the $\texttt{SEL}$ and the $\texttt{CON}$ operation one more time on the pooled graph. 
This is useful if we want to use a GNN architecture that applies pooling multiple times.

In [None]:
so2 = selector(edge_index_pool[0], edge_index_pool[1])
print(so2)

edge_index_pool2 = connector(edge_index_pool[0], so2, edge_index_pool[1])
print(edge_index_pool2)

We can repeat the procedure iteratively for all the pooling levels that we want to have in our GNN. 

## The Precoarsening transform

Precomputing pooling allows us to save a lot of time because we only need to do it once before starting to train our GNN.
However, for each sample in our dataset we end up having an instance of [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) and a pooled connectivity for each pooling level. 
Handling all of them during training, while keeping the correct association between data structures when we shuffle the data, is cumbersome.

<img src="../_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp provides a couple of tools to handle precomputed pooled graphs efficiently. 
The first is the [`PreCoarsening`](https://torch-geometric-pool.readthedocs.io/en/latest/api/data/transforms.html#tgp.data.transforms.PreCoarsening) transform, which can be directly applied to the dataset like all the other [PyG `transforms`](https://pytorch-geometric.readthedocs.io/en/2.5.2/modules/transforms.html).

In [None]:
from tgp.poolers import NDPPooling
from tgp.data import PreCoarsening

dataset = TUDataset(
    root="/tmp/MUTAG",
    name="MUTAG",
    pre_transform=PreCoarsening(
        poolers=[NDPPooling(), NDPPooling()]
    ),
    force_reload=True,
)

data = dataset[0]
print(data)

Once again we look at the first element of the dataset and this time we see that, compared to the standard [`Data`](https://pytorch-geometric.readthedocs.io/en/stable/generated/torch_geometric.data.Data.html) structure, there is an additional field, `pooled_data`, which is a list of length `recursive_depth`.
The elements in the list are the hierarchy of pooled graphs computed with the `selector` and `connector` that we defined in the [`PreCoarsening`](https://torch-geometric-pool.readthedocs.io/en/latest/api/data/transforms.html#tgp.data.transforms.PreCoarsening) transform. 

Each pooled graph is a [`Data`](https://pytorch-geometric.readthedocs.io/en/stable/generated/torch_geometric.data.Data.html) structure containing the [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput) and the pooled connectivity matrix. 
Since we are using [`NDPPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.NDPPooling), which internally calls [`NDPSelect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.NDPSelect), we get an extra argument `L` representing the Laplacian matrix used by [`KronConnect`](https://torch-geometric-pool.readthedocs.io/en/latest/api/connect.html#tgp.connect.KronConnect).

In [None]:
for pooled_data in data.pooled_data:
    print(pooled_data)

This new Data strcture is very convenient as it carries all the information that the GNN needs to perform pooling at each coarsening level.
With it, we do not need to keep track manually of the association between data samples and their pooled graph.

## The PoolDataLoader

The field `pooled_data` in these custom Data structures is *not* handled properly by the standard [`DataLoader`](https://pytorch-geometric.readthedocs.io/en/2.5.2/modules/loader.html#torch_geometric.loader.DataLoader) of <img src="https://raw.githubusercontent.com/TorchSpatiotemporal/tsl/main/docs/source/_static/img/logos/pyg.svg" width="20px" align="center"/> PyG.
While the node features, `x`, the edge indices, edge attributes, etc... are batched correctly, the pooled graphs are just concatenated in a list rather than being combined into a single batched graph for each pooling level.

In [None]:
from torch_geometric.loader import DataLoader

pyg_loader = DataLoader(dataset, batch_size=4, shuffle=True)

next_batch = next(iter(pyg_loader))
print(next_batch)
print(next_batch.pooled_data[0])

To obtain well-formed batches with precomputed pooled graphs <img src="../_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp provides the [`PoolDataLoader`](https://torch-geometric-pool.readthedocs.io/en/latest/api/data/loaders.html#tgp.data.loaders.PoolDataLoader).
Now, the field `pooled_data` in the batch is a list containing a single batched graph for each coarsening level (2 in our case).

In [None]:
from tgp.data import PoolDataLoader

tgp_loader = PoolDataLoader(dataset, batch_size=16, shuffle=True)

next_batch = next(iter(tgp_loader))
print(next_batch)
print(next_batch.pooled_data[0])

A complete example of usage can be found [here](https://github.com/tgp-team/torch-geometric-pool/blob/main/examples/pre_coarsening.py).

## Advanced usage: multiple poolers and different configs per level

You can build a **hierarchical network** with **more than one pooler** and **different poolers or configs at each level**. Pass a list of level specs to `PreCoarsening(poolers=...)`: each element is either a string alias (e.g. `"ndp"`, `"graclus"`) or a tuple `(alias, kwargs)` for that level.

Examples:
- Same pooler, same config: `["ndp", "ndp"]`
- Same pooler, different config: `[("nmf", {"k": 8}), ("nmf", {"k": 4})]`
- Mixed poolers: `["ndp", ("eigen", {"k": 4, "num_modes": 3})]`

In the model, use a `ModuleList` of reducers from `pre_transform.poolers`, one conv per level (accounting for `num_modes` for EigenPooling), and in `forward` loop over `data.pooled_data` and call `reducer(x=x, so=pooled.so)` then the next conv. Below is a full runnable example (mixed poolers: NDP then Eigen). For more schedules and a standalone script, see [pre_coarsening.py](https://github.com/tgp-team/torch-geometric-pool/blob/main/examples/pre_coarsening.py).

In [None]:
# Full runnable example: mixed poolers (NDP -> Eigen), 2 levels
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import ARMAConv
from tgp.data import PreCoarsening, PoolDataLoader
from tgp.reduce import readout

level_specs = ["ndp", ("eigen", {"k": 4, "num_modes": 3})]
pre_transform_adv = PreCoarsening(poolers=level_specs)
level_poolers = pre_transform_adv.poolers
num_levels = len(level_poolers)

dataset_adv = TUDataset(
    root="/tmp/MUTAG_ndp_eigen",
    name="MUTAG",
    pre_transform=pre_transform_adv,
    force_reload=True,
)
train_loader_adv = PoolDataLoader(dataset_adv[:150], batch_size=32, shuffle=True)
test_loader_adv = PoolDataLoader(dataset_adv[150:], batch_size=32)

level_num_modes = [getattr(p, "num_modes", 1) for p in level_poolers]


class PrecoarsenedNet(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        self.conv1 = ARMAConv(
            dataset_adv.num_features, hidden_channels, num_layers=2
        )
        self.reducers = torch.nn.ModuleList([p.reducer for p in level_poolers])
        self.next_conv = torch.nn.ModuleList()
        for num_modes in level_num_modes:
            in_ch = hidden_channels * num_modes
            self.next_conv.append(
                ARMAConv(in_ch, hidden_channels, num_layers=2)
            )
        self.lin = torch.nn.Linear(hidden_channels, dataset_adv.num_classes)

    def forward(self, data):
        x = self.conv1(data.x, data.edge_index, data.edge_weight)
        x = F.relu(x)
        for pooled, conv, reducer in zip(data.pooled_data, self.next_conv, self.reducers):
            x, _ = reducer(x=x, so=pooled.so)
            x = conv(x, pooled.edge_index, pooled.edge_weight)
            x = F.relu(x)
        x = readout(x, reduce_op="sum", batch=pooled.batch)
        return F.log_softmax(self.lin(x), dim=-1)


device_adv = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_adv = PrecoarsenedNet(hidden_channels=64).to(device_adv)
optimizer_adv = torch.optim.Adam(model_adv.parameters(), lr=1e-4)


def train_adv():
    model_adv.train()
    total = 0
    for data in train_loader_adv:
        data = data.to(device_adv)
        optimizer_adv.zero_grad()
        loss = F.nll_loss(model_adv(data), data.y.view(-1))
        loss.backward()
        optimizer_adv.step()
        total += data.y.size(0) * float(loss)
    return total / len(dataset_adv)


@torch.no_grad()
def test_adv(loader):
    model_adv.eval()
    correct = 0
    for data in loader:
        data = data.to(device_adv)
        correct += int(model_adv(data).argmax(dim=-1).eq(data.y.view(-1)).sum())
    return correct / len(loader.dataset)


for epoch in range(1, 11):
    loss = train_adv()
    acc = test_adv(test_loader_adv)
    print(f"Epoch {epoch:2d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}")
print(f"Final test accuracy: {test_adv(test_loader_adv):.4f}")

## Other data transforms

Some pooling layers come with custom transforms that should be applied to the data before starting to train the GNN.
For example, [`JustBalancePooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.JustBalancePooling) transforms the connectivity matrix $\mathbf{A}$ as follows:

$$\mathbf{A} \to \mathbf{I} - \delta \mathbf{L}$$

The transforms associated with a given pooling operator are stored in the field `data_transforms()`. 
They can be accessed and passed to the dataset as any other <img src="https://raw.githubusercontent.com/TorchSpatiotemporal/tsl/main/docs/source/_static/img/logos/pyg.svg" width="20px" align="center"/> PyG [`transform`](https://pytorch-geometric.readthedocs.io/en/2.5.2/modules/transforms.html).

In [None]:
from tgp.poolers import JustBalancePooling

pooler = JustBalancePooling(in_channels=dataset.num_features, k=10)
print(pooler.data_transforms())

In [None]:
dataset = TUDataset(
    root="/tmp/MUTAG",
    name="MUTAG",
    force_reload=True,
    pre_transform=pooler.data_transforms(),  # transform specific for the pooler
)