# Train a Simplicial Attention Network (SAN)

We create and train a Simplicial Attention Neural Networks (SAN) originally proposed in [Giusti*, Battiloro* et. al : Simplicial Attention Neural Networks (2022)](https://arxiv.org/abs/2203.07485). The aim of this notebook is to be didactic and clear, for further technical and implementation details please refer to the original paper and the TopoModelX documentation.

### Abstract
The aim of this work is to introduce simplicial attention networks (SANs), i.e., novel neural architectures that operate on data defined on simplicial complexes leveraging masked self-attentional layers. Hinging on formal arguments from topological signal processing, we introduce a proper self-attention mechanism able to process data components at different layers (e.g., nodes, edges, triangles, and so on), while learning how to weight both upper and lower neighborhoods of the given topological domain in a totally task-oriented fashion. The proposed SANs generalize most of the current architectures available for processing data defined on simplicial complexes.

<center><a href="https://ibb.co/jVggJzK"><img src="https://i.ibb.co/PTwwDMp/SAN-architecture.jpg" alt="SAN-architecture" border="0"></a></center>

**Remark.** The notation we use is defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031)and [Hajij et al : Topological Deep Learning: Going Beyond Graph Data(2023)](https://arxiv.org/pdf/2206.00606.pdf). Custom symbols are introduced along the notebook, when necessary.

### The Neural Network
The SAN layer takes rank-$r$ signals as input  and gives rank-$r$ signals as output. The involved neighborhoods are:

\begin{equation}
\mathcal N = \{\mathcal N_1, \mathcal N_2,...,\mathcal N_{2p+1}\} =  \{A_{\uparrow, r}, A_{\downarrow, r}, A_{\uparrow, r}^2, A_{\downarrow, r}^2,...,A_{\uparrow, r}^p, A_{\downarrow, r}^p, Q_r\},
\end{equation}
where $Q_r$ is a sparse projection operator (weighted matrix) over the kernel of the $r$-th Hodge Laplacian $L_r$, computed as in the original paper. $Q_r$ has the same topology of $L_r$.

The equation of the SAN layer of this neural network is given by:

\begin{equation}
\textbf{h}_x^{t+1} =  \phi^l \Bigg ( \textbf{h}_x^{t}, \bigotimes_{\mathcal{N}_k\in\mathcal N}\bigoplus_{y \in \mathcal{N}_k(x)}  \widetilde{\alpha}_k(h_x^t,hy^t)\Bigg ),
\end{equation}

with $\widetilde{\alpha}_k$ being either an attention function $\alpha_k$ if $\mathcal{N}_k \neq Q_r$ or a standard convolution term(affine transformation + weights) with weights given by the entries of $Q_r$ if $\mathcal{N}_k = Q_r$.

Therefore, the SAN layer is made by an attentional convolution from rank-$r$ cells to rank-$r$ cells using an adjacency message passing scheme up to $p$-hops neighborhoods:

\begin{align*}
&🟥\textrm{ Message.} &\quad m_{(y \rightarrow x),k} =&
\alpha_k(h_x^t,h_y^t) =
a_k(h_x^{t}, h_y^{t}) \cdot \psi_k^t(h_x^{t})\quad \forall \mathcal N_k \in \mathcal{N}\\
\\
&🟧 \textrm{ Within-Neighborhood Aggregation.} &\quad m_{x,k}               =& \bigoplus_{y \in \mathcal{N}_k(x)}  m_{(y \rightarrow x),k}\\
\\
&🟩 \textrm{ Between-Neighborhood Aggregation.} &\quad m_{x} =& \bigotimes_{\mathcal{N}_k\in\mathcal N}m_{x,k}\\
\\
&🟦 \textrm{ Update.}&\quad h_x^{t+1}                =& \phi^{t}(h_x^t, m_{x})
\end{align*}




### The Task:

We train this model to perform a binary node classification task using KarateClub dataset. We use a ["GAT-like" attention function](https://arxiv.org/abs/1710.10903), in which two different sets of attention weights $a_\uparrow$ and $a_\downarrow$ are learned for the upper neighborhoods $A_{\uparrow,1}^p$ and for the lower neighborhoods $A_{\downarrow,1}^p$ ($p=1,...,P$), respectively,   i.e.:

- If $\mathcal{N}_k \neq Q_r$  and suppose, as an example, $\mathcal{N}_k = A_{\downarrow,1}^g$, the $g$-hops lower neighborhood:
\begin{align}
&a_k(h_x^{t}, h_y^{t}) = (\textrm{softmax}_j(\textrm{LeakyReLU}(a_{\downarrow}^T[\underset{p=1}{\overset{P}{||}}h_x^{t}W_{\downarrow,p}|| \underset{p=1}{\overset{P}{||}}h_y^{t}W_{\downarrow,p}]))^g\\
& \psi_k^t(h_x^{t}) = h_x^{t}W_{\downarrow,g}.
\end{align}

- If $\mathcal{N}_k = Q_r$:
\begin{align}
&a_k(h_x^{t}, h_y^{t}) = Q_{x,y}\\
& \psi_k^t(h_x^{t}) = h_x^{t}W.
\end{align}

$W$, $a_\downarrow$, $a_\uparrow$, \{$W_{\downarrow,p}\}_{p=1}^P$ and $\{W_{\uparrow,p}\}_{p=1}^P$ are learnable weights.


In [None]:
import numpy as np
import torch
from torch.nn.parameter import Parameter
import toponetx.datasets.graph as graph
from torch_geometric.utils.convert import to_networkx

from topomodelx.nn.simplicial.san_layer import SANLayer

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import dataset ##

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

In [None]:
dataset = graph.karate_club(complex_type="simplicial")
print(dataset)

Simplicial Complex with shape [34, 78, 45, 11, 2] and dimension 4


In [None]:
dataset.shape

[34, 78, 45, 11, 2]

## Define neighborhood structures. ##

We now retrieve the neighborhoods (i.e. their representative matrices) that we will use to send messages on the domain. In this case, we decide w.l.o.g. to work at the edge level (thus considering a simplicial complex of order 2). We therefore need the lower and upper laplacians of rank 1, $L_{\downarrow,1}=B_1^TB_1$ and $L_{\uparrow,1}=B_2B_2^T$, both with dimensions $n_\text{edges} \times n_\text{edges}$, where $B_1$ and $B_2$ are the incidence matrices of rank 1 and 2. Please notice that the binary adjacencies $A_{\downarrow,1}^p$ and  $A_{\uparrow,1}^p$ encoding the $p$-hops neighborhoods are given by the support (the non-zeros pattern) of $L_{\downarrow,1}^p$ and $L_{\uparrow,1}^p$, respectively. We also convert the neighborhood structures to torch tensors.

**Remark.** In the case of rank-0 simplices (nodes), there is no lower Laplacian; in this case, we just initialize the down laplacian as a 0-matrix, and SAN automatically becomes a GAT-like architecture.
In the case of simplices of maxium rank (the order of the complex), there is no upper Laplacian. In this case we can also initialize it as a 0 matrix and SAN will only consider the lower adjacencies.

In [None]:
simplex_order_k = 1
# Down laplacian
try:
    Ldown = torch.from_numpy(
        dataset.down_laplacian_matrix(rank=simplex_order_k).todense()
    ).to_sparse()
except ValueError:
    Ldown = torch.zeros(
        (dataset.shape[simplex_order_k], dataset.shape[simplex_order_k])
    ).to_sparse()
# Up laplacian
try:
    Lup = torch.from_numpy(
        dataset.up_laplacian_matrix(rank=simplex_order_k).todense()
    ).to_sparse()
except ValueError:
    Lup = torch.zeros(
        (dataset.shape[simplex_order_k], dataset.shape[simplex_order_k])
    ).to_sparse()

## Import signal ##

We define edge features to be the gradient of the nodes features, i.e. given the node feature matrix $X_0$, we compute the edge features matrix as $X_1 = B_1^TX_0$. We will finally obtain the estimated node labels from the updated edge features by multiplying them again with $B_1$, i.e. the final nodes features are computed as the divergence of the final edge features.

**Remark.** Please notice that also this way of deriving edges/nodes features from nodes/edges features could be seen as a (non-learnable) message passing between rank-0/1 cells (nodes/edges) and rank-1/0 cells (nodes).

In [None]:
x_0 = []
for _, v in dataset.get_simplex_attributes("node_feat").items():
    x_0.append(v)
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

x_1 = []
for k, v in dataset.get_simplex_attributes("edge_feat").items():
    x_1.append(v)
x_1 = torch.tensor(np.stack(x_1))
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

x_2 = []
for k, v in dataset.get_simplex_attributes("face_feat").items():
    x_2.append(v)
x_2 = np.stack(x_2)
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")

There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.


We use the incidence matrix between nodes-edges:

In [None]:
incidence_0_1 = torch.from_numpy(dataset.incidence_matrix(1).todense()).to_sparse()

The final edge features are obtained summing the original features of those edges plus the projection of the node features onto edges (using the incidence matrix accordingly):

In [None]:
x = x_1 + torch.sparse.mm(incidence_0_1.T, x_0)

Hence, the final input features are defined by this sum, and we also pre-define the number of hidden and output channels of the model.

In [None]:
in_channels = x.shape[-1]
hidden_channels = 16
out_channels = 2

## Define binary labels
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert one-hot encode the binary labels, and keep the first four nodes for the purpose of testing.

In [None]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_test = y_true[:4]
y_train = y_true[-30:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# Create the Neural Network

Using the SANLayer class, we create a neural network with stacked layers. A linear layer at the end produces an output with shape $n_\text{nodes} \times 2$, so we can compare with our binary labels.

In [None]:
class SAN(torch.nn.Module):
    r"""Simplicial Attention Network (SAN) implementation for binary edge classification.

    Parameters
    ----------
    in_channels : int
        Dimension of input features.
    hidden_channels : int
        Dimension of hidden features.
    out_channels : int
        Dimension of output features.
    simplex_order_k : int
        Order r of the considered simplices. Default to 1 (edges).
    num_filters_J : int, optional
        Approximation order for simplicial filters. Defaults to 2.
    J_har : int, optional
        Approximation order for harmonic convolution. Defaults to 5.
    epsilon_har : float, optional
        Epsilon value for harmonic convolution. Defaults to 1e-1.
    n_layers : int, optional
        Number of message passing layers. Defaults to 2.
    """

    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        num_filters_J=2,
        J_har=5,
        epsilon_har=1e-1,
        n_layers=2,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_filters_J = num_filters_J
        self.J_har = J_har
        self.epsilon_har = epsilon_har
        if n_layers == 1:
            self.layers = [
                SANLayer(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    num_filters_J=self.num_filters_J,
                )
            ]
        else:
            self.layers = [
                SANLayer(
                    in_channels=self.in_channels,
                    out_channels=self.hidden_channels,
                    num_filters_J=self.num_filters_J,
                )
            ]
            for _ in range(n_layers - 2):
                self.layers.append(
                    SANLayer(
                        in_channels=self.hidden_channels,
                        out_channels=self.hidden_channels,
                        num_filters_J=self.num_filters_J,
                    )
                )
            self.layers.append(
                SANLayer(
                    in_channels=self.hidden_channels,
                    out_channels=self.out_channels,
                    num_filters_J=self.num_filters_J,
                )
            )
        self.linear = torch.nn.Linear(out_channels, 2)

    def compute_projection_matrix(self, L):
        r"""Computation of the projection matrix which is then used
        to calculate the harmonic component in SAN layers.

        Parameters
        ---------

        L : tensor
            shape = [n_edges, n_edges]
            Hodge laplacian of rank 1.


        Returns
        --------
        _ : tensor
            shape = [n_edges, n_edges]
            Projection matrix.

        """
        P = torch.eye(L.shape[0]) - self.epsilon_har * L
        P = torch.linalg.matrix_power(P, self.J_har)
        return P

    def forward(self, x, Lup, Ldown):
        r"""Forward computation.

        Parameters
        ---------
        x : tensor
            shape = [n_nodes, channels_in]
            Node features.

        Lup : tensor
            shape = [n_edges, n_edges]
            Upper laplacian of rank 1.

        Ld : tensor
            shape = [n_edges, n_edges]
            Down laplacian of rank 1.


        Returns
        --------
        _ : tensor
            shape = [n_nodes, 2]
            One-hot labels assigned to edges.

        """
        # Compute the projection matrix for the harmonic component
        L = Lup + Ldown
        P = self.compute_projection_matrix(L)

        # Forward computation
        for layer in self.layers:
            x = layer(x, Lup, Ldown, P)
        return torch.sigmoid(self.linear(x))

In [None]:
Lup.shape, Ldown.shape, x_1.shape

(torch.Size([78, 78]), torch.Size([78, 78]), torch.Size([78, 2]))

# Train the Neural Network

The following cell performs the training, looping over the network for a low number of epochs.

In [None]:
model = SAN(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=1,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.4)
test_interval = 2
num_epochs = 10
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat_edge = model(x, Lup=Lup, Ldown=Ldown)
    # We project the edge-level output of the model to the node-level
    # and apply softmax fn to get the final node-level classification output
    y_hat = torch.softmax(torch.sparse.mm(incidence_0_1, y_hat_edge), dim=1)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[-len(y_train) :].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[-len(y_train) :] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_edge_test = model(x, Lup=Lup, Ldown=Ldown)
            # Projection to node-level
            y_hat_test = torch.softmax(
                torch.sparse.mm(incidence_0_1, y_hat_edge_test), dim=1
            )
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            # _pred_test = torch.softmax(y_hat_test,dim=1).ge(0.5).float()
            test_accuracy = (
                torch.eq(y_pred_test[: len(y_test)], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

Epoch: 1 loss: 0.7240 Train_acc: 0.4333
Epoch: 2 loss: 0.7114 Train_acc: 0.7333
Test_acc: 0.2500
Epoch: 3 loss: 0.6948 Train_acc: 0.7333
Epoch: 4 loss: 0.6829 Train_acc: 0.7333
Test_acc: 0.2500
Epoch: 5 loss: 0.6760 Train_acc: 0.7333
Epoch: 6 loss: 0.6721 Train_acc: 0.7333
Test_acc: 0.2500
Epoch: 7 loss: 0.6698 Train_acc: 0.7333
Epoch: 8 loss: 0.6685 Train_acc: 0.7333
Test_acc: 0.2500
Epoch: 9 loss: 0.6677 Train_acc: 0.7333
Epoch: 10 loss: 0.6672 Train_acc: 0.7333
Test_acc: 0.2500
