# Train a Convolutional Cell Complex Network (CCXN)

We create and train a simplified version of the CCXN originally proposed in [Hajij et. al : Cell Complex Neural Networks (2020)](https://arxiv.org/pdf/2010.00743.pdf).

### The Neural Network:

The equations of one layer of this neural network are given by:

1. A convolution from nodes to nodes using an adjacency message passing scheme (AMPS):

🟥 $\quad m_{y \rightarrow \{z\} \rightarrow x}^{(0 \rightarrow 1 \rightarrow 0)} = M_{\mathcal{L}_\uparrow}^t(h_x^{t,(0)}, h_y^{t,(0)}, \Theta^{t,(y \rightarrow x)})$ 

🟧 $\quad m_x^{(0 \rightarrow 1 \rightarrow 0)} = AGG_{y \in \mathcal{L}_\uparrow(x)}(m_{y \rightarrow \{z\} \rightarrow x}^{0 \rightarrow 1 \rightarrow 0})$ 

🟩 $\quad m_x^{(0)} = m_x^{(0 \rightarrow 1 \rightarrow 0)}$ 

🟦 $\quad h_x^{t+1,(0)} = U^{t}(h_x^{t,(0)}, m_x^{(0)})$

2. A convolution from edges to faces using a cohomology message passing scheme:

🟥 $\quad m_{y \rightarrow x}^{(r' \rightarrow r)} = M^t_{\mathcal{C}}(h_{x}^{t,(r)}, h_y^{t,(r')}, x, y)$ 

🟧 $\quad m_x^{(r' \rightarrow r)}  = AGG_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(r' \rightarrow r)}$ 

🟩 $\quad m_x^{(r)} = m_x^{(r' \rightarrow r)}$ 

🟦 $\quad h_{x}^{t+1,(r)} = U^{t,(r)}(h_{x}^{t,(r)}, m_{x}^{(r)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

### The Task:

We train this model to perform entire complex classification on [`MUTAG` from the TUDataset](https://paperswithcode.com/dataset/mutag). This dataset contains:
- 188 samples of chemical compounds represented as graphs,
- with 7 discrete node features.

The task is to predict the mutagenicity of each compound on Salmonella typhimurium.

# Set-up


In [1]:
import importlib
import random

import numpy as np
import toponetx as tnx
import torch
from sklearn.model_selection import train_test_split
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx

import topomodelx.base.conv
from topomodelx.nn.cell.ccxn_layer import CCXNLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

We import a subset of MUTAG, a benchmark dataset for graph classification. 

We then lift each graph into our topological domain of choice, here: a cell complex.

We also retrieve:
- input signals `x_0` and `x_1` on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model's inputs,
- a binary classification label `y` associated to the cell complex.

In [3]:
dataset = TUDataset(
    root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True, use_node_attr=True
)
dataset = dataset[:20]
cc_list = []
x_0_list = []
x_1_list = []
y_list = []
for graph in dataset:
    cell_complex = tnx.transform.graph_2_neighbor_complex(
        to_networkx(graph).to_undirected()
    ).to_cell_complex()
    cc_list.append(cell_complex)
    x_0_list.append(graph.x)
    x_1_list.append(graph.edge_attr)
    y_one_hot = torch.zeros(2)
    y_one_hot[int(graph.y)] = 1.0
    y_list.append(y_one_hot.float())

In [24]:
import networkx as nx

G = to_networkx(dataset[1]).to_undirected()
G1 = tnx.transform.graph_2_neighbor_complex(
    to_networkx(dataset[1]).to_undirected()
).to_cell_complex()

In [27]:
G.edges

14

When lifting to a cell complex, we create more edges: how do we put features on these edges?

In [30]:
len(G1.edges and G.edges)

14

In [5]:
g1 = tnx.transform.graph_2_neighbor_complex(
    to_networkx(dataset[1]).to_undirected()
).to_cell_complex()

In [6]:
to_networkx(dataset[1]).edges

OutEdgeView([(0, 1), (0, 9), (1, 0), (1, 2), (2, 1), (2, 3), (2, 7), (3, 2), (3, 4), (4, 3), (4, 5), (5, 4), (5, 6), (6, 5), (6, 7), (7, 2), (7, 6), (7, 8), (8, 7), (8, 9), (8, 10), (9, 0), (9, 8), (10, 8), (10, 11), (10, 12), (11, 10), (12, 10)])

In [7]:
g1.edges

EdgeView([(0, 1), (0, 9), (0, 2), (0, 8), (1, 9), (1, 2), (1, 7), (9, 8), (9, 10), (2, 3), (2, 4), (2, 8), (2, 6), (3, 4), (3, 5), (3, 7), (4, 5), (4, 6), (5, 6), (5, 7), (6, 7), (7, 8), (7, 10), (8, 10), (8, 12), (10, 11), (11, 12)])

In [8]:
i_cc = 1
print(f"\nThe {i_cc}th cell complex has the following structure: \n{cc_list[i_cc]}.")

print("\nIt supports the following features:")
print(f"On nodes: {x_0_list[i_cc].shape}.")
print(f"On edges: {x_1_list[i_cc].shape}.")

print(f"\nIt has the label: {y_list[i_cc]}")


The 1th cell complex has the following structure: 
Cell Complex with 13 nodes, 27 edges  and 11 2-cells .

It supports the following features:
On nodes: torch.Size([13, 7]).
On edges: torch.Size([28, 4]).

It has the label: tensor([1., 0.])


## Define neighborhood structures. ##

Implementing the CCXN architecture will require to perform message passing along neighborhood structures of the cell complexes.

Thus, now we retrieve these neighborhood structures (i.e. their representative matrices) that we will use to send messages. 

For the CCXN, we need the adjacency matrix $A_{\uparrow, 0}$ and the coboundary matrix $B_2^T$ of each cell complex.

In [9]:
incidence_2_t_list = []
adjacency_0_list = []
for cell_complex in cc_list:
    incidence_2_t = cell_complex.incidence_matrix(rank=2).T
    adjacency_0 = cell_complex.adjacency_matrix(rank=0)
    incidence_2_t = torch.from_numpy(incidence_2_t.todense()).to_sparse()
    adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()
    incidence_2_t_list.append(incidence_2_t)
    adjacency_0_list.append(adjacency_0)

In [10]:
i_cc = 5
print(f"Incidence_2_t B2T of the {i_cc}-th complex: {incidence_2_t_list[i_cc].shape}.")
print(f"Adjacency A0 of the {i_cc}-th complex: {adjacency_0_list[i_cc].shape}.")

Incidence_2_t B2T of the 5-th complex: torch.Size([20, 59]).
Adjacency A0 of the 5-th complex: torch.Size([28, 28]).


# Create the Neural Network

Using the CCXNLayer class, we create a neural network with stacked layers.

In [11]:
in_channels_0 = x_0_list[0].shape[-1]
in_channels_1 = x_1_list[0].shape[-1]
in_channels_2 = 5
print(
    f"The dimension of input features on nodes, edges and faces are: {in_channels_0}, {in_channels_1} and {in_channels_2}."
)

The dimension of input features on nodes, edges and faces are: 7, 4 and 5.


In [12]:
class CCXN(torch.nn.Module):
    """CCXN.

    Parameters
    ----------
    in_channels_0 : int
        Dimension of input features on nodes.
    in_channels_1 : int
        Dimension of input features on edges.
    in_channels_2 : int
        Dimension of input features on faces.
    n_classes : int
        Number of classes.
    n_layers : int
        Number of CCXN layers.
    att : bool
        Whether to use attention.
    """

    def __init__(
        self,
        in_channels_0,
        in_channels_1,
        in_channels_2,
        n_classes,
        n_layers=2,
        att=False,
    ):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                CCXNLayer(
                    in_channels_0=in_channels_0,
                    in_channels_1=in_channels_1,
                    in_channels_2=in_channels_2,
                    att=att,
                )
            )
        self.layers = layers
        self.lin_0 = torch.nn.Linear(in_channels_0, n_classes)
        self.lin_1 = torch.nn.Linear(in_channels_1, n_classes)
        self.lin_2 = torch.nn.Linear(in_channels_2, n_classes)

    def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
        """Forward computation through layers, then linear layers, then avg pooling.

        Parameters
        ----------
        x_0 : torch.Tensor, shape = [n_nodes, in_channels_0]
            Input features on the nodes (0-cells).
        x_1 : torch.Tensor, shape = [n_edges, in_channels_1]
            Input features on the edges (1-cells).
        neighborhood_0_to_0 : tensor, shape = [n_nodes, n_nodes]
            Adjacency matrix of rank 0 (up).
        neighborhood_1_to_2 : tensor, shape = [n_faces, n_edges]
            Transpose of boundary matrix of rank 2.
        x_2 : torch.Tensor, shape = [n_faces, in_channels_2]
            Input features on the faces (2-cells).
            Optional. Use for attention mechanism between edges and faces.

        Returns
        -------
        _ : tensor, shape = [1]
            Label assigned to whole complex.
        """
        for layer in self.layers:
            x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)
        return (
            torch.nanmean(x_2, dim=0) + torch.mean(x_1, dim=0) + torch.mean(x_0, dim=0)
        )

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

In [13]:
model = CCXN(in_channels_0, in_channels_1, in_channels_2, n_classes=2, n_layers=2)
model = model.to(device)
cross_entropy_loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

We split the dataset into train and test sets.

In [14]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0_list, test_size=test_size, shuffle=False)
x_1_train, x_1_test = train_test_split(x_1_list, test_size=test_size, shuffle=False)
incidence_2_t_train, incidence_2_t_test = train_test_split(
    incidence_2_t_list, test_size=test_size, shuffle=False
)
adjacency_0_train, adjacency_0_test = train_test_split(
    adjacency_0_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=test_size, shuffle=False)

We train the CCXN using low amount of epochs: we keep training minimal for the purpose of rapid testing.

In [15]:
test_interval = 2
n_epochs = 5
for i_epoch in range(n_epochs):
    epoch_loss = []
    n_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
        x_0_train, x_1_train, incidence_2_t_train, adjacency_0_train, y_train
    ):
        print(f"N-samples = {n_samples}")
        optimizer.zero_grad()

        y_hat = model(
            x_0.float(), x_1.float(), adjacency_0.float(), incidence_2_t.float()
        )
        print(f"y = {y}")
        print(f"y_hat = {y_hat}")
        loss = cross_entropy_loss(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        n_samples += 1
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    train_acc = correct / n_samples
    print(
        f"Epoch: {i_epoch} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if i_epoch % test_interval == 0:
        with torch.no_grad():
            n_samples = 0
            correct = 0
            for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
                x_0_test, x_1_test, incidence_2_t_test, adjacency_0_test, y_test
            ):
                y = torch.tensor(y).long()
                y_hat = model(
                    x_0.float(), x_1.float(), adjacency_0.float(), incidence_2_t.float()
                )

                correct += (y_hat.argmax() == y).sum().item()
                n_samples += 1
            test_acc = correct / n_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

N-samples = 0

 x_source.shape = torch.Size([17, 7])
self.weight.shape = torch.Size([7, 7])
neighborhood.shape = torch.Size([17, 17])
x_message.shape = torch.Size([17, 7])

 x_source.shape = torch.Size([38, 4])
self.weight.shape = torch.Size([4, 5])
neighborhood.shape = torch.Size([15, 38])
x_message.shape = torch.Size([38, 5])

 x_source.shape = torch.Size([17, 7])
self.weight.shape = torch.Size([7, 7])
neighborhood.shape = torch.Size([17, 17])
x_message.shape = torch.Size([17, 7])

 x_source.shape = torch.Size([38, 4])
self.weight.shape = torch.Size([4, 5])
neighborhood.shape = torch.Size([15, 38])
x_message.shape = torch.Size([38, 5])
y = tensor([0., 1.])
y_hat = tensor([ 2.2376, -1.8781], grad_fn=<AddBackward0>)
N-samples = 1

 x_source.shape = torch.Size([13, 7])
self.weight.shape = torch.Size([7, 7])
neighborhood.shape = torch.Size([13, 13])
x_message.shape = torch.Size([13, 7])

 x_source.shape = torch.Size([28, 4])
self.weight.shape = torch.Size([4, 5])
neighborhood.shape = tor

RuntimeError: addmm: Argument #3 (dense): Expected dim 0 size 27, got 28

# Train the Neural Network with Attention


Now we create a new neural network, that uses the attention mechanism.

In [None]:
model = CCXN(
    in_channels_0, in_channels_1, in_channels_2, n_classes=2, n_layers=2, att=True
)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

We run the training for this neural network:

In [None]:
test_interval = 2
for i_epoch in range(1, 5):
    epoch_loss = []
    n_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
        x_0_train, x_1_train, incidence_2_t_train, adjacency_0_train, y_train
    ):
        optimizer.zero_grad()

        y_hat = model(
            x_0.float(), x_1.float(), adjacency_0.float(), incidence_2_t.float()
        )
        y = torch.tensor(y).long()
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        n_samples += 1
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    train_acc = correct / n_samples
    print(
        f"Epoch: {i_epoch} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if i_epoch % test_interval == 0:
        with torch.no_grad():
            n_samples = 0
            correct = 0
            for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
                x_0_test, x_1_test, incidence_2_t_test, adjacency_0_test, y_test
            ):
                y = torch.tensor(y).long()
                y_hat = model(
                    x_0.float(), x_1.float(), adjacency_0.float(), incidence_2_t.float()
                )

                correct += (y_hat.argmax() == y).sum().item()
                n_samples += 1
            test_acc = correct / n_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)