# Train a Simplicial High-Skip Network (HSN)

We create and train a High Skip Network in the simplicial complex domain, as proposed in [Hajij et. al : High Skip Networks: A Higher Order Generalization of Skip Connections (2022)](https://openreview.net/pdf?id=Sc8glB-k6e9). 

The equations of one layer of this neural network are given by:

🟥 $\quad m_{{y \rightarrow z}}^{(0 \rightarrow 0)} = \sigma ((A_{\uparrow,0})_{xy} \cdot h^{t,(0)}_y \cdot \Theta^{t,(0)1})$    (level 1)

🟥 $\quad m_{z \rightarrow x}^{(0 \rightarrow 0)}  = (A_{\uparrow,0})_{xy} \cdot m_{y \rightarrow z}^{(0 \rightarrow 0)} \cdot \Theta^{t,(0)2}$    (level 2)

🟥 $\quad m_{{y \rightarrow z}}^{(0 \rightarrow 1)}  = \sigma((B_1^T)_{zy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0 \rightarrow 1)})$    (level 1)

🟥 $\quad m_{z \rightarrow x)}^{(1 \rightarrow 0)}  = (B_1)_{xz} \cdot m_{z \rightarrow x}^{(0 \rightarrow 1)} \cdot \Theta^{t, (1 \rightarrow 0)}$    (level 2)

🟧 $\quad m_{x}^{(0 \rightarrow 0)}  = \sum_{z \in \mathcal{L}_\uparrow(x)} m_{z \rightarrow x}^{(0 \rightarrow 0)}$

🟧 $\quad m_{x}^{(1 \rightarrow 0)}  = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow 0)}$

🟩 $\quad m_x^{(0)}  = m_x^{(0 \rightarrow 0)} + m_x^{(1 \rightarrow 0)}$

🟦 $\quad h_x^{t+1,(0)}  = I(m_x^{(0)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

We train the model to perform binary node classification using the [KarateClub benchmark dataset](https://en.wikipedia.org/wiki/Zachary%27s_karate_club). 

In [1]:
import numpy as np
import sklearn
import toponetx as tnx
import torch
from torch_geometric.datasets.karate import KarateClub
from torch_geometric.utils.convert import to_networkx

from topomodelx.nn.simplicial.hsn_layer import HSNLayer

## Import dataset

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to four different social groups. Every node is labeled by one of four classes obtained via modularity-based clustering, following the “Semi-supervised Classification with Graph Convolutional Networks” paper by Kipf & Welling 2016.

We will use these groups for the task of node-level classification.

In [2]:
dataset = KarateClub()[0]
print(dataset)
i_node = 5
print(
    f"\nThe {i_node}-th node:\n"
    f"- has feature: {dataset.x[i_node]}\n"
    f"- and label: {dataset.y[i_node]}."
)

Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

The 5-th node:
- has feature: tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
- and label: 3.


## Choose Topological Domain

To train a HSN, we must first lift our graph dataset into the simplicial complex domain.

In [3]:
simplicial_complex = tnx.transform.graph_2_clique_complex(
    to_networkx(dataset).to_undirected(), max_dim=3
)
print(simplicial_complex)

Simplicial Complex with shape [34, 78, 45] and dimension 2


## Define neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messages on the domain. 

For the HSN model, we need the boundary matrix (or incidence matrix) $B_1$ and the adjacency matrix $A_{\uparrow,0}$ on the nodes. 

As a sanity check, we show that the shape of the $B_1 = n_\text{nodes} \times n_\text{edges}$ and $A_{\uparrow,0} = n_\text{nodes} \times n_\text{nodes}$. We also convert the neighborhood structures to torch tensors.

In [4]:
incidence_1 = simplicial_complex.incidence_matrix(rank=1)
adjacency_0 = simplicial_complex.adjacency_matrix(rank=0)
incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()

print(f"The incidence matrix B1 has shape: {incidence_1.shape}.")
print(f"The adjacency matrix A0 has shape: {adjacency_0.shape}.")

The incidence matrix B1 has shape: torch.Size([34, 78]).
The adjacency matrix A0 has shape: torch.Size([34, 34]).


## Import signal ##

Since our task will be node classification, we must retrieve an input signal on the nodes. The signal will have shape $n_\text{nodes} \times$ in_channels, where in_channels is the dimension of each cell's feature. Here, we have in_channels = channels_nodes $ = 34$. This is because the Karate dataset encodes the identity of each of the 34 nodes as a one hot encoder.

In [5]:
x_nodes = dataset.x
print(x_nodes.shape)
channels_nodes = x_nodes.shape[-1]

torch.Size([34, 34])


## Define classification labels

We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, four social groups emerge. So we assign labels to the nodes indicating of which group they are a part.

We convert the classification labels into one-hot encoder form, and keep the first four nodes' true labels for the purpose of testing.

In [6]:
y_true = dataset.y
y_true[:3]

tensor([1, 1, 1])

In [7]:
y_true = (
    sklearn.preprocessing.OneHotEncoder().fit_transform(y_true.view(-1, 1)).todense()
)
y_true[:3]

matrix([[0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.]])

In [8]:
y_test = y_true[:4]
y_train = y_true[-30:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# Create the Neural Network

Using the HSNLayer class, we create a neural network with stacked layers. A linear layer at the end produces an output with shape $n_\text{nodes} \times 4$, so we can compare with our classification labels.

In [9]:
class HSN(torch.nn.Module):
    """High Skip Network Implementation for binary node classification.

    Parameters
    ---------
    channels : int
        Dimension of features.
    n_layers : int
        Number of message passing layers.
    """

    def __init__(self, channels, n_layers=2, n_classes=4):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                HSNLayer(
                    channels=channels,
                )
            )
        self.linear = torch.nn.Linear(channels, n_classes)
        self.layers = layers

    def forward(self, x_0, incidence_1, adjacency_0):
        """Forward computation.

        Parameters
        ---------=
        x_0 : tensor, shape = [n_nodes, channels]
            Node features.
        incidence_1 : tensor, shape = [n_nodes, n_edges]
            Boundary matrix of rank 1.
        adjacency_0 : tensor, shape = [n_nodes, n_nodes]
            Adjacency matrix (up) of rank 0.

        Returns
        -------
        _ : tensor, shape = [n_nodes, 2]
            One-hot labels assigned to nodes.
        """
        for layer in self.layers:
            x_0 = layer(x_0, incidence_1, adjacency_0)
        return torch.sigmoid(self.linear(x_0))

# Train the Neural Network

We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [10]:
model = HSN(
    channels=channels_nodes,
    n_layers=10,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.4)

The following cell performs the training, looping over the network for a low number of epochs.

In [11]:
test_interval = 2
num_epochs = 5
for i_epoch in range(num_epochs):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(x_nodes, incidence_1, adjacency_0)
    loss = torch.nn.functional.cross_entropy(
        y_hat[-len(y_train) :].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred == y_hat).all(dim=1).float().mean().item()
    print(
        f"Epoch: {i_epoch} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if i_epoch % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(x_nodes, incidence_1, adjacency_0)
            y_pred_test = torch.sigmoid(y_hat_test).ge(0.5).float()
            test_accuracy = (
                torch.eq(y_pred_test[: len(y_test)], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)


 x_source.shape = torch.Size([34, 34])
self.weight.shape = torch.Size([34, 34])
neighborhood.shape = torch.Size([34, 34])
x_message.shape = torch.Size([34, 34])

 x_source.shape = torch.Size([34, 34])
self.weight.shape = torch.Size([34, 34])
neighborhood.shape = torch.Size([78, 34])
x_message.shape = torch.Size([34, 34])

 x_source.shape = torch.Size([34, 34])
self.weight.shape = torch.Size([34, 34])
neighborhood.shape = torch.Size([34, 34])
x_message.shape = torch.Size([34, 34])

 x_source.shape = torch.Size([78, 34])
self.weight.shape = torch.Size([34, 34])
neighborhood.shape = torch.Size([34, 78])
x_message.shape = torch.Size([78, 34])

 x_source.shape = torch.Size([34, 34])
self.weight.shape = torch.Size([34, 34])
neighborhood.shape = torch.Size([34, 34])
x_message.shape = torch.Size([34, 34])

 x_source.shape = torch.Size([34, 34])
self.weight.shape = torch.Size([34, 34])
neighborhood.shape = torch.Size([78, 34])
x_message.shape = torch.Size([34, 34])

 x_source.shape = torch.Siz