# Train a Simplicial Complex Network (SCoNe)

In this notebook, we will create and train a High Skip Network in the simplicial complex domain, as proposed in the paper by [Hajij et. al : High Skip Networks: A Higher Order Generalization of Skip Connections (2022)](https://openreview.net/pdf?id=Sc8glB-k6e9). 

We train the model to perform binary node classification using the KarateClub benchmark dataset. 

The equations of one layer of this neural network are given by:

🟥 $\quad m_{{y \rightarrow z \rightarrow x}}^{(1 \rightarrow 2 \rightarrow 1)} = \sigma ((L_{\uparrow,1})_{xy} \cdot h^{t,(1)}_y \cdot \Theta^{t,(1)1})$

🟥 $\quad m_{y \rightarrow z \rightarrow x}^{(1 \rightarrow 0 \rightarrow 1)}  = (L_{\downarrow,1})_{xy} \cdot h^{t, (1)}_y \cdot \Theta^{t,(1)2}$    

🟥 $\quad m_{{x \rightarrow x}}^{(1 \rightarrow 1)}  = h_x^{t,(1)} \cdot \Theta^{t,(1)3}$    


🟧 $\quad m^{(1 \rightarrow 2 \rightarrow 1)} = \sum_{y \in \mathcal{L}_\uparrow(x)} m_{{y \rightarrow x}}^{(1 \rightarrow 2 \rightarrow 1)}$

🟧 $\quad m^{(1 \rightarrow 0 \rightarrow 1)}  = \sum_{y \in \mathcal{L}_\downarrow(x)} m_{y \rightarrow z \rightarrow x}^{(1 \rightarrow 0 \rightarrow 1)}$

🟩 $\quad m_x^{(1)}  = m_x^{(1 \rightarrow 2 \rightarrow 1)} + m_x^{(1 \rightarrow 0 \rightarrow 1)} + m_{x \rightarrow x}^{1 \rightarrow 1}$

🟦 $\quad h_x^{t+1,(1)}  = \sigma(m_x^{(1)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [1]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split

import toponetx.datasets.graph as graph

from topomodelx.nn.simplicial.scone_layer_bis import SCoNeLayer

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import dataset
The first step is to import the shrec16 (https://github.com/pyt-team/TopoNetX/blob/0090625d547af9536d9c30001ecfa1f19517921a/toponetx/datasets/mesh.py#L64) dataset. This dataset is a graph with 6 node features, 10 edge features, and the face normals, angles, and areas.

In [3]:
dataset = graph.karate_club(complex_type="simplicial")
dataset.shape

(34, 78, 45, 11, 2)

# Define Neighborhood Structures
Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the lower Laplacian matrix $L_{\downarrow, 1}$ and the upper Laplacian matrix $L_{\uparrow,1}$ on the edges. For a santiy check, we show that the shape of the $L_{\downarrow, 1} = n_\text{edges} \times n_\text{edges}$ and $L_{\uparrow,1} = n_\text{edges} \times n_\text{edges}$. We also convert the neighborhood structures to torch tensors.

In [4]:
up_lap1 = dataset.up_laplacian_matrix(rank=1)
down_lap1 = dataset.down_laplacian_matrix(rank=1)

up_lap1 = torch.from_numpy(up_lap1.todense()).to_sparse()
down_lap1 = torch.from_numpy(down_lap1.todense()).to_sparse()

print(f"The upper laplacian matrix has shape: {up_lap1.shape}.")
print(f"The lower laplacian matrix has shape: {down_lap1.shape}.")

The upper laplacian matrix has shape: torch.Size([78, 78]).
The lower laplacian matrix has shape: torch.Size([78, 78]).


# Defining Labels and Preparing Input
Gathering the edge features and using the second of the two features as a label based on its sign. The first feature will act as the input for the neural net.

In [5]:
xy = []
edges = []
for k, v in dataset.get_simplex_attributes("edge_feat").items():
    xy.append(v)
xy = np.stack(xy)
x_1 = []
y = []

for pair in xy:
    x_1.append([pair[0]])
    if pair[1] > 0:
        y.append([1, 0])
    else:
        y.append([0, 1])
x_1 = np.stack(x_1)
x_1 = torch.tensor(x_1).to(device)
y = torch.tensor(y).to(device)

In [6]:
x_1.shape

torch.Size([78, 1])

# Train/Test Split
We split the labels into test and train sets keeping the indices to be able to calculate loss later.

In [7]:
test_size = 0.2
indices = np.arange(78)
y_train, y_test, train_indices, test_indices = train_test_split(
    y, indices, test_size=test_size, shuffle=True
)

# Creating Neural Network
Creating a stacked neural network that uses the SCoNeLayer class. The linear layer at the end produces an output of shape n_{nodes} x 2 so we can compare to our binary labels.

In [8]:
class SCoNeNN(torch.nn.Module):
    """Neural network implementation of classification using SCoNe.

    Parameters
    ---------
    channels : int
        Dimension of features.
    n_layers : int
        Amount of message passing layers.
    """

    def __init__(self, channels, n_layers=2):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                SCoNeLayer(
                    channels=channels,
                )
            )
        self.linear = torch.nn.Linear(channels, 2)
        self.layers = layers

    def forward(self, x_1, up_lap1, down_lap1, iden):
        """Forward computation.

        Parameters
        ---------
        x_0 : tensor
            shape = [n_nodes, channels]
            Node features.

        up_lap1 : tensor
            shape = [n_edges, n_edges]
            Upper Laplacian matrix of rank 1.

        down_lap1 : tensor
            shape = [n_edges, n_edges]
            Laplacian matrix (down) of rank 1.

        Returns
        --------
        _ : tensor
            shape = [n_nodes, 2]
            One-hot labels assigned to nodes.
        """

        for layer in self.layers:
            x_1 = layer(x_1, up_lap1, down_lap1, iden)
        x_1 = self.linear(x_1)
        return torch.softmax(x_1, dim=-1)

# Training the Neural Network

In [9]:
edge_channels = 1
model = SCoNeNN(
    channels=edge_channels,
    n_layers=10,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.4)

Looping and training over data for a low number of epochs.

In [10]:
test_interval = 2
num_epochs = 6
# WHAT ARE YOU INPUTTING/ MAKE SURE OF PROPER SIZES
# edge classification
dim = 78
iden = torch.eye(dim).to_sparse()

for epoch in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(x_1, up_lap1, down_lap1, iden)

    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[train_indices].float(), y_train.float()
    )

    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[train_indices] == y_train).all(dim=1).float().mean().item()

    print(
        f"Epoch: {epoch} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )

    if epoch % test_interval == 0:
        with torch.no_grad():
            y_hat_test = torch.tensor(
                model(x_1, up_lap1, down_lap1, iden), dtype=torch.float
            ).to(device)
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            test_accuracy = (
                torch.eq(y_pred_test[test_indices], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

Epoch: 1 loss: 0.7211 Train_acc: 0.5323
Epoch: 2 loss: 0.7185 Train_acc: 0.5484
Test_acc: 0.4375
Epoch: 3 loss: 0.7205 Train_acc: 0.5484
Epoch: 4 loss: 0.7210 Train_acc: 0.5484
Test_acc: 0.4375
Epoch: 5 loss: 0.7205 Train_acc: 0.5484
Epoch: 6 loss: 0.7194 Train_acc: 0.5484
Test_acc: 0.3750


  y_hat_test = torch.tensor(
