# Train a Hypergraph Neural Network

In this notebook, we will create and train a two-step message passing network in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph. 

$ f_{\mathcal{V} \rightarrow \mathcal{E}}=f_{\mathcal{E} \rightarrow \mathcal{V}} = Y + MLP(LN(Y)) $
    
$ Y = \theta + MH_{h}(\theta, S, S) $

$ MH_{h}(\theta, S, S) = \mathbin\Vert_{i=1}^{h} w ( \theta^{i} (K^{i})^{T} ) V^{i} $

$ K=MLP(S), V=MLP(S)$


In [115]:
import numpy as np
from sklearn.model_selection import train_test_split

from toponetx import SimplicialComplex
import toponetx.datasets as datasets
from topomodelx.nn.hypergraph.allsettransformer_layer import AllSetTransformerLayer

import torch
from torch_geometric.utils import to_edge_index
# make ipynb to read .py files continiously
%load_ext autoreload
%autoreload 2



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [116]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:
- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input
- the label associated to the hypergraph

In [117]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]

Loading dataset...

done!


In [118]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)

The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.


## Define neighborhood structures and lift into hypergraph domain. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\text{nodes} \times n_\text{edges}$.

Once we have recorded the incidence matrix (note that all incidence amtrices in the hypergraph domain must be unsigned), we lift each simplicial complex into a hypergraph. The pairwise edges will become pairwise hyperedges, and faces in the simplciial complex will become 3-wise hyperedges.

In [119]:
hg_list = []
incidence_1_list = []
for simplex in simplexes:
    incidence_1 = simplex.incidence_matrix(rank=1, signed=False)
    # incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    # incidence_1_list.append(incidence_1)
    hg = simplex.to_hypergraph()
    hg_list.append(hg)


# Extract hypergraphs incident matrices from collected hypergraphs
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    incidence_1_list.append(incidence_1)

In [120]:
i_complex = 6
print(
    f"The {i_complex}th hypergraph has an incidence matrix of shape {incidence_1_list[i_complex].shape}."
)

The 6th hypergraph has an incidence matrix of shape torch.Size([252, 1250]).


# Create the Neural Network

Using the TemplateLayer class, we create a neural network with stacked layers.

In [121]:
channels_edge = x_1s[0].shape[1]
channels_node = x_0s[0].shape[1]
hid_dim, out_dim = 32, 1
num_heads = 4

x_0 = torch.tensor(x_0s[0], dtype=torch.float32)
incidence_1 = incidence_1_list[0]
Q_n = 1

In [122]:
class AllSetTransformerModel(torch.nn.Module):
    """AllSet Neural Network Module.

    A module that combines multiple AllSet layers to form a neural network.

    Parameters
    ----------
    in_dim : int
        Dimension of the input features.
    hid_dim : int
        Dimension of the hidden features.
    out_dim : int
        Dimension of the output features.
    dropout : float
        Dropout probability.
    n_layers : int, optional
        Number of AllSet layers in the network. Defaults to 2.
    input_dropout : float, optional
        Dropout probability for the layer input. Defaults to 0.2.
    mlp_num_layers : int, optional
        Number of layers in the MLP. Defaults to 2.
    mlp_input_norm : bool, optional
        Whether to apply input normalization in the MLP. Defaults to False.
    """

    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        n_layers=2,
        dropout=0.2,
        mlp_num_layers=1,
        mlp_dropout=0.0,
    ):
        super().__init__()
        layers = [
            AllSetTransformerLayer(
                in_channels=in_channels,
                hidden_channels=hidden_channels,
                out_channels=out_channels,
                dropout=dropout,
                mlp_num_layers=mlp_num_layers,
                mlp_dropout=mlp_dropout,
            )
        ]

        for _ in range(n_layers - 1):
            layers.append(
                AllSetTransformerLayer(
                    in_channels=hidden_channels,
                    hidden_channels=hidden_channels,
                    out_dim=hidden_channels,
                    dropout=dropout,
                    mlp_num_layers=mlp_num_layers,
                    mlp_dropout=mlp_dropout,
                )
            )
        self.layers = torch.nn.ModuleList(layers)
        self.linear = torch.nn.Linear(hid_dim, out_dim)

    def forward(self, x_0, incidence_1):
        """
        Forward computation.

        Parameters
        ----------
        x : torch.Tensor
            Input features.
        edge_index : torch.Tensor
            Edge list (of size (2, |E|)).

        Returns
        -------
        torch.Tensor
            Output prediction.
        """
        # cidx = edge_index[1].min()
        # edge_index[1] -= cidx
        # reversed_edge_index = torch.stack(
        #     [edge_index[1], edge_index[0]], dim=0)

        for layer in self.layers:
            x_0 = layer(x_0, incidence_1)
        pooled_x = torch.max(x_0, dim=0)[0]
        return torch.sigmoid(self.linear(pooled_x))[0]

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [123]:
hid_dim, out_dim = 64, 1

# Define the model
model = AllSetTransformerModel(
    in_channels=channels_node,
    hidden_channels=hid_dim,
    out_channels=out_dim,
    n_layers=1,
    mlp_num_layers=1,
)
model = model.to(device)

# Optimizer and loss
opt = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()

Split the dataset into train and test sets.

In [124]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [None]:
test_interval = 10
num_epochs = 50
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, incidence_1, y in zip(x_0_train, incidence_1_train, y_train):
        x_0 = torch.tensor(x_0)
        x_0, incidence_1, y = (
            x_0.float().to(device),
            incidence_1.float().to(device),
            torch.tensor(y, dtype=torch.float).to(device),
        )
        opt.zero_grad()
        # Extract edge_index from sparse incidence matrix
        # edge_index, _ = to_edge_index(incidence_1)
        y_hat = model(x_0, incidence_1)
        loss = loss_fn(y_hat, y)

        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x_0, incidence_1, y in zip(x_0_test, incidence_1_test, y_test):
                x_0 = torch.tensor(x_0)
                x_0, incidence_1, y = (
                    x_0.float().to(device),
                    incidence_1.float().to(device),
                    torch.tensor(y, dtype=torch.float).to(device),
                )
                y_hat = model(x_0, incidence_1)
                loss = loss_fn(y_hat, y)

            print(f"Test_loss: {loss:.4f}", flush=True)

In [64]:
x_0 = torch.randn(3, 10)
incidence_1 = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32)

in_dim = 10
hid_dim = 40
out_out = 10
layer = AllSetLayer(in_dim, hid_dim, out_out)

In [66]:
layer(x_0, incidence_1)

tensor([[0.5924, 0.9794, 0.0289, 0.0000, 0.1268, 0.1217, 1.2585, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0090, 0.0000, 0.4814, 0.8803, 0.0000, 0.0000, 0.6687, 0.0000,
         0.9917],
        [0.0000, 0.0000, 0.0000, 0.1937, 0.0000, 0.0000, 0.0000, 1.3111, 0.0000,
         0.1433]], grad_fn=<MulBackward0>)