# Train an All-Set TNN

In this notebook, we will create and train a two-step message passing network named AllSet (Chien et al., [2021](https://arxiv.org/abs/2106.13264)) in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph. 

🟧 $\quad m_{\rightarrow z}^{(\rightarrow 1)} = AGG_{y \in \mathcal{B}(z)} (h_y^{t, (0)}, h_z^{t,(1)})$ 

🟦 $\quad h_z^{t+1,(1)} = \sigma(m_{\rightarrow z}^{(\rightarrow 1)})$ 

Edge to vertex: 

🟧 $\quad m_{\rightarrow x}^{(\rightarrow 0)} = AGG_{z \in \mathcal{C}(x)} (h_z^{t+1,(1)}, h_x^{t,(0)})$ 

🟦 $\quad h_x^{t+1,(0)} = \sigma(m_{\rightarrow x}^{(\rightarrow 0)})$

### Additional theoretical clarifications
Given a hypergraph $G=(\mathcal{V}, \mathcal{E})$, let $\textbf{X} \in \mathbb{R}^{|\mathcal{V}| \times F}$ and $\textbf{Z} \in \mathbb{R}^{|\mathcal{E}| \times F'}$ denote the hidden node and hyperedge representations, respectively. Additionally, define $V_{e, \textbf{X}} = \{\textbf{X}_{u,:}: u \in e\}$ as the multiset of hidden node representations in the hyperedge $e$ and $E_{v, \textbf{Z}} = \{\textbf{Z}_{e,:}: v \in e\}$ as the multiset of hidden representations of hyperedges containing $v$.

\
In this setting, the two general update rules that AllSet's framework puts in place in each layer are:

🔷 $\textbf{Z}_{e,:}^{(t+1)} = f_{\mathcal{V} \rightarrow \mathcal{E}}(V_{e, \textbf{X}^{(t)}}; \textbf{Z}_{e,:}^{(t)})$

🔷 $\textbf{X}_{v,:}^{(t+1)} = f_{\mathcal{E} \rightarrow \mathcal{V}}(E_{v, \textbf{Z}^{(t+1)}}; \textbf{X}_{v,:}^{(t)})$

in which $f_{\mathcal{V} \rightarrow \mathcal{E}}$ and $f_{\mathcal{E} \rightarrow \mathcal{V}}$ are two permutation invariant functions with respect to their first input. The matrices $\textbf{Z}_{e,:}^{(0)}$ and $\textbf{X}_{v,:}^{(0)}$ are initialized with the hyperedge and node features respectively, if available, otherwise they are set to be all-zero matrices.

In the practical implementation of the model, $f_{\mathcal{V} \rightarrow \mathcal{E}}$ and $f_{\mathcal{E} \rightarrow \mathcal{V}}$ are parametrized and $learnt$ for each dataset and task, and the information of their second argument is not utilized. 


In [1]:
"""
This module contains the AllSet class for hypergraph-based neural networks.

The AllSet class implements a specific hypergraph-based neural network architecture
used for solving certain types of problems.

Author: Your Name
"""

import torch
import numpy as np
from sklearn.model_selection import train_test_split
import toponetx.datasets as datasets

from topomodelx.nn.hypergraph.allset import AllSet

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:
- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input
- the label associated to the hypergraph

In [3]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]

Loading shrec 16 small dataset...

done!


In [4]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)

The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.


## Define neighborhood structures and lift into hypergraph domain. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\text{nodes} \times n_\text{edges}$.

Once we have recorded the incidence matrix (note that all incidence amtrices in the hypergraph domain must be unsigned), we lift each simplicial complex into a hypergraph. The pairwise edges will become pairwise hyperedges, and faces in the simplciial complex will become 3-wise hyperedges.

In [5]:
hg_list = []
incidence_1_list = []
for simplex in simplexes:
    incidence_1 = simplex.incidence_matrix(rank=1, signed=False)
    # incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    # incidence_1_list.append(incidence_1)
    hg = simplex.to_hypergraph()
    hg_list.append(hg)


# Extract hypergraphs incident matrices from collected hypergraphs
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    incidence_1_list.append(incidence_1)

In [6]:
i_complex = 6
print(
    f"The {i_complex}th hypergraph has an incidence matrix of shape {incidence_1_list[i_complex].shape}."
)

The 6th hypergraph has an incidence matrix of shape torch.Size([252, 1250]).


# Create the Neural Network

In [7]:
channels_edge = x_1s[0].shape[1]
channels_node = x_0s[0].shape[1]
in_channels = channels_node
hidden_channels, out_channels = 64, 1

# Define the model
model = AllSet(
    in_channels=channels_node,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=1,
    mlp_num_layers=1,
)
model = model.to(device)

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [8]:
# Optimizer and loss
opt = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()

Split the dataset into train and test sets.

In [9]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [10]:
test_interval = 1
num_epochs = 5
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, incidence_1, y in zip(x_0_train, incidence_1_train, y_train):
        x_0 = torch.tensor(x_0)
        x_0, incidence_1, y = (
            x_0.float().to(device),
            incidence_1.float().to(device),
            torch.tensor(y, dtype=torch.float).to(device),
        )
        opt.zero_grad()
        # Extract edge_index from sparse incidence matrix
        # edge_index, _ = to_edge_index(incidence_1)
        y_hat = model(x_0, incidence_1)
        loss = loss_fn(y_hat, y)

        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x_0, incidence_1, y in zip(x_0_test, incidence_1_test, y_test):
                x_0 = torch.tensor(x_0)
                x_0, incidence_1, y = (
                    x_0.float().to(device),
                    incidence_1.float().to(device),
                    torch.tensor(y, dtype=torch.float).to(device),
                )
                y_hat = model(x_0, incidence_1)
                loss = loss_fn(y_hat, y)

            print(f"Test_loss: {loss:.4f}", flush=True)

Epoch: 1 loss: 274.8176
Test_loss: 529.0000
Epoch: 2 loss: 274.6125
Test_loss: 529.0000
Epoch: 3 loss: 274.6125
Test_loss: 529.0000
Epoch: 4 loss: 274.6125
Test_loss: 529.0000
Epoch: 5 loss: 274.6125
Test_loss: 529.0000
