# Train a Simplicial High-Skip Network (HSN)

In this notebook, we will create and train a High Skip Network in the simplicial complex domain, as proposed in the paper by [Hajij et. al : High Skip Networks: A Higher Order Generalization of Skip Connections (2022)](https://openreview.net/pdf?id=Sc8glB-k6e9). 

We train the model to perform binary node classification using the KarateClub benchmark dataset. 

The equations of one layer of this neural network are given by:

🟥 $\quad m_{{y \rightarrow z}}^{(0 \rightarrow 0)} = \sigma ((A_{\uparrow,0})_{xy} \cdot h^{t,(0)}_y \cdot \Theta^{t,(0)1})$    (level 1)

🟥 $\quad m_{z \rightarrow x}^{(0 \rightarrow 0)}  = (A_{\uparrow,0})_{xy} \cdot m_{y \rightarrow z}^{(0 \rightarrow 0)} \cdot \Theta^{t,(0)2}$    (level 2)

🟥 $\quad m_{{y \rightarrow z}}^{(0 \rightarrow 1)}  = \sigma((B_1^T)_{zy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0 \rightarrow 1)})$    (level 1)

🟥 $\quad m_{z \rightarrow x)}^{(1 \rightarrow 0)}  = (B_1)_{xz} \cdot m_{z \rightarrow x}^{(0 \rightarrow 1)} \cdot \Theta^{t, (1 \rightarrow 0)}$    (level 2)

🟧 $\quad m_{x}^{(0 \rightarrow 0)}  = \sum_{z \in \mathcal{L}_\uparrow(x)} m_{z \rightarrow x}^{(0 \rightarrow 0)}$

🟧 $\quad m_{x}^{(1 \rightarrow 0)}  = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow 0)}$

🟩 $\quad m_x^{(0)}  = m_x^{(0 \rightarrow 0)} + m_x^{(1 \rightarrow 0)}$

🟦 $\quad h_x^{t+1,(0)}  = I(m_x^{(0)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [1]:
import torch
import numpy as np

from toponetx import SimplicialComplex
import toponetx.datasets.graph as graph

from topomodelx.nn.simplicial.hsn_layer import HSNLayer

# Pre-processing

## Import dataset ##

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

In [2]:
dataset = graph.karate_club(complex_type="simplicial", feat_dim=8)
print(dataset)

Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4


## Define neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the boundary matrix (or incidence matrix) $B_1$ and the adjacency matrix $A_{\uparrow,0}$ on the nodes. For a santiy check, we show that the shape of the $B_1 = n_\text{nodes} \times n_\text{edges}$ and $A_{\uparrow,0} = n_\text{nodes} \times n_\text{nodes}$. We also convert the neighborhood structures to torch tensors.

In [3]:
incidence_1 = dataset.incidence_matrix(rank=1)
adjacency_0 = dataset.adjacency_matrix(rank=0)

incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()

print(f"The incidence matrix B1 has shape: {incidence_1.shape}.")
print(f"The adjacency matrix A0 has shape: {adjacency_0.shape}.")

The incidence matrix B1 has shape: torch.Size([34, 78]).
The adjacency matrix A0 has shape: torch.Size([34, 34]).


## Import signal ##

Since our task will be node classification, we must retrieve an input signal on the nodes. The signal will have shape $n_\text{nodes} \times$ in_channels, where in_channels is the dimension of each cell's feature. Here, we have in_channels = channels_nodes $ = 34$. This is because the Karate dataset encodes the identity of each of the 34 nodes as a one hot encoder.

In [4]:
x_0 = []
for _, v in dataset.get_simplex_attributes("node_feat").items():
    x_0.append(v)
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]

In [5]:
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

There are 34 nodes with features of dimension 8.


To load edge features, this is how we would do it (note that we will not use these features for this model, and this serves simply as a demonstration).

In [6]:
x_1 = []
for k, v in dataset.get_simplex_attributes("edge_feat").items():
    x_1.append(v)
x_1 = np.stack(x_1)

In [7]:
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

There are 78 edges with features of dimension 8.


Similarly for face features:

In [8]:
x_2 = []
for k, v in dataset.get_simplex_attributes("face_feat").items():
    x_2.append(v)
x_2 = np.stack(x_2)

In [9]:
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")

There are 45 faces with features of dimension 8.


## Define binary labels
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert the binary labels into one-hot encoder form, and keep the last four nodes' true labels for the purpose of testing.

In [10]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_test = y_true[-4:]
y_train = y_true[:30]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# Create the Neural Network

Using the HSNLayer class, we create a neural network with stacked layers. A linear layer at the end produces an output with shape $n_\text{nodes} \times 2$, so we can compare with our binary labels.

In [11]:
class HSN(torch.nn.Module):
    """High Skip Network Implementation for binary node classification.

    Parameters
    ---------
    channels : int
        Dimension of features
    n_layers : int
        Amount of message passing layers.

    """

    def __init__(self, channels, n_layers=2):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                HSNLayer(
                    channels=channels,
                )
            )
        self.linear = torch.nn.Linear(channels, 2)
        self.layers = torch.nn.ModuleList(layers)

    def forward(self, x_0, incidence_1, adjacency_0):
        """Forward computation.

        Parameters
        ---------
        x_0 : tensor
            shape = [n_nodes, channels]
            Node features.

        incidence_1 : tensor
            shape = [n_nodes, n_edges]
            Boundary matrix of rank 1.

        adjacency_0 : tensor
            shape = [n_nodes, n_nodes]
            Adjacency matrix (up) of rank 0.

        Returns
        --------
        _ : tensor
            shape = [n_nodes, 2]
            One-hot labels assigned to nodes.

        """
        for layer in self.layers:
            x_0 = layer(x_0, incidence_1, adjacency_0)
        logits = self.linear(x_0)
        return logits

# Train the Neural Network

We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [12]:
model = HSN(
    channels=channels_nodes,
    n_layers=10,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

In [13]:
print(model)

HSN(
  (linear): Linear(in_features=8, out_features=2, bias=True)
  (layers): ModuleList(
    (0-9): 10 x HSNLayer(
      (conv_level1_0_to_0): Conv()
      (conv_level1_0_to_1): Conv()
      (conv_level2_0_to_0): Conv()
      (conv_level2_1_to_0): Conv()
      (aggr_on_nodes): Aggregation()
    )
  )
)


The following cell performs the training, looping over the network for a low number of epochs.

In [14]:
test_interval = 100
num_epochs = 500
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(x_0, incidence_1, adjacency_0)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[-len(y_train) :] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(x_0, incidence_1, adjacency_0)
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            test_accuracy = (
                torch.eq(y_pred_test[-len(y_test) :], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

Epoch: 1 loss: 0.7293 Train_acc: 0.1667
Epoch: 2 loss: 0.6926 Train_acc: 0.0000
Epoch: 3 loss: 0.6711 Train_acc: 0.0000
Epoch: 4 loss: 0.6475 Train_acc: 0.0667
Epoch: 5 loss: 0.6277 Train_acc: 0.3000
Epoch: 6 loss: 0.6088 Train_acc: 0.3333
Epoch: 7 loss: 0.5909 Train_acc: 0.3333
Epoch: 8 loss: 0.5681 Train_acc: 0.3667
Epoch: 9 loss: 0.5495 Train_acc: 0.4333
Epoch: 10 loss: 0.5237 Train_acc: 0.4333
Epoch: 11 loss: 0.4991 Train_acc: 0.5333
Epoch: 12 loss: 0.4753 Train_acc: 0.5333
Epoch: 13 loss: 0.4512 Train_acc: 0.5333
Epoch: 14 loss: 0.4269 Train_acc: 0.5333
Epoch: 15 loss: 0.4111 Train_acc: 0.5333
Epoch: 16 loss: 0.3984 Train_acc: 0.5333
Epoch: 17 loss: 0.3787 Train_acc: 0.5000
Epoch: 18 loss: 0.3638 Train_acc: 0.5000
Epoch: 19 loss: 0.3521 Train_acc: 0.5333
Epoch: 20 loss: 0.3374 Train_acc: 0.5333
Epoch: 21 loss: 0.3295 Train_acc: 0.5667
Epoch: 22 loss: 0.3166 Train_acc: 0.5667
Epoch: 23 loss: 0.3052 Train_acc: 0.5667
Epoch: 24 loss: 0.2985 Train_acc: 0.5667
Epoch: 25 loss: 0.2934 Tr

Epoch: 199 loss: 0.0013 Train_acc: 0.6333
Epoch: 200 loss: 0.0013 Train_acc: 0.6333
Test_acc: 0.2500
Epoch: 201 loss: 0.0013 Train_acc: 0.6333
Epoch: 202 loss: 0.0013 Train_acc: 0.6333
Epoch: 203 loss: 0.0013 Train_acc: 0.6333
Epoch: 204 loss: 0.0013 Train_acc: 0.6333
Epoch: 205 loss: 0.0013 Train_acc: 0.6333
Epoch: 206 loss: 0.0013 Train_acc: 0.6333
Epoch: 207 loss: 0.0013 Train_acc: 0.6333
Epoch: 208 loss: 0.0013 Train_acc: 0.6333
Epoch: 209 loss: 0.0012 Train_acc: 0.6333
Epoch: 210 loss: 0.0012 Train_acc: 0.6333
Epoch: 211 loss: 0.0012 Train_acc: 0.6333
Epoch: 212 loss: 0.0012 Train_acc: 0.6333
Epoch: 213 loss: 0.0012 Train_acc: 0.6333
Epoch: 214 loss: 0.0012 Train_acc: 0.6333
Epoch: 215 loss: 0.0012 Train_acc: 0.6333
Epoch: 216 loss: 0.0012 Train_acc: 0.6333
Epoch: 217 loss: 0.0012 Train_acc: 0.6333
Epoch: 218 loss: 0.0012 Train_acc: 0.6333
Epoch: 219 loss: 0.0012 Train_acc: 0.6333
Epoch: 220 loss: 0.0012 Train_acc: 0.6333
Epoch: 221 loss: 0.0011 Train_acc: 0.6333
Epoch: 222 loss: 

Epoch: 394 loss: 0.0005 Train_acc: 0.6333
Epoch: 395 loss: 0.0005 Train_acc: 0.6333
Epoch: 396 loss: 0.0005 Train_acc: 0.6333
Epoch: 397 loss: 0.0005 Train_acc: 0.6333
Epoch: 398 loss: 0.0005 Train_acc: 0.6333
Epoch: 399 loss: 0.0005 Train_acc: 0.6333
Epoch: 400 loss: 0.0005 Train_acc: 0.6333
Test_acc: 0.2500
Epoch: 401 loss: 0.0005 Train_acc: 0.6333
Epoch: 402 loss: 0.0005 Train_acc: 0.6333
Epoch: 403 loss: 0.0005 Train_acc: 0.6333
Epoch: 404 loss: 0.0005 Train_acc: 0.6333
Epoch: 405 loss: 0.0005 Train_acc: 0.6333
Epoch: 406 loss: 0.0005 Train_acc: 0.6333
Epoch: 407 loss: 0.0005 Train_acc: 0.6333
Epoch: 408 loss: 0.0005 Train_acc: 0.6333
Epoch: 409 loss: 0.0005 Train_acc: 0.6333
Epoch: 410 loss: 0.0005 Train_acc: 0.6333
Epoch: 411 loss: 0.0005 Train_acc: 0.6333
Epoch: 412 loss: 0.0005 Train_acc: 0.6333
Epoch: 413 loss: 0.0005 Train_acc: 0.6333
Epoch: 414 loss: 0.0005 Train_acc: 0.6333
Epoch: 415 loss: 0.0004 Train_acc: 0.6333
Epoch: 416 loss: 0.0004 Train_acc: 0.6333
Epoch: 417 loss: 