# Tutorial: Set-up, create, and train a two-step message passing network (Template)

In this notebook, we will create and train a two-step message passing network in the simplicial complex domain. We will build a simple toy dataset from scratch using TopoNetX. We train the model to perform binary node classification. 

In [1]:
import torch
import numpy as np
from toponetx import SimplicialComplex
from topomodelx.nn.simplicial.template_layer import TemplateLayer

# Pre-processing

## Create domain ##

The first step is to define the topological domain on which the TNN will operate, as well as the neighborhod structures characterizing this domain. We will only define the neighborhood matrices that we plan on using.

Here, we build a simple simplicial complex domain. Our domain is comprised of 5 nodes, which form two faces. We specify two edges in the domain, and TopoNetX adds edges along the faces to ensure the cell is regular.

In [2]:
edge_set = [[1, 2], [1, 3]]
face_set = [[2, 3, 4], [2, 4, 5]]

domain = SimplicialComplex(edge_set + face_set)

## Create neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the boundary matrix (or incidence matrix) $B_2$. For a santiy check, we show that the shape of the $B_2 = n_\text{edges} \times n_\text{faces}$.

In [3]:
incidence_2 = domain.incidence_matrix(rank=2)
print("incidence_2\n", incidence_2.todense())

incidence_2
 [[ 0.  0.]
 [ 0.  0.]
 [ 1.  0.]
 [-1.  1.]
 [ 0. -1.]
 [ 1.  0.]
 [ 0.  1.]]


We convert the neighborhood matrix to tensor format.

In [4]:
incidence_2_torch = torch.from_numpy(incidence_2.todense()).to_sparse()

## Create signal ##

Since our task will be node classification, we must define an input signal (at least one datapoint) on the nodes. The signal will have shape $n_\text{faces} \times$ in_channels, where in_channels is the dimension of each cell's feature. Here, we take in_channels = channels_nodes $ = 2$.

In [5]:
x_2 = torch.tensor([[1.0, 1.0], [2.0, 2.0]])

# Create the Neural Network

Using the TemplateLayer class, we create a neural network with stacked layers. We define the amount of channels on the face and edge ranks to be different, making this a heterogenous network.

In [6]:
channels_face = np.array(x_2.shape[1])
channels_edge = 4

In [7]:
class TemplateNN(torch.nn.Module):
    def __init__(self, channels_face, channels_edge, n_layers=2):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                TemplateLayer(
                    in_channels=channels_face,
                    intermediate_channels=channels_edge,
                    out_channels=channels_face,
                )
            )
        self.layers = layers
        self.linear = torch.nn.Linear(channels_face, 1)

    def forward(self, x_2, incidence_2_torch):
        for layer in self.layers:
            x_2 = layer(x_2, incidence_2_torch)
        return self.linear(x_2)

# Train the Neural Network

We specify the model with our pre-made neighborhood structures, assign ground truth labels for the classification task, and specify an optimizer.

In [8]:
model = TemplateNN(channels_face, channels_edge, n_layers=2)
faces_gt_labels = torch.Tensor([[0], [1]])  # (n_faces, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

The following cell performs the training, looping over the network for 5 epochs.

In [9]:
for epoch in range(5):
    optimizer.zero_grad()
    faces_pred_labels = model(x_2, incidence_2_torch)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        faces_pred_labels, faces_gt_labels
    )
    loss.backward()
    optimizer.step()