# Train a Convolutional Cell Complex Network (CCXN)

We create and train a simplified version of the CCXN originally proposed in [Hajij et. al : Cell Complex Neural Networks (2020)](https://arxiv.org/pdf/2010.00743.pdf).

### The Neural Network:

The equations of one layer of this neural network are given by:

1. A convolution from nodes to nodes using an adjacency message passing scheme (AMPS):

🟥 $\quad m_{y \rightarrow \{z\} \rightarrow x}^{(0 \rightarrow 1 \rightarrow 0)} = M_{\mathcal{L}_\uparrow}^t(h_x^{t,(0)}, h_y^{t,(0)}, \Theta^{t,(y \rightarrow x)})$ 

🟧 $\quad m_x^{(0 \rightarrow 1 \rightarrow 0)} = AGG_{y \in \mathcal{L}_\uparrow(x)}(m_{y \rightarrow \{z\} \rightarrow x}^{0 \rightarrow 1 \rightarrow 0})$ 

🟩 $\quad m_x^{(0)} = m_x^{(0 \rightarrow 1 \rightarrow 0)}$ 

🟦 $\quad h_x^{t+1,(0)} = U^{t}(h_x^{t,(0)}, m_x^{(0)})$

2. A convolution from edges to faces using a cohomology message passing scheme:

🟥 $\quad m_{y \rightarrow x}^{(r' \rightarrow r)} = M^t_{\mathcal{C}}(h_{x}^{t,(r)}, h_y^{t,(r')}, x, y)$ 

🟧 $\quad m_x^{(r' \rightarrow r)}  = AGG_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(r' \rightarrow r)}$ 

🟩 $\quad m_x^{(r)} = m_x^{(r' \rightarrow r)}$ 

🟦 $\quad h_{x}^{t+1,(r)} = U^{t,(r)}(h_{x}^{t,(r)}, m_{x}^{(r)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

### The Task:

We train this model to perform entire complex classification on [`MUTAG` from the TUDataset](https://paperswithcode.com/dataset/mutag). This dataset contains:
- 188 samples of chemical compounds represented as graphs,
- with 7 discrete node features.

The task is to predict the mutagenicity of each compound on Salmonella typhimurium.

# Set-up


In [37]:
import torch
import numpy as np
import torch.nn as nn
from toponetx import CombinatorialComplex
from toponetx.datasets.mesh import shrec_16
from sklearn.model_selection import train_test_split
from topomodelx.nn.combinatorial.hmc_layer import HMCLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

We import a subset of MUTAG, a benchmark dataset for graph classification.

We then lift each graph into our topological domain of choice, here: a cell complex.

We also retrieve:
- input signals `x_0` and `x_1` on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model's inputs,
- a binary classification label `y` associated to the cell complex.

In [117]:
shrec_training, shrec_testing = shrec_16()

# training dataset
training_complexes = [
    cc.to_combinatorial_complex() for cc in shrec_training["complexes"]
]
training_labels = shrec_training["label"]
training_node_feat = shrec_training["node_feat"]
training_edge_feat = shrec_training["edge_feat"]
training_face_feat = shrec_training["face_feat"]

# testing dataset
testing_complexes = [cc.to_combinatorial_complex() for cc in shrec_testing["complexes"]]
testing_labels = shrec_testing["label"]
testing_node_feat = shrec_testing["node_feat"]
testing_edge_feat = shrec_testing["edge_feat"]
testing_face_feat = shrec_testing["face_feat"]

unzipping the files...

done!
Loading dataset...

done!
<class 'toponetx.classes.simplicial_complex.SimplicialComplex'>


In [123]:
training_complexes[0].incidence_matrix(
    rank=2, to_rank=None, incidence_type="down"
).todense()
# what is going on with self.skeleton??

TypeError: CombinatorialComplex.skeleton() got an unexpected keyword argument 'level'

In [110]:
"""
Node features:
    - Position
    - Normal
"""

training_node_feat[0]

array([[ 0.567542  ,  0.570995  , -0.210023  , -0.06894332,  0.72564355,
         0.6846081 ],
       [ 0.603908  ,  0.557968  , -0.223298  ,  0.91155493,  0.37192256,
         0.17533174],
       [ 0.591494  ,  0.464737  ,  0.350567  ,  0.33390166,  0.55019138,
         0.76537515],
       ...,
       [-0.331773  , -0.33001   ,  0.049767  , -0.28529544, -0.95405566,
        -0.09156586],
       [-0.401883  , -0.263047  ,  0.136766  , -0.66302596, -0.62547973,
         0.41130485],
       [ 0.261249  ,  0.128185  ,  0.304997  ,  0.7728916 , -0.44845904,
         0.44891321]])

In [111]:
"""
Edge features:
    - Dihedral angle
    - Edge span
    - 2 edge angle in the triangle
    - 6 edge ratios
"""

training_edge_feat[0]

array([[1.22957341, 0.13581241, 0.45991269, ..., 2.25198965, 2.89552337,
        2.00058261],
       [1.43636325, 0.14296384, 0.58913082, ..., 1.74326848, 2.89552337,
        0.51764337],
       [0.96944867, 0.11326229, 0.70546737, ..., 2.4000534 , 2.25383585,
        0.468195  ],
       ...,
       [0.39824646, 0.12369258, 0.55931499, ..., 0.945035  , 0.59602236,
        4.14184985],
       [0.19510851, 0.11228415, 1.07026534, ..., 3.67055053, 2.41476323,
        0.51704672],
       [0.0621548 , 0.122271  , 1.95575932, ..., 2.90562928, 1.5314299 ,
        0.66267758]])

In [112]:
"""
Face features:
    - Face area
    - Face normal
    - Face angle
"""

training_face_feat[0]

array([[ 6.16026555e-04,  3.42332662e-01,  9.39444198e-01, ...,
         9.55859971e-01,  4.28478956e-01,  1.75725373e+00],
       [ 6.81088138e-04,  3.97494225e-01,  1.96188417e-01, ...,
         4.59912686e-01,  1.33996485e+00,  1.34171512e+00],
       [ 2.02495554e-03, -7.65949901e-01,  4.13444548e-01, ...,
         1.32871011e+00,  1.22375172e+00,  5.89130819e-01],
       ...,
       [ 1.88118714e-03,  1.16577970e-01, -9.84528932e-01, ...,
         3.64363423e-01,  1.93855641e+00,  8.38672818e-01],
       [ 1.50431667e-03, -4.34561991e-01, -8.48251445e-01, ...,
         1.54860901e+00,  5.77114229e-01,  1.01586942e+00],
       [ 1.74908308e-03, -8.97968600e-01, -3.36066581e-01, ...,
         1.04615878e+00,  1.50159627e+00,  5.93837603e-01]])

In [43]:
# training_complexes[0].to_trimesh().show()

## Define neighborhood structures. ##

Implementing the CCXN architecture will require to perform message passing along neighborhood structures of the cell complexes.

Thus, now we retrieve these neighborhood structures (i.e. their representative matrices) that we will use to send messages. 

For the CCXN, we need the adjacency matrix $A_{\uparrow, 0}$ and the coboundary matrix $B_2^T$ of each cell complex.

In [113]:
adjacency_0_train_list = []
adjacency_1_train_list = []
coadjacency_2_train_list = []

incidence_1_train_list = []
incidence_2_train_list = []

for cc in training_complexes:

    adjacency_0_train = torch.from_numpy(
        cc.adjacency_matrix(0, 1).todense()
    ).to_sparse()
    adjacency_1_train = torch.from_numpy(
        cc.adjacency_matrix(1, 2).todense()
    ).to_sparse()
    coadjacency_2_train = torch.from_numpy(
        cc.coadjacency_matrix(2, 1).todense()
    ).to_sparse()

    adjacency_0_train_list.append(adjacency_0_train)
    adjacency_1_train_list.append(adjacency_1_train)
    coadjacency_2_train_list.append(coadjacency_2_train)

    incidence_1_train = torch.from_numpy(
        cc.incidence_matrix(1, 0).todense()
    ).to_sparse()
    incidence_2_train = torch.from_numpy(
        cc.incidence_matrix(2, 1).todense()
    ).to_sparse()

    incidence_1_train_list.append(incidence_1_train)
    incidence_2_train_list.append(incidence_2_train)

In [116]:
# training_complexes[0].coadjacency_matrix(rank=2,via_rank=1).todense()
from scipy.sparse import csr_matrix

B = training_complexes[0].incidence_matrix(
    rank=2, to_rank=None, incidence_type="down", sparse=True
)
print(B.shape)
A = csr_matrix(B.T) @ csr_matrix(B)
print(A.shape)

TypeError: CombinatorialComplex.skeleton() got an unexpected keyword argument 'level'

In [61]:
i_cc = 0
print(f"Adjacency_0 of the {i_cc}-th complex: {adjacency_0_train_list[i_cc].shape}.")
print(f"Adjacency_1 of the {i_cc}-th complex: {adjacency_1_train_list[i_cc].shape}.")
print(
    f"Coadjacency_2 of the {i_cc}-th complex: {coadjacency_2_train_list[i_cc].shape}."
)
print(f"Incidence_1 of the {i_cc}-th complex: {incidence_1_train_list[i_cc].shape}.")
print(f"Incidence_2 of the {i_cc}-th complex: {incidence_2_train_list[i_cc].shape}.")

Adjacency_0 of the 0-th complex: torch.Size([252, 252]).
Adjacency_1 of the 0-th complex: torch.Size([750, 750]).
Coadjacency_2 of the 0-th complex: torch.Size([750, 750]).
Incidence_1 of the 0-th complex: torch.Size([252, 750]).
Incidence_2 of the 0-th complex: torch.Size([750, 500]).


In [46]:
adjacency_0_test_list = []
adjacency_1_test_list = []
coadjacency_2_test_list = []
incidence_1_test_list = []
incidence_2_test_list = []

for cc in testing_complexes:

    adjacency_0_test = torch.from_numpy(cc.adjacency_matrix(0, 1).todense()).to_sparse()
    adjacency_1_test = torch.from_numpy(cc.adjacency_matrix(1, 2).todense()).to_sparse()
    coadjacency_2_test = torch.from_numpy(
        cc.coadjacency_matrix(2, 1).todense()
    ).to_sparse()

    adjacency_0_test_list.append(adjacency_0_test)
    adjacency_1_test_list.append(adjacency_1_test)
    coadjacency_2_test_list.append(coadjacency_2_test)

    incidence_1_test = torch.from_numpy(cc.incidence_matrix(1, 0).todense()).to_sparse()
    incidence_2_test = torch.from_numpy(cc.incidence_matrix(2, 1).todense()).to_sparse()

    incidence_1_test_list.append(incidence_1_test)
    incidence_2_test_list.append(incidence_2_test)

In [47]:
i_cc = 0
print(f"Adjacency_0 of the {i_cc}-th complex: {adjacency_0_test_list[i_cc].shape}.")
print(f"Coadjacency_2 of the {i_cc}-th complex: {coadjacency_2_test_list[i_cc].shape}.")
print(f"Adjacency_1 of the {i_cc}-th complex: {adjacency_1_test_list[i_cc].shape}.")
print(f"Incidence_1 of the {i_cc}-th complex: {incidence_1_test_list[i_cc].shape}.")
print(f"Incidence_2 of the {i_cc}-th complex: {incidence_2_test_list[i_cc].shape}.")

Adjacency_0 of the 0-th complex: torch.Size([252, 252]).
Coadjacency_2 of the 0-th complex: torch.Size([750, 750]).
Adjacency_1 of the 0-th complex: torch.Size([750, 750]).
Incidence_1 of the 0-th complex: torch.Size([252, 750]).
Incidence_2 of the 0-th complex: torch.Size([750, 500]).


# Create the Neural Network

Using the CCXNLayer class, we create a neural network with stacked layers.

In [56]:
d_0 = training_node_feat[0].shape[-1]
d_1 = training_edge_feat[0].shape[-1]
d_2 = training_face_feat[0].shape[-1]

in_channels = [d_0, d_1, d_2]

print(
    f"The dimension of input features on nodes, edges and faces are: {d_0}, {d_1} and {d_2}."
)

The dimension of input features on nodes, edges and faces are: 6, 10 and 7.


In [57]:
class HoanMeshClassifier(torch.nn.Module):
    """HoanMeshClassifier.

    Parameters
    ----------
    in_channels : List[int]
        Dimension of input features on nodes, edges and faces respectively.
    intermediate_channels : List[int]
        Dimension of intermediate features on nodes, edges and faces respectively.
    out_channels : List[int]
        Dimension of output features on nodes, edges and faces respectively.
    num_classes : int
        Number of classes.
    n_layers : int
        Number of CCXN layers.
    """

    def __init__(
        self,
        in_channels,
        intermediate_channels,
        out_channels,
        num_classes,
        negative_slope=0.2,
        n_layers=1,
    ) -> None:

        super().__init__()
        self.num_classes = num_classes
        self.layer = HMCLayer(
            in_channels=in_channels,
            intermediate_channels=intermediate_channels,
            out_channels=out_channels,
            negative_slope=negative_slope,
        )
        """self.layers = torch.nn.ModuleList(
            HMCLayer(
                in_channels=in_channels,
                intermediate_channels=intermediate_channels,
                out_channels=out_channels,
                negative_slope=negative_slope
            ) for _ in range(n_layers)
        )"""

        self.l0 = torch.nn.Linear(out_channels[0], num_classes)
        self.l1 = torch.nn.Linear(out_channels[1], num_classes)
        self.l2 = torch.nn.Linear(out_channels[2], num_classes)

    def forward(
        self,
        x_0,
        x_1,
        x_2,
        neighborhood_0_to_0,
        neighborhood_1_to_1,
        neighborhood_2_to_2,
        neighborhood_0_to_1,
        neighborhood_1_to_2,
    ) -> torch.Tensor:

        x_0, x_1, x_2 = self.layer(
            x_0,
            x_1,
            x_2,
            neighborhood_0_to_0,
            neighborhood_1_to_1,
            neighborhood_2_to_2,
            neighborhood_0_to_1,
            neighborhood_1_to_2,
        )
        """
        for layer in self.layers:
            x_0, x_1, x_2 = layer(x_0,
                x_1,
                x_2,
                neighborhood_0_to_0,
                neighborhood_1_to_1,
                neighborhood_2_to_2,
                neighborhood_0_to_1,
                neighborhood_1_to_2
                                  )"""

        x_0 = self.l0(x_0)
        x_1 = self.l1(x_1)
        x_2 = self.l2(x_2)

        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
        # Return the sum of the averages
        return (
            two_dimensional_cells_mean
            + one_dimensional_cells_mean
            + zero_dimensional_cells_mean
        )

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

In [58]:
num_classes = np.unique(training_labels).shape[0]
model = HoanMeshClassifier(
    in_channels,
    in_channels,
    in_channels,
    negative_slope=0.2,
    num_classes=num_classes,
    n_layers=1,
)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)

We train the HoanMeshClassifier using low amount of epochs: we keep training minimal for the purpose of rapid testing.

In [60]:
test_interval = 2
num_epochs = 4
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for (
        x_0,
        x_1,
        x_2,
        adjacency_0,
        adjacency_1,
        coadjacency_2,
        incidence_1,
        incidence_2,
        y,
    ) in zip(
        training_node_feat,
        training_edge_feat,
        training_face_feat,
        adjacency_0_train_list,
        adjacency_1_train_list,
        coadjacency_2_train_list,
        incidence_1_train_list,
        incidence_2_train_list,
        training_labels,
    ):
        x_0, x_1, x_2 = (
            torch.tensor(x_0, dtype=torch.float).to(device),
            torch.tensor(x_1, dtype=torch.float).to(device),
            torch.tensor(x_2, dtype=torch.float).to(device),
        )
        y = torch.tensor(y, dtype=torch.long).to(device)
        adjacency_0, adjacency_1, coadjacency_2 = (
            adjacency_0.float().to(device),
            adjacency_1.float().to(device),
            coadjacency_2.float().to(device),
        )
        incidence_1, incidence_2 = incidence_1.float().to(
            device
        ), incidence_2.float().to(device)
        opt.zero_grad()
        y_hat = model.forward(
            x_0,
            x_1,
            x_2,
            adjacency_0,
            adjacency_1,
            coadjacency_2,
            incidence_1,
            incidence_2,
        )
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for (
                x_0,
                x_1,
                x_2,
                adjacency_0,
                adjacency_1,
                coadjacency_2,
                incidence_1,
                incidence_2,
                y,
            ) in zip(
                testing_node_feat,
                testing_edge_feat,
                testing_face_feat,
                adjacency_0_test_list,
                adjacency_1_test_list,
                coadjacency_2_train_list,
                incidence_1_train_list,
                incidence_2_train_list,
                testing_labels,
            ):
                x_0, x_1, x_2 = (
                    torch.tensor(x_0, dtype=torch.float).to(device),
                    torch.tensor(x_1, dtype=torch.float).to(device),
                    torch.tensor(x_2, dtype=torch.float).to(device),
                )
                y = torch.tensor(y, dtype=torch.long).to(device)
                adjacency_0, adjacency_1, coadjacency_2 = (
                    adjacency_0.float().to(device),
                    adjacency_1.float().to(device),
                    coadjacency_2.float().to(device),
                )
                incidence_1, incidence_2 = incidence_1.float().to(
                    device
                ), incidence_2.float().to(device)
                opt.zero_grad()
                y_hat = model(
                    x_0,
                    x_1,
                    x_2,
                    adjacency_0,
                    adjacency_1,
                    coadjacency_2,
                    incidence_1,
                    incidence_2,
                )
                loss = crit(y_hat, y)
                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (750x750 and 500x500)

In [45]:
torch.tensor([3, 4, 5]).float()

tensor([3., 4., 5.])

In [40]:
a = [3, 5, 6]
a.float()

AttributeError: 'list' object has no attribute 'float'

In [82]:
hmc = HMCLayer([3, 3, 3], [3, 3, 3], [3, 3, 3], negative_slope=0.2)