# Train a Convolutional Cell Complex Network (CCXN)

We create and train a simplified version of the CCXN originally proposed in [Hajij et. al : Cell Complex Neural Networks (2020)](https://arxiv.org/pdf/2010.00743.pdf).

### The Neural Network:

The equations of one layer of this neural network are given by:

1. A convolution from nodes to nodes using an adjacency message passing scheme (AMPS):

🟥 $\quad m_{y \rightarrow \{z\} \rightarrow x}^{(0 \rightarrow 1 \rightarrow 0)} = M_{\mathcal{L}_\uparrow}^t(h_x^{t,(0)}, h_y^{t,(0)}, \Theta^{t,(y \rightarrow x)})$ 

🟧 $\quad m_x^{(0 \rightarrow 1 \rightarrow 0)} = AGG_{y \in \mathcal{L}_\uparrow(x)}(m_{y \rightarrow \{z\} \rightarrow x}^{0 \rightarrow 1 \rightarrow 0})$ 

🟩 $\quad m_x^{(0)} = m_x^{(0 \rightarrow 1 \rightarrow 0)}$ 

🟦 $\quad h_x^{t+1,(0)} = U^{t}(h_x^{t,(0)}, m_x^{(0)})$

2. A convolution from edges to faces using a cohomology message passing scheme:

🟥 $\quad m_{y \rightarrow x}^{(r' \rightarrow r)} = M^t_{\mathcal{C}}(h_{x}^{t,(r)}, h_y^{t,(r')}, x, y)$ 

🟧 $\quad m_x^{(r' \rightarrow r)}  = AGG_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(r' \rightarrow r)}$ 

🟩 $\quad m_x^{(r)} = m_x^{(r' \rightarrow r)}$ 

🟦 $\quad h_{x}^{t+1,(r)} = U^{t,(r)}(h_{x}^{t,(r)}, m_{x}^{(r)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

### The Task:

We train this model to perform entire complex classification on [`MUTAG` from the TUDataset](https://paperswithcode.com/dataset/mutag). This dataset contains:
- 188 samples of chemical compounds represented as graphs,
- with 7 discrete node features.

The task is to predict the mutagenicity of each compound on Salmonella typhimurium.

# Set-up


In [1]:
import random
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from toponetx import CombinatorialComplex
from toponetx.datasets.mesh import shrec_16
from torch_geometric.utils.convert import to_networkx
from topomodelx.nn.cell.ccxn_layer import CCXNLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

We import a subset of MUTAG, a benchmark dataset for graph classification.

We then lift each graph into our topological domain of choice, here: a cell complex.

We also retrieve:
- input signals `x_0` and `x_1` on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model's inputs,
- a binary classification label `y` associated to the cell complex.

In [3]:
dataset = TUDataset(
    root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True, use_node_attr=True
)
dataset = dataset[:20]
cc_list = []
x_0_list = []
x_1_list = []
y_list = []
for graph in dataset:
    cell_complex = CombinatorialComplex(to_networkx(graph))
    cc_list.append(cell_complex)
    x_0_list.append(graph.x)
    x_1_list.append(graph.edge_attr)
    y_list.append(int(graph.y))

i_cc = 0
print(f"Features on nodes for the {i_cc}th cell complex: {x_0_list[i_cc].shape}.")
print(f"Features on edges for the {i_cc}th cell complex: {x_1_list[i_cc].shape}.")
print(f"Label of {i_cc}th cell complex: {y_list[i_cc]}.")

unzipping the files...

Loading dataset...

done!


In [4]:
shrec_training, shrec_testing = shrec_16()

# training dataset
training_complexes = shrec_training["complexes"]
training_labels = shrec_training["label"]
training_node_feat = shrec_training["node_feat"]
training_edge_feat = shrec_training["edge_feat"]
training_face_feat = shrec_training["face_feat"]

# testing dataset
testing_complexes = shrec_testing["complexes"]
testing_labels = shrec_testing["label"]
testing_node_feat = shrec_testing["node_feat"]
testing_edge_feat = shrec_testing["edge_feat"]
testing_face_feat = shrec_testing["face_feat"]

In [5]:
"""
Node features:
    - Position
    - Normal
"""
import pyvista as pv
mesh = training_complexes[0].to_trimesh()
pv_mesh = pv.PolyData(mesh.vertices, mesh.faces)
pv_mesh.plot()

array([[ 0.819765  , -0.316374  ,  0.2162    ,  0.8503737 , -0.28702822,
         0.44099816],
       [ 0.803785  , -0.319039  ,  0.134566  ,  0.59733338, -0.38828796,
        -0.70173022],
       [ 0.763605  , -0.199287  ,  0.171764  ,  0.7594386 ,  0.58320488,
        -0.28831422],
       ...,
       [-0.779561  ,  0.136582  ,  0.158549  , -0.66767593, -0.44223474,
         0.59886333],
       [-0.821902  ,  0.245018  ,  0.144421  , -0.78971004,  0.28455973,
         0.54349224],
       [-0.802178  ,  0.129943  ,  0.042674  , -0.86820958, -0.45792179,
        -0.19110143]])

In [6]:
training_complexes[0].to_trimesh().show()

array([[1.42782995, 0.15384225, 0.62795728, ..., 1.77421153, 1.55949484,
        0.89653369],
       [0.40949755, 0.17308214, 0.70043063, ..., 2.93316508, 1.63301163,
        0.65236704],
       [1.27168475, 0.17205171, 0.47988115, ..., 2.57336372, 1.55949484,
        1.19406235],
       ...,
       [0.81137404, 0.17767169, 0.86427246, ..., 1.70529674, 0.75708422,
        1.49625958],
       [0.42962673, 0.219632  , 1.11560732, ..., 1.82331083, 2.07733926,
        0.61800127],
       [0.88637135, 0.21321849, 0.57161096, ..., 2.37766529, 1.43387299,
        2.04679451]])

In [7]:
"""
Edge features:
    - Dihedral angle
    - Edge span
    - 2 edge angle in the triangle
    - 6 edge ratios
"""

training_edge_feat[0]

array([[ 1.50264134e-03,  9.11383142e-01,  3.64910901e-01, ...,
         1.19306589e+00,  1.32056948e+00,  6.27957277e-01],
       [ 8.20956861e-04,  4.90164562e-01, -8.69006027e-01, ...,
         8.22450325e-01,  1.31123620e+00,  1.00790613e+00],
       [ 1.92993056e-03,  8.53634969e-01,  4.83226000e-01, ...,
         5.19895316e-01,  1.92126670e+00,  7.00430634e-01],
       ...,
       [ 1.91996648e-03, -9.25504404e-01, -3.39513148e-01, ...,
         8.65261466e-01,  8.64272463e-01,  1.41205872e+00],
       [ 2.91073733e-03, -7.61350157e-01, -6.21618281e-01, ...,
         1.38266651e+00,  6.43318827e-01,  1.11560732e+00],
       [ 1.37702838e-03, -9.35632685e-01, -2.91318121e-01, ...,
         1.68782078e+00,  5.71610958e-01,  8.82160917e-01]])

In [8]:
import pyvista as pv



In [13]:
"""
Face features:
    - Face area
    - Face normal
    - Face angle
"""

training_face_feat[0]

<252x252 sparse matrix of type '<class 'numpy.int64'>'
	with 1500 stored elements in Compressed Sparse Column format>

## Define neighborhood structures. ##

Implementing the CCXN architecture will require to perform message passing along neighborhood structures of the cell complexes.

Thus, now we retrieve these neighborhood structures (i.e. their representative matrices) that we will use to send messages. 

For the CCXN, we need the adjacency matrix $A_{\uparrow, 0}$ and the coboundary matrix $B_2^T$ of each cell complex.

In [89]:
incidence_2_t_list = []
adjacency_0_list = []
for cell_complex in cc_list:
    incidence_2_t = cell_complex.incidence_matrix(rank=2).T
    adjacency_0 = cell_complex.adjacency_matrix(rank=0)
    incidence_2_t = torch.from_numpy(incidence_2_t.todense()).to_sparse()
    adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()
    incidence_2_t_list.append(incidence_2_t)
    adjacency_0_list.append(adjacency_0)

i_cc = 0
print(f"Incidence_2_t of the {i_cc}-th complex: {incidence_2_t_list[i_cc].shape}.")
print(f"Adjacency of the {i_cc}-th complex: {adjacency_0_list[i_cc].shape}.")

Incidence_2_t of the 0-th complex: torch.Size([500, 750]).
Adjacency of the 0-th complex: torch.Size([252, 252]).


# Create the Neural Network

Using the CCXNLayer class, we create a neural network with stacked layers.

In [79]:
in_channels_0 = x_0_list[0].shape[-1]
in_channels_1 = x_1_list[0].shape[-1]
in_channels_2 = 5
print(
    f"The dimension of input features on nodes, edges and faces are: {in_channels_0}, {in_channels_1} and {in_channels_2}."
)

tensor([[0, 1, 1,  ..., 0, 0, 0],
        [1, 0, 1,  ..., 0, 0, 0],
        [1, 1, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 1, 1],
        [0, 0, 0,  ..., 1, 0, 1],
        [0, 0, 0,  ..., 1, 1, 0]])

In [35]:
class CCXN(torch.nn.Module):
    """CCXN.

    Parameters
    ----------
    in_channels_0 : int
        Dimension of input features on nodes.
    in_channels_1 : int
        Dimension of input features on edges.
    in_channels_2 : int
        Dimension of input features on faces.
    num_classes : int
        Number of classes.
    n_layers : int
        Number of CCXN layers.
    att : bool
        Whether to use attention.
    """

    def __init__(
        self,
        in_channels_0,
        in_channels_1,
        in_channels_2,
        num_classes,
        n_layers=2,
        att=False,
    ):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                CCXNLayer(
                    in_channels_0=in_channels_0,
                    in_channels_1=in_channels_1,
                    in_channels_2=in_channels_2,
                    att=att,
                )
            )
        self.layers = torch.nn.ModuleList(layers)
        self.lin_0 = torch.nn.Linear(in_channels_0, num_classes)
        self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)
        self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)

    def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
        """Forward computation through layers, then linear layers, then avg pooling.

        Parameters
        ----------
        x_0 : torch.Tensor, shape = [n_nodes, in_channels_0]
            Input features on the nodes (0-cells).
        x_1 : torch.Tensor, shape = [n_edges, in_channels_1]
            Input features on the edges (1-cells).
        neighborhood_0_to_0 : tensor, shape = [n_nodes, n_nodes]
            Adjacency matrix of rank 0 (up).
        neighborhood_1_to_2 : tensor, shape = [n_faces, n_edges]
            Transpose of boundary matrix of rank 2.
        x_2 : torch.Tensor, shape = [n_faces, in_channels_2]
            Input features on the faces (2-cells).
            Optional. Use for attention mechanism between edges and faces.

        Returns
        -------
        _ : tensor, shape = [1]
            Label assigned to whole complex.
        """
        for layer in self.layers:
            x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)
        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
        # Return the sum of the averages
        return two_dimensional_cells_mean + one_dimensional_cells_mean + zero_dimensional_cells_mean

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

In [36]:
model = CCXN(in_channels_0, in_channels_1, in_channels_2, num_classes=2, n_layers=2)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)

We split the dataset into train and test sets.

In [37]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0_list, test_size=test_size, shuffle=False)
x_1_train, x_1_test = train_test_split(x_1_list, test_size=test_size, shuffle=False)
incidence_2_t_train, incidence_2_t_test = train_test_split(
    incidence_2_t_list, test_size=test_size, shuffle=False
)
adjacency_0_train, adjacency_0_test = train_test_split(
    adjacency_0_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=test_size, shuffle=False)

We train the CCXN using low amount of epochs: we keep training minimal for the purpose of rapid testing.

In [38]:
test_interval = 2
num_epochs = 4
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
        x_0_train, x_1_train, incidence_2_t_train, adjacency_0_train, y_train
    ):
        x_0, x_1, y = x_0.float().to(device), x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(device)
        incidence_2_t, adjacency_0 = incidence_2_t.float().to(device), adjacency_0.float().to(device)
        opt.zero_grad()
        y_hat = model(x_0, x_1, adjacency_0, incidence_2_t)
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
                x_0_test, x_1_test, incidence_2_t_test, adjacency_0_test, y_test
            ):
                x_0, x_1, y = x_0.float().to(device), x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(device)
                incidence_2_t, adjacency_0 = incidence_2_t.float().to(device), adjacency_0.float().to(device)
                y_hat = model(x_0, x_1, adjacency_0, incidence_2_t)
                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

Epoch: 1 loss: 1.5594 Train_acc: 0.3750
Epoch: 2 loss: 0.6690 Train_acc: 0.7500
Test_acc: 0.7500
Epoch: 3 loss: 0.5635 Train_acc: 0.6875
Epoch: 4 loss: 0.6332 Train_acc: 0.7500
Test_acc: 0.7500


# Train the Neural Network with Attention


Now we create a new neural network, that uses the attention mechanism.

In [39]:
model = CCXN(
    in_channels_0, in_channels_1, in_channels_2, num_classes=2, n_layers=2, att=True
)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)

We run the training for this neural network:

In [82]:
test_interval = 2
for epoch_i in range(1, 5):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
        x_0_train, x_1_train, incidence_2_t_train, adjacency_0_train, y_train
    ):
        x_0, x_1, y = x_0.float().to(device), x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(device)
        incidence_2_t, adjacency_0 = incidence_2_t.float().to(device), adjacency_0.float().to(device)
        opt.zero_grad()
        y_hat = model(x_0, x_1, adjacency_0, incidence_2_t)
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
                x_0_test, x_1_test, incidence_2_t_test, adjacency_0_test, y_test
            ):
                x_0, x_1, y = x_0.float().to(device), x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(device)
                incidence_2_t, adjacency_0 = incidence_2_t.float().to(device), adjacency_0.float().to(device)
                y_hat = model(x_0, x_1, adjacency_0, incidence_2_t)

                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)