# Train a Combinatorial Complex Attention Neural Network for Mesh Classification.

We create and train a mesh classification high order attentional neural network operating over combinatorial complexes. The model was introduced in [Figure 35(b), Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf).

## The Neural Network:

The neural network is composed of a sequence of identical attention layers for a dimension two combinatorial complex. Each layer is composed of two levels. In both levels, messages computed for the cells of identical dimension are aggregated using a sum operation. All the messages are computed using the attention mechanisms for squared and non-squared neighborhoods presented in [Definitions 31, 32, and 33, Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf). The following message passing scheme is followed in each of the levels for each layer:

1. First level:

🟥 $\quad m^{0\rightarrow 0}_{y\rightarrow x} = \phi\left(\left((A_{\uparrow, 0})_{xy} \cdot \text{att}_{xy}^{0\rightarrow 0}\right) h_y^{t,(0)} \Theta^t_{0\rightarrow 0}\right)$

🟥 $\quad m^{0\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((B_{1}^T)_{xy} \cdot \text{att}_{xy}^{0\rightarrow 1}\right) h_y^{t,(0)} \Theta^t_{0\rightarrow 1}\right)$

🟥 $\quad  m^{1\rightarrow 0}_{y\rightarrow x} = \phi\left(\left((B_{1})_{xy} \cdot \text{att}_{xy}^{1\rightarrow 0}\right) h_y^{t,(1)} \Theta^t_{1\rightarrow 0}\right)$

🟥 $\quad  m^{1\rightarrow 2}_{y\rightarrow x} = \phi\left(\left((B_{2}^T)_{xy} \cdot \text{att}_{xy}^{1\rightarrow 2}\right) h_y^{t,(1)} \Theta^t_{1\rightarrow 2}\right)$

🟥 $\quad m^{2\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((B_{2})_{xy} \cdot \text{att}_{xy}^{2\rightarrow 1}\right) h_y^{t,(2)} \Theta^t_{2\rightarrow 1}\right)$

🟧 $\quad m^{0\rightarrow 0}_{x}=\sum_{y\in A_{\uparrow, 0}(x)} m^{0\rightarrow 0}_{y\rightarrow x}$

🟧 $\quad m^{0\rightarrow 1}_{x}=\sum_{y\in B_{1}^T(x)} m^{0\rightarrow 1}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 0}_{x}=\sum_{y\in B_{1}(x)} m^{1\rightarrow 0}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 2}_{x}=\sum_{y\in B_{2}^T(x)} m^{1\rightarrow 2}_{y\rightarrow x}$

🟧 $\quad m^{2\rightarrow 1}_{x}=\sum_{y\in B_{2}(x)} m^{2\rightarrow 1}_{y\rightarrow x}$

🟩 $\quad m_x^{(0)}=m^{0\rightarrow 0}_{x}+m^{1\rightarrow 0}_{x}$

🟩 $\quad m_x^{(1)}=m^{0\rightarrow 1}_{x}+m^{2\rightarrow 1}_{x}$

🟩 $\quad m_x^{(2)}=m^{1\rightarrow 2}_{x}$

🟦 $\quad i_x^{t,(0)} = m_x^{(0)}$

🟦 $\quad i_x^{t,(1)} = m_x^{(1)}$

🟦 $\quad i_x^{t,(2)} = m_x^{(2)}$

where $i_x^{t,(\cdot)}$ represents intermediate feature vectors.


2. Second level:


🟥 $\quad m^{0\rightarrow 0}_{y\rightarrow x} = \phi\left(\left((A_{\uparrow, 0})_{xy} \cdot \text{att}_{xy}^{0\rightarrow 0}\right) i_y^{t,(0)} \Theta^t_{0\rightarrow 0}\right)$

🟥 $\quad m^{1\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((A_{\uparrow, 1})_{xy} \cdot \text{att}_{xy}^{1\rightarrow 1}\right) i_y^{t,(1)} \Theta^t_{1\rightarrow 1}\right)$

🟥 $\quad m^{2\rightarrow 2}_{y\rightarrow x} = \phi\left(\left((A_{\downarrow, 2})_{xy} \cdot \text{att}_{xy}^{2\rightarrow 2}\right) i_y^{t,(2)} \Theta^t_{2\rightarrow 2}\right)$

🟥 $\quad m^{0\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((B_{1}^T)_{xy} \cdot \text{att}_{xy}^{0\rightarrow 1}\right) i_y^{t,(0)} \Theta^t_{0\rightarrow 1}\right)$

🟥 $\quad m^{1\rightarrow 2}_{y\rightarrow x} = \phi\left(\left((B_{2}^T)_{xy} \cdot \text{att}_{xy}^{1\rightarrow 2}\right) i_y^{t,(1)} \Theta^t_{1\rightarrow 2}\right)$

🟧 $\quad m^{0\rightarrow 0}_{x}=\sum_{y\in A_{\uparrow, 0}(x)} m^{0\rightarrow 0}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 1}_{x}=\sum_{y\in A_{\uparrow, 1}(x)} m^{1\rightarrow 1}_{y\rightarrow x}$

🟧 $\quad m^{2\rightarrow 2}_{x}=\sum_{y\in A_{\downarrow, 2}(x)} m^{2\rightarrow 2}_{y\rightarrow x}$

🟧 $\quad m^{0\rightarrow 1}_{x}=\sum_{y\in B_{1}^T(x)} m^{0\rightarrow 1}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 2}_{x}=\sum_{y\in B_{2}^T(x)} m^{1\rightarrow 2}_{y\rightarrow x}$

🟩 $\quad m_x^{(0)}=m^{0\rightarrow 0}_{x}+m^{1\rightarrow 0}_{x}$

🟩 $\quad m_x^{(1)}=m^{1\rightarrow 1}_{x} + m^{0\rightarrow 1}_{x}$

🟩 $\quad m_x^{(2)}=m^{1\rightarrow 2}_{x} + m^{2\rightarrow 2}_{x}$

🟦 $\quad h_x^{t+1,(0)} = m_x^{(0)}$

🟦 $\quad h_x^{t+1,(1)} = m_x^{(1)}$

🟦 $\quad h_x^{t+1,(2)} = m_x^{(2)}$

In both message passing levels, $\phi$ represents a common activation function. Also, $\Theta$ and $\text{att}$ represent learnable weights and attention matrices, respectively, that are different in each level. Attention matrices are introduced in [Figure 35(b), Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf). In this case, attention matrices are computed using the LeakyReLU activation function, as in previous versions of the paper.

Notations, adjacency, coadjacency, and incidence matrices are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031). The tensor diagram for the layer can be found in the first column and last row of Figure 11, from the same paper.



## The Task:

We train this model to perform entire mesh classification on [`SHREC 2016` from the ShapeNet Dataset](http://shapenet.cs.stanford.edu/shrec16/). This dataset contains 480 3D mesh samples belonging to 30 distinct classes and represented as simplicial complexes.

Each mesh contains a set of vertices, edges, and faces. Each of the latter entities have a set of features associated to them:

- Node features $v \in R^6$ defined as the direct sum of the following features:
    - Position $p_v \in \mathbb{R}^3$ coordinates.
    - Normal $n_v \in \mathbb{R}^3$ coordinates.
- Edge features $e \in \mathbb{R}^{10}$ defined as the direct sum of the following features:
    - Dihedral angle $\phi \in \mathbb{R}$.
    - Edge span $l \in \mathbb{R}$.
    - 2 edge angle in the triangle that $\theta_e \in \mathbb{R}^2$.
    - 6 edge ratios $r \in \mathbb{R}^6$.
- Face features
    - Face area $a \in \mathbb{R}$.
    - Face normal $n_f \in \mathbb{R}^3$.
    - 3 face angles $\theta_f \in \mathbb{R}^3$.

We lift the simplicial complexes representing each mesh to a topologically equivalent combinatorial complex representation.

The task is to predict the class that a certain mesh belongs to, given its combinatorial complex representation. For this purpose we implement the Higher Order Attention Model for Mesh Classification first introduced in [Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf).

# Set-up


In [15]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from scipy.sparse import csr_matrix
from toponetx import CombinatorialComplex
from toponetx.datasets.mesh import shrec_16
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from topomodelx.nn.combinatorial.hmc_layer import HMCLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [82]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

We then lift each graph into our topological domain of choice, here: a cell complex.

We also retrieve:
- input signals `x_0` and `x_1` on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model's inputs,
- a binary classification label `y` associated to the cell complex.

In [17]:
shrec_training, shrec_testing = shrec_16()

Loading shrec 16 full dataset...

done!


In [26]:
class SHRECDataset(Dataset):
    def __init__(self, data):
        self.complexes = [cc.to_combinatorial_complex() for cc in data["complexes"]]
        self.x_0 = data["node_feat"]
        self.x_1 = data["edge_feat"]
        self.x_2 = data["face_feat"]
        self.y = data["label"]
        self.a0, self.a1, self.coa2, self.b1, self.b2 = self._get_neighborhood_matrix()

    def _get_neighborhood_matrix(self):

        a0 = []
        a1 = []
        coa2 = []
        b1 = []
        b2 = []

        for cc in self.complexes:

            a0.append(torch.from_numpy(cc.adjacency_matrix(0, 1).todense()).to_sparse())
            a1.append(torch.from_numpy(cc.adjacency_matrix(1, 2).todense()).to_sparse())

            B = cc.incidence_matrix(rank=2, to_rank=1)
            A = B.T @ B
            A.setdiag(0)
            coa2.append(torch.from_numpy(A.todense()).to_sparse())

            b1.append(torch.from_numpy(cc.incidence_matrix(1, 0).todense()).to_sparse())
            b2.append(torch.from_numpy(cc.incidence_matrix(2, 1).todense()).to_sparse())

        return a0, a1, coa2, b1, b2

    def num_classes(self):
        return len(np.unique(self.y))

    def channels_dim(self):
        return self.x_0[0].shape[1], self.x_1[0].shape[1], self.x_2[0].shape[1]

    def __len__(self):
        return len(self.complexes)

    def __getitem__(self, idx):
        return (
            self.x_0[idx],
            self.x_1[idx],
            self.x_2[idx],
            self.a0[idx],
            self.a1[idx],
            self.coa2[idx],
            self.b1[idx],
            self.b2[idx],
            self.y[idx],
        )

In [27]:
training_dataset = SHRECDataset(shrec_training)
training_dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29]


In [28]:
testing_dataset = SHRECDataset(shrec_training)
testing_dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True)

# Create the Neural Network


In [29]:
d_0, d_1, d_2 = training_dataset.channels_dim()
in_channels = [d_0, d_1, d_2]
print(
    f"The dimension of input features on nodes, edges and faces are: {d_0}, {d_1} and {d_2}."
)

The dimension of input features on nodes, edges and faces are: 6, 10 and 7.


In [69]:
class HoanMeshClassifier(torch.nn.Module):
    """HoanMeshClassifier.

    Parameters
    ----------
    in_channels : List[int]
        Dimension of input features on nodes, edges and faces respectively.
    intermediate_channels : List[int]
        Dimension of intermediate features on nodes, edges and faces respectively.
    out_channels : List[int]
        Dimension of output features on nodes, edges and faces respectively.
    num_classes : int
        Number of classes.
    n_layers : int
        Number of CCXN layers.
    """

    def __init__(
        self,
        in_channels,
        intermediate_channels,
        out_channels,
        num_classes,
        negative_slope=0.2,
        n_layers=1,
        final_layer_embedding_dimension=20,
    ) -> None:

        super().__init__()
        self.num_classes = num_classes

        self.layers = torch.nn.ModuleList(
            HMCLayer(
                in_channels=in_channels,
                intermediate_channels=intermediate_channels,
                out_channels=out_channels,
                negative_slope=negative_slope,
                softmax_attention=True,
                update_func_attention="relu",
                update_func_aggregation="relu",
            )
            for _ in range(n_layers)
        )
        self.l0 = torch.nn.Linear(out_channels[0], num_classes)
        self.l1 = torch.nn.Linear(out_channels[1], num_classes)
        self.l2 = torch.nn.Linear(out_channels[2], num_classes)
        self.final_layer = torch.nn.Linear(
            3 * final_layer_embedding_dimension, num_classes
        )

    def forward(
        self,
        x_0,
        x_1,
        x_2,
        neighborhood_0_to_0,
        neighborhood_1_to_1,
        neighborhood_2_to_2,
        neighborhood_0_to_1,
        neighborhood_1_to_2,
    ) -> torch.Tensor:

        for layer in self.layers:
            x_0, x_1, x_2 = layer(
                x_0,
                x_1,
                x_2,
                neighborhood_0_to_0,
                neighborhood_1_to_1,
                neighborhood_2_to_2,
                neighborhood_0_to_1,
                neighborhood_1_to_2,
            )

        x_0 = self.l0(x_0)
        x_1 = self.l1(x_1)
        x_2 = self.l2(x_2)
        # Sum all the elements in the dimension zero
        x_0 = torch.nanmean(x_0, dim=0)
        x_1 = torch.nanmean(x_1, dim=0)
        x_2 = torch.nanmean(x_2, dim=0)
        # x = torch.cat((x_0, x_1, x_2), dim=0)
        return x_0 + x_1 + x_2

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

In [71]:
class Trainer:
    def __init__(
        self, model, training_dataloader, testing_dataloader, learning_rate, device
    ):
        """Initializes the trainer with model, dataloaders, learning rate, and device."""
        self.model = model.to(device)
        self.training_dataloader = training_dataloader
        self.testing_dataloader = testing_dataloader
        self.device = device
        self.crit = torch.nn.CrossEntropyLoss()
        self.opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

    def _to_device(self, x):
        """Converts tensors to the correct type and moves them to the device."""
        return [el[0].float().to(self.device) for el in x]

    def train(self, num_epochs=500, test_interval=25):
        """Trains the model for the specified number of epochs."""
        for epoch_i in range(num_epochs):
            training_accuracy, epoch_loss = self._train_epoch()
            print(
                f"Epoch: {epoch_i} loss: {epoch_loss:.4f} Train_acc: {training_accuracy:.4f}",
                flush=True,
            )
            if (epoch_i + 1) % test_interval == 0:
                test_accuracy = self.validate()
                print(f"Test_acc: {test_accuracy:.4f}", flush=True)

    def _train_epoch(self):
        """Trains the model for one epoch."""
        training_samples = len(self.training_dataloader.dataset)
        total_loss = 0
        correct = 0
        self.model.train()
        for sample in self.training_dataloader:
            (
                x_0,
                x_1,
                x_2,
                adjacency_0,
                adjacency_1,
                coadjacency_2,
                incidence_1,
                incidence_2,
            ) = self._to_device(sample[:-1])

            self.opt.zero_grad()

            y_hat = self.model.forward(
                x_0,
                x_1,
                x_2,
                adjacency_0,
                adjacency_1,
                coadjacency_2,
                incidence_1,
                incidence_2,
            )

            y = sample[-1][0].long().to(self.device)
            total_loss += self._compute_loss_and_update(y_hat, y)
            correct += (y_hat.argmax() == y).sum().item()

        training_accuracy = correct / training_samples
        epoch_loss = total_loss / training_samples

        return training_accuracy, epoch_loss

    def _compute_loss_and_update(self, y_hat, y):
        """Computes the loss, performs backpropagation, and updates the model's parameters."""
        loss = self.crit(y_hat, y)
        loss.backward()
        self.opt.step()
        return loss.item()

    def validate(self):
        """Validates the model using the testing dataloader."""
        correct = 0
        self.model.eval()
        test_samples = len(self.testing_dataloader.dataset)
        with torch.no_grad():
            for sample in self.testing_dataloader:
                (
                    x_0,
                    x_1,
                    x_2,
                    adjacency_0,
                    adjacency_1,
                    coadjacency_2,
                    incidence_1,
                    incidence_2,
                ) = self._to_device(sample[:-1])

                y_hat = self.model(
                    x_0,
                    x_1,
                    x_2,
                    adjacency_0,
                    adjacency_1,
                    coadjacency_2,
                    incidence_1,
                    incidence_2,
                )
                y = sample[-1][0].long().to(self.device)
                correct += (y_hat.argmax() == y).sum().item()
            test_accuracy = correct / test_samples
            return test_accuracy

In [77]:
intermediate_channels = [100, 100, 100]
final_channels = [100, 100, 100]

model = HoanMeshClassifier(
    in_channels,
    intermediate_channels,
    in_channels,
    negative_slope=0.2,
    num_classes=training_dataset.num_classes(),
    n_layers=1,
    final_layer_embedding_dimension=50,
)

trainer = Trainer(model, training_dataloader, testing_dataloader, 0.001, "cpu")

We train the HoanMeshClassifier using low amount of epochs: we keep training minimal for the purpose of rapid testing.

In [78]:
trainer.train(num_epochs=5)

Epoch: 0 loss: 3.5006 Train_acc: 0.0292
Epoch: 1 loss: 3.4322 Train_acc: 0.0250
Epoch: 2 loss: 3.4171 Train_acc: 0.0479
Epoch: 3 loss: 3.3384 Train_acc: 0.0729
Epoch: 4 loss: 3.1561 Train_acc: 0.1042
