# Train a Simplicial 2-complex convolutional neural network (SCConv)


In this notebook, we will create and train a Simplicial 2-complex convolutional neural in the simplicial complex domain, as proposed in the paper by [Bunch et. al : Simplicial 2-Complex Convolutional Neural Networks (2020)](https://openreview.net/pdf?id=Sc8glB-k6e9).


We train the model to perform

The equations of one layer of this neural network are given by:

🟥 $\quad m_{y\rightarrow x}^{(0\rightarrow 0)} = ({\tilde{A}_{\uparrow,0}})_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0\rightarrow0)}$

🟥 $\quad m^{(1\rightarrow0)}_{y\rightarrow x}  = (B_1)_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(1\rightarrow 0)}$

🟥 $\quad m^{(0 \rightarrow 1)}_{y \rightarrow x}  = (\tilde B_1)_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0 \rightarrow1)}$

🟥 $\quad m^{(1\rightarrow1)}_{y\rightarrow x} = ({\tilde{A}_{\downarrow,1}} + {\tilde{A}_{\uparrow,1}})_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,(1\rightarrow1)}$

🟥 $\quad m^{(2\rightarrow1)}_{y \rightarrow x}  = (B_2)_{xy} \cdot h_y^{t,(2)} \cdot \Theta^{t,(2 \rightarrow1)}$

🟥 $\quad m^{(1 \rightarrow 2)}_{y \rightarrow x}  = (\tilde B_2)_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,(1 \rightarrow 2)}$

🟥 $\quad m^{(2 \rightarrow 2)}_{y \rightarrow x}  = ({\tilde{A}_{\downarrow,2}})_{xy} \cdot h_y^{t,(2)} \cdot \Theta^{t,(2 \rightarrow 2)}$

🟧 $\quad m_x^{(0 \rightarrow 0)}  = \sum_{y \in \mathcal{L}_\uparrow(x)} m_{y \rightarrow x}^{(0 \rightarrow 0)}$

🟧 $\quad m_x^{(1 \rightarrow 0)}  = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(1 \rightarrow 0)}$

🟧 $\quad m_x^{(0 \rightarrow 1)}  = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(0 \rightarrow 1)}$

🟧 $\quad m_x^{(1 \rightarrow 1)}  = \sum_{y \in (\mathcal{L}_\uparrow(x) + \mathcal{L}_\downarrow(x))} m_{y \rightarrow x}^{(1 \rightarrow 1)}$

🟧 $\quad m_x^{(2 \rightarrow 1)} = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(2 \rightarrow 1)}$

🟧 $\quad m_x^{(1 \rightarrow 2)}  = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(1 \rightarrow 2)}$

🟧 $\quad m_x^{(2 \rightarrow 2)}  = \sum_{y \in \mathcal{L}_\downarrow(x)} m_{y \rightarrow x}^{(2 \rightarrow 2)}$

🟩 $\quad m_x^{(0)}  = m_x^{(1\rightarrow0)}+ m_x^{(0\rightarrow0)}$

🟩 $\quad m_x^{(1)}  = m_x^{(2\rightarrow1)}+ m_x^{(1\rightarrow1)}$

🟦 $\quad h^{t+1, (0)}_x  = \sigma(m_x^{(0)})$

🟦 $\quad h^{t+1, (1)}_x  = \sigma(m_x^{(1)})$

🟦 $\quad h^{t+1, (2)}_x  = \sigma(m_x^{(2)})$


Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [52]:
import numpy as np
import toponetx.datasets.graph as graph
import torch
from scipy.sparse import coo_matrix, diags

from topomodelx.nn.simplicial.scconv import SCConv
from topomodelx.utils.sparse import from_sparse

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [53]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import dataset ##

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

In [54]:
dataset = graph.karate_club(complex_type="simplicial")
print(dataset)

Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4


In [55]:
dataset.shape

(34, 78, 45, 11, 2)

# Define Neighbourhood Structures

We create the neigborood structures expected by SSConv. The SSConv layer expects the following neighbourhood structures:
* incidence_1 $B_1$
* incidence_1_norm $\tilde{B}_1$
* incidence_2 $B_2$
* incidence_2_norm $\tilde{B}_1$
* adjacency_up_0_norm $\tilde{A}_{\uparrow,0}$
* adjacency_up_1_norm $\tilde{A}_{\uparrow,1}$
* adjacency_down_1_norm $\tilde{A}_{\downarrow,1}$
* adjacency_down_2_norm $\tilde{A}_{\downarrow,2}$

In [56]:
# Not working, it needs to be reviewed
def normalize_higher_order_adj(A_opt):
    """
    Args:
        A_opt is an opt that maps a j-cochain to a k-cochain.
        shape [num_of_k_simplices num_of_j_simplices]

    return:
         D^{-0.5}* (A_opt)* D^{-0.5}.
    """
    rowsum = np.array(np.abs(A_opt).sum(1))
    r_inv_sqrt = np.power(rowsum, -0.5).flatten()
    r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.0
    r_mat_inv_sqrt = diags(r_inv_sqrt)
    A_opt_to = A_opt.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt)

    return coo_matrix(A_opt_to)

In [78]:
incidence_1_norm = incidence_0_1 = from_sparse(dataset.incidence_matrix(1))
incidence_1 = incidence_1_0 = from_sparse(dataset.coincidence_matrix(1))
incidence_2_norm = incidence_1_2 = from_sparse(dataset.incidence_matrix(2))
incidence_2 = incidence_2_1 = from_sparse(dataset.coincidence_matrix(2))
adjacency_up_0_norm = adjacency_0 = from_sparse(dataset.up_laplacian_matrix(0))
adjacency_up_1_norm = adjacency_1_up = from_sparse(dataset.up_laplacian_matrix(1))
adjacency_down_1_norm = adjacency_1_down = from_sparse(dataset.down_laplacian_matrix(1))
adjacency_down_2_norm = adjacency_2 = from_sparse(dataset.down_laplacian_matrix(2))

## Import signal ##

We retrieve an input signal on the nodes, edges and faces. The signal will have shape $n_\text{simplicial} \times$ in_channels, where in_channels is the dimension of each simplicial's feature. Here, we have in_channels = channels_nodes $ = 2$. 

In [89]:
x_0 = list(dataset.get_simplex_attributes("node_feat").values())
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

x_1 = list(dataset.get_simplex_attributes("edge_feat").values())
x_1 = torch.tensor(np.stack(x_1))
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

x_2 = list(dataset.get_simplex_attributes("face_feat").values())
x_2 = torch.tensor(np.stack(x_2))
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")

There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.


We also pre-define the number output channels of the model, in this case the number of node classes.

In [90]:
in_channels = x_0.shape[-1]
out_channels = 2

## Define binary labels
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert one-hot encode the binary labels, and keep the first four nodes for the purpose of testing.

In [91]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_train = y_true[:30]
y_test = y_true[-4:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# Create the Neural Network

Using the SAN class, we create our neural network with stacked layers. Given the considered dataset and task (Karate Club, node classification), a linear layer at the end produces an output with shape $n_\text{nodes} \times 2$, so we can compare with our binary labels.

In [99]:
class Network(torch.nn.Module):
    def __init__(self, in_channels, out_channels, n_layers=1):
        super().__init__()
        self.base_model = SCConv(
            node_channels=in_channels,
            n_layers=n_layers,
        )
        self.linear_x0 = torch.nn.Linear(in_channels, out_channels)
        self.linear_x1 = torch.nn.Linear(in_channels, out_channels)
        self.linear_x2 = torch.nn.Linear(in_channels, out_channels)

    def forward(
        self,
        x_0,
        x_1,
        x_2,
        incidence_1,
        incidence_1_norm,
        incidence_2,
        incidence_2_norm,
        adjacency_up_0_norm,
        adjacency_up_1_norm,
        adjacency_down_1_norm,
        adjacency_down_2_norm,
    ):
        x_0, x_1, x_2 = self.base_model(
            x_0,
            x_1,
            x_2,
            incidence_1,
            incidence_1_norm,
            incidence_2,
            incidence_2_norm,
            adjacency_up_0_norm,
            adjacency_up_1_norm,
            adjacency_down_1_norm,
            adjacency_down_2_norm,
        )
        x_0 = self.linear_x0(x_0)
        x_1 = self.linear_x1(x_1)
        x_2 = self.linear_x2(x_2)
        return torch.softmax(x_0, dim=1)

In [100]:
n_layers = 1
model = Network(
    in_channels=in_channels,
    out_channels=out_channels,
    n_layers=n_layers,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# Train the Neural Network

The following cell performs the training, looping over the network for a low number of epochs.

In [102]:
test_interval = 10
num_epochs = 200
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(
        x_0,
        x_1,
        x_2,
        incidence_1,
        incidence_1_norm,
        incidence_2,
        incidence_2_norm,
        adjacency_up_0_norm,
        adjacency_up_1_norm,
        adjacency_down_1_norm,
        adjacency_down_2_norm,
    )
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[: len(y_train)] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(
                x_0,
                x_1,
                x_2,
                incidence_1,
                incidence_1_norm,
                incidence_2,
                incidence_2_norm,
                adjacency_up_0_norm,
                adjacency_up_1_norm,
                adjacency_down_1_norm,
                adjacency_down_2_norm,
            )
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            test_accuracy = (
                torch.eq(y_pred_test[-len(y_test) :], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

Epoch: 1 loss: 0.7134 Train_acc: 0.5667
Epoch: 2 loss: 0.7133 Train_acc: 0.5667
Epoch: 3 loss: 0.7132 Train_acc: 0.5667
Epoch: 4 loss: 0.7131 Train_acc: 0.5667
Epoch: 5 loss: 0.7128 Train_acc: 0.5667
Epoch: 6 loss: 0.7125 Train_acc: 0.5667
Epoch: 7 loss: 0.7120 Train_acc: 0.5667
Epoch: 8 loss: 0.7117 Train_acc: 0.5667
Epoch: 9 loss: 0.7113 Train_acc: 0.5667
Epoch: 10 loss: 0.7109 Train_acc: 0.5667
Test_acc: 0.0000
Epoch: 11 loss: 0.7105 Train_acc: 0.5667
Epoch: 12 loss: 0.7099 Train_acc: 0.5667
Epoch: 13 loss: 0.7093 Train_acc: 0.5667
Epoch: 14 loss: 0.7086 Train_acc: 0.5667
Epoch: 15 loss: 0.7079 Train_acc: 0.5667


Epoch: 16 loss: 0.7070 Train_acc: 0.5667
Epoch: 17 loss: 0.7062 Train_acc: 0.5667
Epoch: 18 loss: 0.7052 Train_acc: 0.5667
Epoch: 19 loss: 0.7042 Train_acc: 0.5667
Epoch: 20 loss: 0.7030 Train_acc: 0.5667
Test_acc: 0.0000
Epoch: 21 loss: 0.7017 Train_acc: 0.5667
Epoch: 22 loss: 0.7003 Train_acc: 0.5667
Epoch: 23 loss: 0.6988 Train_acc: 0.5667
Epoch: 24 loss: 0.6971 Train_acc: 0.5667
Epoch: 25 loss: 0.6954 Train_acc: 0.5667
Epoch: 26 loss: 0.6935 Train_acc: 0.5667
Epoch: 27 loss: 0.6916 Train_acc: 0.5667
Epoch: 28 loss: 0.6894 Train_acc: 0.5667
Epoch: 29 loss: 0.6872 Train_acc: 0.5667
Epoch: 30 loss: 0.6849 Train_acc: 0.5667
Test_acc: 0.5000
Epoch: 31 loss: 0.6824 Train_acc: 0.6000
Epoch: 32 loss: 0.6799 Train_acc: 0.6000
Epoch: 33 loss: 0.6772 Train_acc: 0.6333
Epoch: 34 loss: 0.6745 Train_acc: 0.6333
Epoch: 35 loss: 0.6716 Train_acc: 0.6667
Epoch: 36 loss: 0.6687 Train_acc: 0.7000
Epoch: 37 loss: 0.6657 Train_acc: 0.7333
Epoch: 38 loss: 0.6626 Train_acc: 0.7333
Epoch: 39 loss: 0.6596 