# Train a Combinatorial Complex Attention Neural Network for Mesh Classification.

We create and train a combinatorial complex attention neural network for mesh classification, introduced in [Figure 35(b), Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf).

### The Neural Network:

The neural network is composed of a sequence of identical attention layers for a dimension two combinatorial complex. Each layer is composed of two levels. In both levels, messages computed for the cells of identical dimension are aggregated using a sum operation. All the messages are computed using the attention mechanisms for squared and non-squared neighborhoods presented in [Definitions 31, 32, and 33, Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf). The following message passing scheme is followed in each of the levels for each layer:

1. First level:

🟥 $\quad m^{0\rightarrow 0}_{y\rightarrow x} = \phi\left(\left((A_{\uparrow, 0})_{xy} \cdot \text{att}_{xy}^{0\rightarrow 0}\right) h_y^{t,(0)} \Theta^t_{0\rightarrow 0}\right)$

🟥 $\quad m^{0\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((B_{1}^T)_{xy} \cdot \text{att}_{xy}^{0\rightarrow 1}\right) h_y^{t,(0)} \Theta^t_{0\rightarrow 1}\right)$

🟥 $\quad  m^{1\rightarrow 0}_{y\rightarrow x} = \phi\left(\left((B_{1})_{xy} \cdot \text{att}_{xy}^{1\rightarrow 0}\right) h_y^{t,(1)} \Theta^t_{1\rightarrow 0}\right)$

🟥 $\quad  m^{1\rightarrow 2}_{y\rightarrow x} = \phi\left(\left((B_{2}^T)_{xy} \cdot \text{att}_{xy}^{1\rightarrow 2}\right) h_y^{t,(1)} \Theta^t_{1\rightarrow 2}\right)$

🟥 $\quad m^{2\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((B_{2})_{xy} \cdot \text{att}_{xy}^{2\rightarrow 1}\right) h_y^{t,(2)} \Theta^t_{2\rightarrow 1}\right)$

🟧 $\quad m^{0\rightarrow 0}_{x}=\sum_{y\in A_{\uparrow, 0}(x)} m^{0\rightarrow 0}_{y\rightarrow x}$

🟧 $\quad m^{0\rightarrow 1}_{x}=\sum_{y\in B_{1}^T(x)} m^{0\rightarrow 1}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 0}_{x}=\sum_{y\in B_{1}(x)} m^{1\rightarrow 0}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 2}_{x}=\sum_{y\in B_{2}^T(x)} m^{1\rightarrow 2}_{y\rightarrow x}$

🟧 $\quad m^{2\rightarrow 1}_{x}=\sum_{y\in B_{2}(x)} m^{2\rightarrow 1}_{y\rightarrow x}$

🟩 $\quad m_x^{(0)}=m^{0\rightarrow 0}_{x}+m^{1\rightarrow 0}_{x}$

🟩 $\quad m_x^{(1)}=m^{0\rightarrow 1}_{x}+m^{2\rightarrow 1}_{x}$

🟩 $\quad m_x^{(2)}=m^{1\rightarrow 2}_{x}$

🟦 $\quad i_x^{t,(0)} = m_x^{(0)}$

🟦 $\quad i_x^{t,(1)} = m_x^{(1)}$

🟦 $\quad i_x^{t,(2)} = m_x^{(2)}$

where $i_x^{t,(\cdot)}$ represents intermediate feature vectors.


2. Second level:


🟥 $\quad m^{0\rightarrow 0}_{y\rightarrow x} = \phi\left(\left((A_{\uparrow, 0})_{xy} \cdot \text{att}_{xy}^{0\rightarrow 0}\right) i_y^{t,(0)} \Theta^t_{0\rightarrow 0}\right)$

🟥 $\quad m^{1\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((A_{\uparrow, 1})_{xy} \cdot \text{att}_{xy}^{1\rightarrow 1}\right) i_y^{t,(1)} \Theta^t_{1\rightarrow 1}\right)$

🟥 $\quad m^{2\rightarrow 2}_{y\rightarrow x} = \phi\left(\left((A_{\downarrow, 2})_{xy} \cdot \text{att}_{xy}^{2\rightarrow 2}\right) i_y^{t,(2)} \Theta^t_{2\rightarrow 2}\right)$

🟥 $\quad m^{0\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((B_{1}^T)_{xy} \cdot \text{att}_{xy}^{0\rightarrow 1}\right) i_y^{t,(0)} \Theta^t_{0\rightarrow 1}\right)$

🟥 $\quad m^{1\rightarrow 2}_{y\rightarrow x} = \phi\left(\left((B_{2}^T)_{xy} \cdot \text{att}_{xy}^{1\rightarrow 2}\right) i_y^{t,(1)} \Theta^t_{1\rightarrow 2}\right)$

🟧 $\quad m^{0\rightarrow 0}_{x}=\sum_{y\in A_{\uparrow, 0}(x)} m^{0\rightarrow 0}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 1}_{x}=\sum_{y\in A_{\uparrow, 1}(x)} m^{1\rightarrow 1}_{y\rightarrow x}$

🟧 $\quad m^{2\rightarrow 2}_{x}=\sum_{y\in A_{\downarrow, 2}(x)} m^{2\rightarrow 2}_{y\rightarrow x}$

🟧 $\quad m^{0\rightarrow 1}_{x}=\sum_{y\in B_{1}^T(x)} m^{0\rightarrow 1}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 2}_{x}=\sum_{y\in B_{2}^T(x)} m^{1\rightarrow 2}_{y\rightarrow x}$

🟩 $\quad m_x^{(0)}=m^{0\rightarrow 0}_{x}+m^{1\rightarrow 0}_{x}$

🟩 $\quad m_x^{(1)}=m^{1\rightarrow 1}_{x} + m^{0\rightarrow 1}_{x}$

🟩 $\quad m_x^{(2)}=m^{1\rightarrow 2}_{x} + m^{2\rightarrow 2}_{x}$

🟦 $\quad h_x^{t+1,(0)} = m_x^{(0)}$

🟦 $\quad h_x^{t+1,(1)} = m_x^{(1)}$

🟦 $\quad h_x^{t+1,(2)} = m_x^{(2)}$

In both message passing levels, $\phi$ represents a common activation function. Also, $\Theta$ and $\text{att}$ represent learnable weights and attention matrices, respectively, that are different in each level. Attention matrices are introduced in [Figure 35(b), Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf). In this case, attention matrices are computed using the LeakyReLU activation function, as in previous versions of the paper.

Notations, adjacency, coadjacency, and incidence matrices are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031). The tensor diagram for the layer can be found in the first column and last row of Figure 11, from the same paper.

### The Task:

We train this model to perform entire complex classification on [`MUTAG` from the TUDataset](https://paperswithcode.com/dataset/mutag). This dataset contains:
- 188 samples of chemical compounds represented as graphs,
- with 7 discrete node features.

The task is to predict the mutagenicity of each compound on Salmonella typhimurium.

# Set-up


In [15]:
import torch
import numpy as np
import torch.nn as nn
from toponetx import CombinatorialComplex
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from toponetx.datasets.mesh import shrec_16
from sklearn.model_selection import train_test_split
from scipy.sparse import csr_matrix
from topomodelx.nn.combinatorial.hmc_layer import HMCLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##




We then lift each graph into our topological domain of choice, here: a cell complex.

We also retrieve:
- input signals `x_0` and `x_1` on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model's inputs,
- a binary classification label `y` associated to the cell complex.

Node features:
    - Position
    - Normal

Edge features:
    - Dihedral angle
    - Edge span
    - 2 edge angle in the triangle
    - 6 edge ratios

Face features:
    - Face area
    - Face normal
    - Face angle


In [17]:
shrec_training, shrec_testing = shrec_16()

Loading shrec 16 full dataset...

done!


In [26]:
class SHRECDataset(Dataset):
    def __init__(self, data):
        self.complexes = [cc.to_combinatorial_complex() for cc in data["complexes"]]
        self.x_0 = data["node_feat"]
        self.x_1 = data["edge_feat"]
        self.x_2 = data["face_feat"]
        self.y = data["label"]
        self.a0, self.a1, self.coa2, self.b1, self.b2 = self._get_neighborhood_matrix()

    def _get_neighborhood_matrix(self):

        a0 = []
        a1 = []
        coa2 = []
        b1 = []
        b2 = []

        for cc in self.complexes:

            a0.append(torch.from_numpy(cc.adjacency_matrix(0, 1).todense()).to_sparse())
            a1.append(torch.from_numpy(cc.adjacency_matrix(1, 2).todense()).to_sparse())

            B = cc.incidence_matrix(rank=2, to_rank=1)
            A = B.T @ B
            A.setdiag(0)
            coa2.append(torch.from_numpy(A.todense()).to_sparse())

            b1.append(torch.from_numpy(cc.incidence_matrix(1, 0).todense()).to_sparse())
            b2.append(torch.from_numpy(cc.incidence_matrix(2, 1).todense()).to_sparse())

        return a0, a1, coa2, b1, b2

    def num_classes(self):
        return len(np.unique(self.y))

    def channels_dim(self):
        return self.x_0[0].shape[1], self.x_1[0].shape[1], self.x_2[0].shape[1]

    def __len__(self):
        return len(self.complexes)

    def __getitem__(self, idx):
        return (
            self.x_0[idx],
            self.x_1[idx],
            self.x_2[idx],
            self.a0[idx],
            self.a1[idx],
            self.coa2[idx],
            self.b1[idx],
            self.b2[idx],
            self.y[idx],
        )

In [27]:
training_dataset = SHRECDataset(shrec_training)
training_dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True)

In [28]:
testing_dataset = SHRECDataset(shrec_training)
testing_dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True)

# Create the Neural Network


In [29]:
d_0, d_1, d_2 = training_dataset.channels_dim()
in_channels = [d_0, d_1, d_2]
print(
    f"The dimension of input features on nodes, edges and faces are: {d_0}, {d_1} and {d_2}."
)

The dimension of input features on nodes, edges and faces are: 6, 10 and 7.


In [30]:
class HoanMeshClassifier(torch.nn.Module):
    """HoanMeshClassifier.

    Parameters
    ----------
    in_channels : List[int]
        Dimension of input features on nodes, edges and faces respectively.
    intermediate_channels : List[int]
        Dimension of intermediate features on nodes, edges and faces respectively.
    out_channels : List[int]
        Dimension of output features on nodes, edges and faces respectively.
    num_classes : int
        Number of classes.
    n_layers : int
        Number of CCXN layers.
    """

    def __init__(
        self,
        in_channels,
        intermediate_channels,
        out_channels,
        num_classes,
        negative_slope=0.2,
        n_layers=1,
        final_layer_embedding_dimension=20,
    ) -> None:

        super().__init__()
        self.num_classes = num_classes

        self.layers = torch.nn.ModuleList(
            HMCLayer(
                in_channels=in_channels,
                intermediate_channels=intermediate_channels,
                out_channels=out_channels,
                negative_slope=negative_slope,
                softmax_attention=True,
                update_func_attention="relu",
                update_func_aggregation="relu",
            )
            for _ in range(n_layers)
        )
        self.l0 = torch.nn.Linear(out_channels[0], final_layer_embedding_dimension)
        self.l1 = torch.nn.Linear(out_channels[1], final_layer_embedding_dimension)
        self.l2 = torch.nn.Linear(out_channels[2], final_layer_embedding_dimension)
        self.final_layer = torch.nn.Linear(
            3 * final_layer_embedding_dimension, num_classes
        )

    def forward(
        self,
        x_0,
        x_1,
        x_2,
        neighborhood_0_to_0,
        neighborhood_1_to_1,
        neighborhood_2_to_2,
        neighborhood_0_to_1,
        neighborhood_1_to_2,
    ) -> torch.Tensor:

        for layer in self.layers:
            x_0, x_1, x_2 = layer(
                x_0,
                x_1,
                x_2,
                neighborhood_0_to_0,
                neighborhood_1_to_1,
                neighborhood_2_to_2,
                neighborhood_0_to_1,
                neighborhood_1_to_2,
            )
        """
        # Take the sum of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nansum(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0

        x_0 = self.l0(zero_dimensional_cells_mean)
        x_1 = self.l1(one_dimensional_cells_mean)
        x_2 = self.l2(two_dimensional_cells_mean)


        #return F.softmax(x_0 + x_1 + x_2, dim=-1)
        return x_0 + x_1 + x_2
        """
        x_0 = self.l0(x_0)
        x_1 = self.l1(x_1)
        x_2 = self.l2(x_2)
        # Sum all the elements in the dimension zero
        x_0 = torch.nanmean(x_0, dim=0)
        x_1 = torch.nanmean(x_1, dim=0)
        x_2 = torch.nanmean(x_2, dim=0)
        x = torch.cat((x_0, x_1, x_2), dim=0)
        return self.final_layer(x)

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

In [40]:
intermediate_channels = [100, 100, 100]
final_channels = [100, 100, 100]

model = HoanMeshClassifier(
    in_channels,
    intermediate_channels,
    in_channels,
    negative_slope=0.2,
    num_classes=training_dataset.num_classes(),
    n_layers=1,
    final_layer_embedding_dimension=50,
)

We train the HoanMeshClassifier using low amount of epochs: we keep training minimal for the purpose of rapid testing.

In [66]:
class Trainer:
    def __init__(
        self, model, training_dataloader, testing_dataloader, learning_rate, device
    ):
        """Initializes the trainer with model, dataloaders, learning rate, and device."""
        self.model = model.to(device)
        self.training_dataloader = training_dataloader
        self.testing_dataloader = testing_dataloader
        self.device = device
        self.crit = torch.nn.CrossEntropyLoss()
        self.opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

    def _to_device(self, x):
        """Converts tensors to the correct type and moves them to the device."""
        return [el[0].float().to(self.device) for el in x]

    def train(self, num_epochs=500, test_interval=25):
        """Trains the model for the specified number of epochs."""
        for epoch_i in range(num_epochs):
            training_accuracy, epoch_loss = self._train_epoch()
            print(
                f"Epoch: {epoch_i} loss: {epoch_loss:.4f} Train_acc: {training_accuracy:.4f}",
                flush=True,
            )
            if (epoch_i + 1) % test_interval == 0:
                test_accuracy = self.validate()
                print(f"Test_acc: {test_accuracy:.4f}", flush=True)

    def _train_epoch(self):
        """Trains the model for one epoch."""
        training_samples = len(self.training_dataloader.dataset)
        total_loss = 0
        correct = 0
        self.model.train()
        for sample in self.training_dataloader:
            (
                x_0,
                x_1,
                x_2,
                adjacency_0,
                adjacency_1,
                coadjacency_2,
                incidence_1,
                incidence_2,
            ) = self._to_device(sample[:-1])

            self.opt.zero_grad()

            y_hat = self.model.forward(
                x_0,
                x_1,
                x_2,
                adjacency_0,
                adjacency_1,
                coadjacency_2,
                incidence_1,
                incidence_2,
            )

            y = sample[-1][0].long().to(self.device)
            total_loss += self._compute_loss_and_update(y_hat, y)
            correct += (y_hat.argmax() == y).sum().item()

        training_accuracy = correct / training_samples
        epoch_loss = total_loss / training_samples

        return training_accuracy, epoch_loss

    def _compute_loss_and_update(self, y_hat, y):
        """Computes the loss, performs backpropagation, and updates the model's parameters."""
        loss = self.crit(y_hat, y)
        loss.backward()
        self.opt.step()
        return loss.item()

    def validate(self):
        """Validates the model using the testing dataloader."""
        correct = 0
        self.model.eval()
        test_samples = len(self.testing_dataloader.dataset)
        with torch.no_grad():
            for sample in self.testing_dataloader:
                (
                    x_0,
                    x_1,
                    x_2,
                    adjacency_0,
                    adjacency_1,
                    coadjacency_2,
                    incidence_1,
                    incidence_2,
                ) = self._to_device(sample[:-1])

                y_hat = self.model(
                    x_0,
                    x_1,
                    x_2,
                    adjacency_0,
                    adjacency_1,
                    coadjacency_2,
                    incidence_1,
                    incidence_2,
                )
                y = sample[-1][0].long().to(self.device)
                correct += (y_hat.argmax() == y).sum().item()
            test_accuracy = correct / test_samples
            return test_accuracy

In [67]:
trainer = Trainer(model, training_dataloader, testing_dataloader, 0.001, "cpu")

In [68]:
trainer.train()

Epoch: 0 loss: 1.4025 Train_acc: 0.5229
Epoch: 1 loss: 1.2714 Train_acc: 0.5583
Epoch: 2 loss: 1.2041 Train_acc: 0.5708
Epoch: 3 loss: 1.1085 Train_acc: 0.6188
Epoch: 4 loss: 0.9987 Train_acc: 0.6500
Epoch: 5 loss: 0.9877 Train_acc: 0.6604
Epoch: 6 loss: 0.9208 Train_acc: 0.6854
Epoch: 7 loss: 0.8291 Train_acc: 0.7146
Epoch: 8 loss: 0.8049 Train_acc: 0.7042
Epoch: 9 loss: 0.7337 Train_acc: 0.7500
Epoch: 10 loss: 0.7134 Train_acc: 0.7375
Epoch: 11 loss: 0.6624 Train_acc: 0.7729
Epoch: 12 loss: 0.6511 Train_acc: 0.7771
Epoch: 13 loss: 0.5759 Train_acc: 0.7958
Epoch: 14 loss: 0.6288 Train_acc: 0.7729
Epoch: 15 loss: 0.5740 Train_acc: 0.7917
Epoch: 16 loss: 0.5556 Train_acc: 0.7854
Epoch: 17 loss: 0.5091 Train_acc: 0.7958
Epoch: 18 loss: 0.4983 Train_acc: 0.8208
Epoch: 19 loss: 0.4640 Train_acc: 0.8333
Epoch: 20 loss: 0.4449 Train_acc: 0.8500
Epoch: 21 loss: 0.5013 Train_acc: 0.8167
Epoch: 22 loss: 0.4369 Train_acc: 0.8438
Epoch: 23 loss: 0.4487 Train_acc: 0.8250
Unexpected exception forma

Traceback (most recent call last):
  File "/Users/manuellecha/miniconda3/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/ll/bd2dgg_50_52khsw_lrcpvwc0000gn/T/ipykernel_90025/49973641.py", line 1, in <module>
    trainer.train()
  File "/var/folders/ll/bd2dgg_50_52khsw_lrcpvwc0000gn/T/ipykernel_90025/12018842.py", line 20, in train
    training_accuracy, epoch_loss = self._train_epoch()
  File "/var/folders/ll/bd2dgg_50_52khsw_lrcpvwc0000gn/T/ipykernel_90025/12018842.py", line 61, in _train_epoch
    total_loss += self._compute_loss_and_update(y_hat, y)
  File "/var/folders/ll/bd2dgg_50_52khsw_lrcpvwc0000gn/T/ipykernel_90025/12018842.py", line 72, in _compute_loss_and_update
    loss.backward()
  File "/Users/manuellecha/miniconda3/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    retain_graph=retain_graph,
  File "/Users/manuellecha/miniconda3/lib/pyt