# Train a Simplicial Complex Convolutional Network (SCCN)

*TODO: more explanation, also in later cells. Refer to hsn example notebook*

We train the model to perform binary node classification using the KarateClub benchmark dataset. 

The equations of one layer of this neural network are given by:

*TODO*

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [1]:
import torch
import numpy as np

from toponetx import SimplicialComplex
import toponetx.datasets.graph as graph

from topomodelx.nn.simplicial.sccn_layer import SCCNLayer

# Pre-processing

## Import dataset ##

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

In [2]:
dataset = graph.karate_club(complex_type="simplicial", feat_dim=8)
print(dataset)

Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4


In [3]:
max_rank = 3  # There are features up to tetrahedron order in the dataset

## Define neighborhood structures. ##


In [4]:
def sparse_to_torch(X):
    return torch.from_numpy(X.todense()).to_sparse()

incidences = {
    f'rank_{r}': sparse_to_torch(
        dataset.incidence_matrix(rank=r)
    )
    for r in range(1, max_rank + 1)
}

adjacencies = {}
adjacencies['rank_0'] = (sparse_to_torch(dataset.adjacency_matrix(rank=0)) 
                  + torch.eye(dataset.shape[0]).to_sparse())
for r in range(1, max_rank):
    adjacencies[f'rank_{r}'] = (
        sparse_to_torch(
            dataset.adjacency_matrix(rank=r)
            + dataset.coadjacency_matrix(rank=r)
        )
        + 2 * torch.eye(dataset.shape[r]).to_sparse()
    )
adjacencies[f'rank_{max_rank}'] = (sparse_to_torch(dataset.coadjacency_matrix(rank=max_rank)) 
                         + torch.eye(dataset.shape[max_rank]).to_sparse())

for r in range(max_rank + 1):
    print(f"The adjacency matrix H{r} has shape: {adjacencies[f'rank_{r}'].shape}.")
    if r > 0:
        print(f"The incidence matrix B{r} has shape: {incidences[f'rank_{r}'].shape}.")
    

The adjacency matrix H0 has shape: torch.Size([34, 34]).
The adjacency matrix H1 has shape: torch.Size([78, 78]).
The incidence matrix B1 has shape: torch.Size([34, 78]).
The adjacency matrix H2 has shape: torch.Size([45, 45]).
The incidence matrix B2 has shape: torch.Size([78, 45]).
The adjacency matrix H3 has shape: torch.Size([11, 11]).
The incidence matrix B3 has shape: torch.Size([45, 11]).


  self._set_arrayXarray(i, j, x)


## Import signal ##

Since our task will be node classification, we must retrieve an input signal on the nodes. The signal will have shape $n_\text{nodes} \times$ in_channels, where in_channels is the dimension of each cell's feature. Here, we have in_channels = channels_nodes.

In [21]:
x_0 = []
for _, v in dataset.get_simplex_attributes("node_feat").items():
    x_0.append(v)
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]

In [22]:
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

There are 34 nodes with features of dimension 8.


Load edge features.

In [23]:
x_1 = []
for k, v in dataset.get_simplex_attributes("edge_feat").items():
    x_1.append(v)
x_1 = torch.tensor(np.stack(x_1))

In [24]:
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

There are 78 edges with features of dimension 8.


Similarly for face features:

In [25]:
x_2 = []
for k, v in dataset.get_simplex_attributes("face_feat").items():
    x_2.append(v)
x_2 = torch.tensor(np.stack(x_2))

In [26]:
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")

There are 45 faces with features of dimension 8.


Higher order features:

In [27]:
x_3 = []
for k, v in dataset.get_simplex_attributes("tetrahedron_feat").items():
    x_3.append(v)
x_3 = torch.tensor(np.stack(x_3))

In [28]:
print(f"There are {x_3.shape[0]} tetrahedrons with features of dimension {x_3.shape[1]}.")

There are 11 tetrahedrons with features of dimension 8.


In [35]:
features = {'rank_0': x_0, 'rank_1': x_1, 'rank_2': x_2, 'rank_3': x_3}

## Define binary labels
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert the binary labels into one-hot encoder form, and keep the last four nodes' true labels for the purpose of testing.

In [30]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_test = y_true[-4:]
y_train = y_true[:30]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# Create the Neural Network

Using the HSNLayer class, we create a neural network with stacked layers. A linear layer at the end produces an output with shape $n_\text{nodes} \times 2$, so we can compare with our binary labels.

In [16]:
class SCCN(torch.nn.Module):
    """Simplicial Complex Convolutional Network Implementation for binary node classification.

    Parameters
    ---------
    channels : int
        Dimension of features
    n_layers : int
        Amount of message passing layers.

    """

    def __init__(self, channels, max_rank, n_layers=2, n_classes=2):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                SCCNLayer(
                    channels=channels,
                    max_rank=max_rank,
                )
            )
        self.linear = torch.nn.Linear(channels, n_classes)
        self.layers = torch.nn.ModuleList(layers)

    def forward(self, features, incidences, adjacencies):
        """Forward computation.

        Parameters
        ---------
        TODO: same as in individual layers

        Returns
        --------
        _ : tensor
            shape = [n_nodes, 2]
            One-hot labels assigned to nodes.

        """
        for layer in self.layers:
            features = layer(features, incidences, adjacencies)
        logits = self.linear(features['rank_0'])
        return logits

# Train the Neural Network

We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [17]:
model = SCCN(
    channels=channels_nodes,
    max_rank=max_rank,
    n_layers=3,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

The following cell performs the training, looping over the network for a low number of epochs.

In [20]:
test_interval = 100
num_epochs = 500
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(features, incidences, adjacencies)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[-len(y_train) :] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(features, incidences, adjacencies)
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            test_accuracy = (
                torch.eq(y_pred_test[-len(y_test) :], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

Epoch: 1 loss: 0.0000 Train_acc: 0.7000
Epoch: 2 loss: 0.0000 Train_acc: 0.7000
Epoch: 3 loss: 0.0000 Train_acc: 0.7000
Epoch: 4 loss: 0.0000 Train_acc: 0.7000
Epoch: 5 loss: 0.0000 Train_acc: 0.7000
Epoch: 6 loss: 0.0000 Train_acc: 0.7000
Epoch: 7 loss: 0.0000 Train_acc: 0.7000
Epoch: 8 loss: 0.0000 Train_acc: 0.7000
Epoch: 9 loss: 0.0000 Train_acc: 0.7000
Epoch: 10 loss: 0.0000 Train_acc: 0.7000
Epoch: 11 loss: 0.0000 Train_acc: 0.7000
Epoch: 12 loss: 0.0000 Train_acc: 0.7000
Epoch: 13 loss: 0.0000 Train_acc: 0.7000
Epoch: 14 loss: 0.0000 Train_acc: 0.7000
Epoch: 15 loss: 0.0000 Train_acc: 0.7000
Epoch: 16 loss: 0.0000 Train_acc: 0.7000
Epoch: 17 loss: 0.0000 Train_acc: 0.7000
Epoch: 18 loss: 0.0000 Train_acc: 0.7000
Epoch: 19 loss: 0.0000 Train_acc: 0.7000
Epoch: 20 loss: 0.0000 Train_acc: 0.7000
Epoch: 21 loss: 0.0000 Train_acc: 0.7000
Epoch: 22 loss: 0.0000 Train_acc: 0.7000
Epoch: 23 loss: 0.0000 Train_acc: 0.7000
Epoch: 24 loss: 0.0000 Train_acc: 0.7000
Epoch: 25 loss: 0.0000 Tr