# Train a Hypersage TNN

In this notebook, we will create and train HyperSAGE layer (Arya et al., [2020](https://arxiv.org/abs/2010.04558)) - two-levels message passing strategy for hypergraphs learning. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph.

Following the "awesome-tnns" [github repo.](https://github.com/awesome-tnns/awesome-tnns/blob/main/Hypergraphs.md)

🟥 $\quad m_{y \rightarrow z}^{(0 \rightarrow 1)} = (B_1)^T_{zy} \cdot w_y \cdot (h_y^{(0)})^p$ 

🟥 $\quad m_z^{(0 \rightarrow 1)}  = \left(\frac{1}{\vert \mathcal{B}(z)\vert}\sum_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0 \rightarrow 1)}\right)^{\frac{1}{p}}$

🟥 $\quad m_{z \rightarrow x}^{(1 \rightarrow 0)} =  (B_1)_{xz} \cdot w_z  \cdot (m_z^{(0 \rightarrow 1)})^p$

🟧 $\quad m_x^{(1 \rightarrow 0)}  = \left(\frac{1}{\vert \mathcal{C}(x) \vert}\sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow 0)}\right)^{\frac{1}{p}}$

🟩 $\quad m_x^{(0)}  = m_x^{(1 \rightarrow 0)}$ 

🟦 $\quad h_x^{t+1, (0)} = \sigma \left(\frac{m_x^{(0)} + h_x^{t,(0)}}{\lvert m_x^{(0)} + h_x^{t,(0)}\rvert} \cdot \Theta^t\right)$ 

### Additional theoretical clarifications

Arya et al propose to interpret the propagation of information in a given hypergraph as a two-level aggregation problem, where the neighborhood of any node is divided into intra-edge neighbors and inter-edge neighbors. Given a hypergraph $H=(\mathcal{V}, \mathcal{E})$, let $\textbf{X}$ denote the feature matrix, such that $\textbf{x}_{i} \in \textbf{X}$ is the feature set for node $\textbf{v}_{i} \in \textbf{V}$ . For two-level aggregation, 
let $\mathcal{F}_{1}(·)$ and $\mathcal{F}_{2}(·)$ denote the intra-edge and inter-edge aggregation functions, respectively. Message passing at node vi for aggregation of information at the $\mathcal{l}^{th}$ layer can then be stated as

$ \mathcal{x}_{i,l}^{(e)} \leftarrow \mathcal{F}_{1}(\{ \mathcal{x}_{j,l-1} | \mathcal{v}_{j} \in \mathcal{N}( \mathcal{v}_{i},
\textbf{e},\alpha) \}), $

$ \mathcal{x}_{i,l} \leftarrow \mathcal{x}_{i,l-1} + \mathcal{F}_{2}(\{ \mathcal{x}_{i,l}^{(e)} | \mathcal{v}_{i} \in {E}( \mathcal{v}_{i}) \}), $

where, $ \mathcal{x}_{i,l}^{(e)}$  refers to the aggregated feature set at $\mathcal{v}_{i}$ obtained with intra-edge aggregation for edge $\textbf{e}$.

In [1]:
import torch
import numpy as np
import toponetx.datasets as datasets
from sklearn.model_selection import train_test_split

from topomodelx.nn.hypergraph.hypersage import HyperSAGE

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:
- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input
- the label associated to the hypergraph

In [3]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]

Loading shrec 16 small dataset...

done!


In [4]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)

The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.


## Define neighborhood structures and lift into hypergraph domain. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\text{nodes} \times n_\text{edges}$.

Once we have recorded the incidence matrix (note that all incidence amtrices in the hypergraph domain must be unsigned), we lift each simplicial complex into a hypergraph. The pairwise edges will become pairwise hyperedges, and faces in the simplciial complex will become 3-wise hyperedges.

In [5]:
hg_list = []
incidence_1_list = []
for simplex in simplexes:
    incidence_1 = simplex.incidence_matrix(rank=1, signed=False)
    hg = simplex.to_hypergraph()
    hg_list.append(hg)

# Extract hypergraphs incident matrices from collected hypergraphs
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    incidence_1_list.append(incidence_1)

In [6]:
i_complex = 6
print(
    f"The {i_complex}th hypergraph has an incidence matrix of shape {incidence_1_list[i_complex].shape}."
)

The 6th hypergraph has an incidence matrix of shape torch.Size([252, 1250]).


# Define the Neural Network



In [9]:
# Define the model
in_channels = x_0s[0].shape[1]
out_channels = 10
p = 2
initialization = "xavier_uniform"
n_layers = 2
model = HyperSAGE(
    in_channels=in_channels,
    out_channels=out_channels,
    n_layers=n_layers,
    device="cpu",
    initialization=initialization,
)

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [10]:
# Optimizer and loss
opt = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()

Split the dataset into train and test sets.

In [11]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [12]:
test_interval = 1
num_epochs = 5
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, incidence_1, y in zip(x_0_train, incidence_1_train, y_train):
        x_0 = torch.tensor(x_0)
        x_0, incidence_1, y = (
            x_0.float().to(device),
            incidence_1.float().to(device),
            torch.tensor(y, dtype=torch.float).to(device),
        )
        opt.zero_grad()
        # Extract edge_index from sparse incidence matrix
        # edge_index, _ = to_edge_index(incidence_1)
        y_hat = model(x_0, incidence_1)
        loss = loss_fn(y_hat, y)

        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x_0, incidence_1, y in zip(x_0_test, incidence_1_test, y_test):
                x_0 = torch.tensor(x_0)
                x_0, incidence_1, y = (
                    x_0.float().to(device),
                    incidence_1.float().to(device),
                    torch.tensor(y, dtype=torch.float).to(device),
                )
                y_hat = model(x_0, incidence_1)
                loss = loss_fn(y_hat, y)

            print(f"Test_loss: {loss:.4f}", flush=True)

Epoch: 1 loss: 275.2170
Test_loss: 529.0000
Epoch: 2 loss: 274.6125
Test_loss: 529.0000
Epoch: 3 loss: 274.6125
Test_loss: 529.0000
Epoch: 4 loss: 274.6125
Test_loss: 529.0000
Epoch: 5 loss: 274.6125
Test_loss: 529.0000
