# Introduction by examples

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tgp-team/torch-geometric-pool/blob/main/docs/source/tutorials/intro.ipynb)

In the following, we will go through a few examples that showcase the main functionalities of <img src="../_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> TGP.
Let's start by importing the required libraries and checking the pooling operators that are available.

In [1]:
import sys
import torch
if 'google.colab' in sys.modules:
    import os
    os.environ["TORCH"] = torch.__version__
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
    !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
    !pip install -q torch_geometric_pool[notebook]

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DenseGCNConv, GCNConv

from tgp.poolers import TopkPooling, get_pooler, pooler_classes, pooler_map

torch.set_printoptions(threshold=2, edgeitems=2)

print("Available poolers:")
for i,pooler in enumerate(pooler_classes):
    print(f"{i+1}. {pooler}")

Available poolers:
1. ASAPooling
2. AsymCheegerCutPooling
3. BNPool
4. DiffPool
5. DMoNPooling
6. EdgeContractionPooling
7. GraclusPooling
8. HOSCPooling
9. LaPooling
10. JustBalancePooling
11. KMISPooling
12. MaxCutPooling
13. MinCutPooling
14. NDPPooling
15. NMFPooling
16. PANPooling
17. SAGPooling
18. TopkPooling


For example, let's create a [`TopkPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.TopkPooling) object.

In [3]:
pooler = TopkPooling(in_channels=16)
print(f"Pooler: {pooler}")

Pooler: TopkPooling(
	select=TopkSelect(in_channels=16, ratio=0.5, act=Tanh(), s_inv_op=transpose)
	reduce=BaseReduce(reduce_op=sum)
	lift=BaseLift(matrix_op=precomputed, reduce_op=sum)
	connect=SparseConnect(reduce_op=sum, remove_self_loops=True)
	multiplier=1.0
)


Each pooler is associated with an alias that can be used to quickly instantiate a pooler.

In [4]:
print("Available poolers:")
for alias, cls in zip(pooler_map.keys(), pooler_map.values()):
    print(f"'{alias}' --> {cls.__name__}")

Available poolers:
'asap' --> ASAPooling
'acc' --> AsymCheegerCutPooling
'bnpool' --> BNPool
'diff' --> DiffPool
'dmon' --> DMoNPooling
'ec' --> EdgeContractionPooling
'graclus' --> GraclusPooling
'hosc' --> HOSCPooling
'lap' --> LaPooling
'jb' --> JustBalancePooling
'kmis' --> KMISPooling
'maxcut' --> MaxCutPooling
'mincut' --> MinCutPooling
'ndp' --> NDPPooling
'nmf' --> NMFPooling
'pan' --> PANPooling
'sag' --> SAGPooling
'topk' --> TopkPooling


We can instantiate the same object of class [`TopkPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.TopkPooling) by passing the alias and a dict with the parameters needed to initialize the pooler.

In [5]:
params = {
    "in_channels": 3,  # Number of input features
    "ratio": 0.25,  # Percentage of nodes to keep
}

pooler = get_pooler("topk", **params)  # Get the pooler by alias
print(pooler)

TopkPooling(
	select=TopkSelect(in_channels=3, ratio=0.25, act=Tanh(), s_inv_op=transpose)
	reduce=BaseReduce(reduce_op=sum)
	lift=BaseLift(matrix_op=precomputed, reduce_op=sum)
	connect=SparseConnect(reduce_op=sum, remove_self_loops=True)
	multiplier=1.0
)


We see that each pooling layer implements a specific select ($\texttt{SEL}$), reduce ($\texttt{RED}$), connect $\texttt{CON}$ operations, as defined by the [SRC framework](https://arxiv.org/abs/2110.05292).

<img src="../_static/img/src_overview.png" style="width: 55%; display: block; margin: auto;">

- The $\texttt{SEL}$ operation is what sets most pooling methods apart and defines how the nodes are assigned to the supernodes of the pooled graph. 
- The $\texttt{RED}$ operation specifies how to compute the features of the supernodes in the pooled graph. 
- Finally, $\texttt{CON}$ creates the connectivity matrix of the pooled graph. 

The pooling operators also have a $\texttt{LIFT}$ function, which is used by some GNN architectures to map the pooled node features back to the node space of the original graph.
See [here](../content/src.md) for an introduction to the SRC(L) framework.

## Calling a pooling layer

A pooling layer can be called similarly to a message-passing layer in PyG.
Let's start by loading some data and creating a data batch.

In [6]:
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
print(f"Dataset: {dataset}")
loader = DataLoader(dataset, batch_size=32, shuffle=True)
data_batch = next(iter(loader))
print(f"Data batch: {data_batch}")

Dataset: ENZYMES(600)
Data batch: DataBatch(edge_index=[2, 3964], x=[1109, 3], y=[32], batch=[1109], ptr=[33])


```{attention}
Pooling operators support **edge weights**, i.e., scalar values stored in a `edge_weight` attribute.
However, some dataset have **edge features** stored in the `edge_attr` field.
In <img src="../_static/img/tgp-logo.svg" width="40px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp we assume that the edge attributes are processed by the message-passing layers before pooling, which embed the attributes into the node features that reach the pooling operators.
```

In [7]:
pooling_output = pooler(
    x=data_batch.x,
    adj=data_batch.edge_index,
    edge_weight=data_batch.edge_weight,
    batch=data_batch.batch,
)
print(pooling_output)

PoolingOutput(so=[1109, 290], x=[290, 3], edge_index=[2, 524], edge_weight=None, batch=[290], loss=None)


The output of a pooling layer is an object of class [`PoolingOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/src.html#tgp.src.PoolingOutput) that contains different fields:
- the node features of the pooled graph (`x`), 
- the indices and weights of the pooled adjacency matrix (`edge_index`, `edge_weight`), 
- the batch indices of the pooled graphs (`batch`). 

In addition, `so` is an object of class [`SelectOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/select.html#tgp.select.SelectOutput), i.e., the output of the $\texttt{SEL}$ operation that describes how the nodes of the original graph are assigned to the supernodes of the pooled graph.

In [8]:
print(pooling_output.so)

SelectOutput(num_nodes=1109, num_clusters=290)


```{note}
Some pooling operators save additional data structures in the [`PoolingOutput`](https://torch-geometric-pool.readthedocs.io/en/latest/api/src.html#tgp.src.PoolingOutput), to be used downstream the $\texttt{RES}$ and $\texttt{CON}$.
```

The pooling layer can also be used to perform $\texttt{LIFT}$, i.e., to map the pooled features back to the original node space.

In [9]:
x_lift = pooler(
    x=pooling_output.x, so=pooling_output.so, batch=pooling_output.batch, lifting=True
)

print(f"original x shape: {data_batch.x.shape}")
print(f"pooled x shape: {pooling_output.x.shape}")
print(f"x_lift shape: {x_lift.shape}")

original x shape: torch.Size([1109, 3])
pooled x shape: torch.Size([290, 3])
x_lift shape: torch.Size([1109, 3])


$\texttt{LIFT}$ is typically used by GNNs with an autoencoder architecture that perform node-level tasks (e.g., node classification).

## Types of pooling operator

On of the main differnces between the pooling operators in <img src="../_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp is if they are **dense** or **sparse**. [`TopkPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.TopkPooling) that we just saw is a sparse method. Let's now look at a dense pooler: [`MinCutPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.MinCutPooling).

In [10]:
params = {
    "in_channels": 3,  # Number of input features
    "k": 10,  # Number of supernodes in the pooled graph
}

dense_pooler = get_pooler("mincut", **params)
print(dense_pooler)

MinCutPooling(
	select=DenseSelect(in_channels=[3], k=10, act=None, dropout=0.0, s_inv_op=transpose)
	reduce=BaseReduce(reduce_op=sum)
	lift=BaseLift(matrix_op=precomputed, reduce_op=sum)
	connect=DenseConnect(remove_self_loops=False, degree_norm=False, adj_transpose=True)
	cut_loss_coeff=1.0
	ortho_loss_coeff=1.0
)


Something that sets sparse and pooling methods apart is the format of the data that they take as input. 
In particular, dense methods take as input graphs whose connectivity matrix is a dense tensor.
Luckily, we do not need to keep track of which method we are using to do the right preprocessing.
Each pooling operator in <img src="../_static/img/tgp-logo.svg" width="20px" align="center" style="display: inline-block; height: 1.3em; width: unset; vertical-align: text-top;"/> tgp provides a preprocessing function that converts the data in the correct format accepted by the operator.

In [11]:
x_dense, adj_dense, mask = dense_pooler.preprocessing(
    x=data_batch.x,
    edge_index=data_batch.edge_index,
    batch=data_batch.batch,
)
print(f"x_dense shape: {x_dense.shape}")
print(f"adj_dense shape: {adj_dense.shape}")
print(f"mask shape: {mask.shape}")

x_dense shape: torch.Size([32, 126, 3])
adj_dense shape: torch.Size([32, 126, 126])
mask shape: torch.Size([32, 126])


The processed data can now be fed into the dense pooling operator to compute the output.

In [12]:
dense_pooling_output = dense_pooler(
    x=x_dense,
    adj=adj_dense,
    batch=data_batch.batch,
)
print(dense_pooling_output)

PoolingOutput(so=[126, 10], x=[32, 10, 3], edge_index=[32, 10, 10], edge_weight=None, batch=None, loss=['cut_loss', 'ortho_loss'])


The connectivity of the coarsened graphs generated by a dense pooling operator is also a dense tensor.

In [13]:
print(dense_pooling_output.edge_index[0])

tensor([[0.0000, 0.0819,  ..., 0.1121, 0.1049],
        [0.0819, 0.0000,  ..., 0.1031, 0.0983],
        ...,
        [0.1121, 0.1031,  ..., 0.0000, 0.1231],
        [0.1049, 0.0983,  ..., 0.1231, 0.0000]], grad_fn=<SelectBackward0>)


Another difference w.r.t. [`TopkPooling`](https://torch-geometric-pool.readthedocs.io/en/latest/api/poolers.html#tgp.poolers.TopkPooling) is the presence of one or more loss terms in the pooling output.

In [14]:
for key, value in dense_pooling_output.loss.items():
    print(f"{key}: {value:.3f}")

cut_loss: -0.964
ortho_loss: 1.164


These are *auxiliary losses* that must be minimized along with the other task's losses used to train the GNN. 
Most dense pooling methods have an auxiliary loss. 
A few sparse methods have an auxiliary loss too.

## GNN model with pooling layers

Let's create a simple GNN for graph classification with the following architecture: 

$$[\texttt{MP}-\texttt{Pool}-\texttt{MP}-\texttt{GlobalPool}-\texttt{Linear}]$$


### Initialization
First, in the `__init__()` we specify the architecture, instatiating the MP layers, the pooling layer from its alias and parameters, and the readout.

In [15]:
class GNN(torch.nn.Module):
    def __init__(
        self, in_channels, out_channels, pooler_type, pooler_kwargs, hidden_channels=64
    ):
        super().__init__()

        # First MP layer
        self.conv1 = GCNConv(in_channels=in_channels, out_channels=hidden_channels)

        # Pooling
        self.pooler = pooler_kwargs.update({"in_channels": hidden_channels})
        self.pooler = get_pooler(pooler_type, **pooler_kwargs)

        # Second MP layer
        if self.pooler.is_dense:
            self.conv2 = DenseGCNConv(
                in_channels=hidden_channels, out_channels=hidden_channels
            )
        else:
            self.conv2 = GCNConv(
                in_channels=hidden_channels, out_channels=hidden_channels
            )

        # Readout layer
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

Note that the type of pooling operator determines what kind of MP layer is used after pooling. 
A sparse pooler is followed by a sparse MP operator such as [`GCNConv`](https://pytorch-geometric.readthedocs.io/en/2.6.1/generated/torch_geometric.nn.conv.GCNConv.html#torch_geometric.nn.conv.GCNConv). 
On the other hand, a dense pooling operator that returns a dense connectivity matrix must be followed by a dense MP layer such as [`DenseGCNConv`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.dense.DenseGCNConv.html).
The type of pooling operator can be checked by the property `is_dense`.

### Forward pass

Next, we define the forward pass of the GNN.

In [16]:
def forward(self, x, edge_index, edge_weight, batch=None):
    # First MP layer
    x = self.conv1(x, edge_index, edge_weight)
    x = F.relu(x)

    # Pooling
    x, edge_index, mask = self.pooler.preprocessing(
        x=x,
        edge_index=edge_index,
        edge_weight=edge_weight,
        batch=batch,
        use_cache=False,
    )
    out = self.pooler(
        x=x, adj=edge_index, edge_weight=edge_weight, batch=batch, mask=mask
    )
    x_pool, adj_pool = out.x, out.edge_index

    # Second MP layer
    x = self.conv2(x_pool, adj_pool)
    x = F.relu(x)

    # Global pooling
    x = self.pooler.global_pool(x, reduce_op="sum", batch=out.batch)

    # Readout layer
    x = self.lin(x)

    if out.loss is not None:
        return F.log_softmax(x, dim=-1), sum(out.get_loss_value())
    else:
        return F.log_softmax(x, dim=-1), torch.tensor(0.0)


GNN.forward = forward

There are a few things to discuss.

#### Preprocessing
In the `forward()` function, before calling the pooling layer, we must preprocess the data. If the pooler is sparse, preprocessing has no effect: `x` and `edge_index` will be returned as-is, and `mask` will be `None`.

In [17]:
x, edge_index, mask = pooler.preprocessing(
    x=data_batch.x,
    edge_index=data_batch.edge_index,
    edge_weight=data_batch.edge_weight,
    batch=data_batch.batch,
)

print(f"x shape: {x.shape}")
print(f"edge_index shape: {edge_index.shape}")
print(f"mask: {mask}")

x shape: torch.Size([1109, 3])
edge_index shape: torch.Size([2, 3964])
mask: None


Conversely, if the pooler is dense, `x` will be a tensor of size $[B, N, F]$, where $B$ is the batch size, $N$ is the maximum number of nodes in the batch, and $F$ is the size of the node features. 
Graphs with less than $N$ nodes will be padded and `mask` is a boolean indicating which node is valid. 
Similarly, `edge_index` will be a dense tensor of shape $[B, N, N]$. 
Internally, `preprocessing()` of a dense pooler calls the functions [`to_dense_batch`](https://pytorch-geometric.readthedocs.io/en/2.4.0/_modules/torch_geometric/utils/to_dense_batch.html) and [`to_dense_adj`](https://pytorch-geometric.readthedocs.io/en/2.4.0/_modules/torch_geometric/utils/to_dense_adj.html) of <img src="https://raw.githubusercontent.com/TorchSpatiotemporal/tsl/main/docs/source/_static/img/logos/pyg.svg" width="20px" align="center"/> PyG.

Finally, `use_cache=True`, avoids recomputing the densified version of `edge_index`. 
This is useful in tasks such as [node_classification](https://github.com/tgp-team/torch-geometric-pool/blob/main/examples/node_class.py) and [clustering](https://github.com/tgp-team/torch-geometric-pool/blob/main/examples/clustering.py), where there is only one  underlying graph that is usually large and costly to densify.

In [18]:
x, edge_index, mask = dense_pooler.preprocessing(
    x=data_batch.x,
    edge_index=data_batch.edge_index,
    edge_weight=data_batch.edge_weight,
    batch=data_batch.batch,
)

print(f"x shape: {x.shape}")
print(f"edge_index shape: {edge_index.shape}")
print(f"mask shape: {mask.shape}")

x shape: torch.Size([32, 126, 3])
edge_index shape: torch.Size([32, 126, 126])
mask shape: torch.Size([32, 126])


#### Global pooling
The global pooling operation combines all the features in the current graph and is implemented differently depending if the pooler is sparse or dense.
In the sparse case, we have a `batch` tensor indicating to which graph each node belongs to. 
In this case, global pooling should combine the features of the nodes belonging to the same graph. 
The output is a tensor of shape $[B, F]$.

In [19]:
# Sparse case
print(f"Input shape: {data_batch.x.shape}")
out_global_sparse = pooler.global_pool(
    data_batch.x, reduce_op="sum", batch=data_batch.batch
)
print(f"Output shape: {out_global_sparse.shape}")

Input shape: torch.Size([1109, 3])
Output shape: torch.Size([32, 3])


In the dense case, the features of the pooled graph are stored in a tensor of shape $[B, K, F]$ and global pooling can be done e.g., by summing or taking the average across the nodes dimension, yielding a tensor of shape $[B, F]$. In this case, `batch` is not needed. 

In [20]:
# Dense case
print(f"Input shape: {x_dense.shape}")
out_global_dense = dense_pooler.global_pool(x_dense, reduce_op="sum", batch=None)
print(f"Output shape: {out_global_dense.shape}")

Input shape: torch.Size([32, 126, 3])
Output shape: torch.Size([32, 3])


Note that in both cases the output is the same.

In [21]:
torch.allclose(out_global_sparse, out_global_dense)

True

#### Auxiliary losses
As we saw earlier, some pooling operators return an auxiliary loss, while others do not.
In the forward pass we check if `out.loss` is not `None` and, in case, return the sum of all the auxiliary losses to be passed to the optimizer.

### Testing the model

Let's first test our GNN when configured with a sparse pooler.

In [22]:
num_features = dataset.num_features
num_classes = dataset.num_classes

sparse_params = {
    "ratio": 0.25,  # Percentage of nodes to keep
}

sparse_pool_gnn = GNN(
    in_channels=num_features,
    out_channels=num_classes,
    pooler_type="topk",
    pooler_kwargs=sparse_params,
)

sparse_gnn_out = sparse_pool_gnn(
    x=data_batch.x,
    edge_index=data_batch.edge_index,
    edge_weight=data_batch.edge_weight,
    batch=data_batch.batch,
)
print(f"Sparse GNN output shape: {sparse_gnn_out[0].shape}")
print(f"Sparse GNN loss: {sparse_gnn_out[1]:.3f}")

Sparse GNN output shape: torch.Size([32, 6])
Sparse GNN loss: 0.000


Since there is no auxiliary loss, the second output of the GNN is simply a constant zero-valued tensor that will not affect the gradients computation.

Next, we create the GNN with the dense pooling layer.

In [23]:
dense_params = {
    "k": 10,  # Number of supernodes in the pooled graph
}
dense_pool_gnn = GNN(
    in_channels=num_features,
    out_channels=num_classes,
    pooler_type="mincut",
    pooler_kwargs=dense_params,
)
dense_gnn_out = dense_pool_gnn(
    x=data_batch.x,
    edge_index=data_batch.edge_index,
    edge_weight=data_batch.edge_weight,
    batch=data_batch.batch,
)
print(f"Dense GNN output shape: {dense_gnn_out[0].shape}")
print(f"Dense GNN loss: {dense_gnn_out[1]:.3f}")

Dense GNN output shape: torch.Size([32, 6])
Dense GNN loss: 0.170


This time, we get an auxiliary loss that should be added to the other losses, e.g. the classification loss of the downstream task.

In [24]:
total_loss = F.nll_loss(dense_gnn_out[0], data_batch.y.view(-1)) + dense_gnn_out[1]
print(f"Loss: {total_loss:.3f}")

Loss: 3.165


And that's it! We can train this GNN as any other that we normally build with <img src="https://raw.githubusercontent.com/TorchSpatiotemporal/tsl/main/docs/source/_static/img/logos/pyg.svg" width="20px" align="center"/> PyG.

You can check the complete graph classificatione example [here](https://github.com/tgp-team/torch-geometric-pool/blob/main/examples/classification.py).