# Train a Simplicial Attention Network (SAN)

TODO

In [38]:
import numpy as np
import torch
from torch.nn.parameter import Parameter
from toponetx import SimplicialComplex
import toponetx.datasets.graph as graph
from torch_geometric.utils.convert import to_networkx

from topomodelx.nn.simplicial.san_layer import SANLayer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Pre-processing

## Import dataset ##

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

In [40]:
dataset = graph.karate_club(complex_type="simplicial")
print(dataset)

Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4


In [41]:
dataset.shape

(34, 78, 45, 11, 2)

## Define neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messages on the domain. In this case, we need the down and upper laplacians of the nodes, $L_d^1=B_1^TB_1$ and $L_u^1=B_2B_2^T$ respectively, both with dimensions $n_\text{edges} \times n_\text{edges}$. We also convert the neighborhood structures to torch tensors.

### Note:
We can generalize to arbitrary _k_-simplex order; we just need to compute the _k_-th order down and upper laplacians. There are though two particular scenarios:
- When we want to work with 0-simplices (nodes), as there is no down laplacian; in this case, we just initialize the down laplacian as a 0-matrix, and SAN automatically becomes a GAT-based architecture.
- When dealing with the higher simplex dimension of the dataset (_k_=4 in Karate Club); just the opposite case of nodes, where now the upper laplacian cannot be computed. In this case we can also initialize it as a 0 matrix and SAN will only consider the lower boundary info for  those _k_-simplices.

In [58]:
simplex_order_k = 1
# Down laplacian
try:
    Ldown = torch.from_numpy(
        dataset.down_laplacian_matrix(rank=simplex_order_k).todense()
    ).to_sparse()
except ValueError:
    Ldown = torch.zeros(
        (dataset.shape[simplex_order_k], dataset.shape[simplex_order_k])
    ).to_sparse()
# Up laplacian
try:
    Lup = torch.from_numpy(
        dataset.up_laplacian_matrix(rank=simplex_order_k).todense()
    ).to_sparse()
except ValueError:
    Lup = torch.zeros(
        (dataset.shape[simplex_order_k], dataset.shape[simplex_order_k])
    ).to_sparse()

## Import signal ##

The original task is node classification, but SAN work only with simplices of order higher or equal to 1 (edges, tetrahedrons, etc.). We will retrieve the input signal on both nodes and links, aggregate the information of nodes into link features and apply the SAN model on those resulting edge features. We will finally obtain the estimated node labels from the edge-level model output.

In [59]:
x_0 = []
for _, v in dataset.get_simplex_attributes("node_feat").items():
    x_0.append(v)
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

x_1 = []
for k, v in dataset.get_simplex_attributes("edge_feat").items():
    x_1.append(v)
x_1 = torch.tensor(np.stack(x_1))
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

x_2 = []
for k, v in dataset.get_simplex_attributes("face_feat").items():
    x_2.append(v)
x_2 = np.stack(x_2)
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")

There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.


In [60]:
in_channels = x_1.shape[-1]
hidden_channels = 16
out_channels = 2

## Define binary labels
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert the binary labels into one-hot encoder form, and keep the first four nodes' true labels for the purpose of testing.

In [61]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_test = y_true[:4]
y_train = y_true[-30:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# Create the Neural Network

Using the HSNLayer class, we create a neural network with stacked layers. A linear layer at the end produces an output with shape $n_\text{nodes} \times 2$, so we can compare with our binary labels.

In [62]:
class SAN(torch.nn.Module):
    r"""Simplicial Attention Network (SAN) implementation for binary edge classification.

    Parameters
    ----------
    in_channels : int
        Dimension of input features.
    hidden_channels : int
        Dimension of hidden features.
    out_channels : int
        Dimension of output features.
    simplex_order_k : int
        Order r of the considered simplices. Default to 1 (edges).
    num_filters_J : int, optional
        Approximation order for simplicial filters. Defaults to 2.
    J_har : int, optional
        Approximation order for harmonic convolution. Defaults to 5.
    epsilon_har : float, optional
        Epsilon value for harmonic convolution. Defaults to 1e-1.
    n_layers : int, optional
        Number of message passing layers. Defaults to 2.
    """

    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        num_filters_J=2,
        J_har=5,
        epsilon_har=1e-1,
        n_layers=2,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_filters_J = num_filters_J
        self.J_har = J_har
        self.epsilon_har = epsilon_har
        if n_layers == 1:
            self.layers = [
                SANLayer(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    num_filters_J=self.num_filters_J,
                )
            ]
        else:
            self.layers = [
                SANLayer(
                    in_channels=self.in_channels,
                    out_channels=self.hidden_channels,
                    num_filters_J=self.num_filters_J,
                )
            ]
            for _ in range(n_layers - 2):
                self.layers.append(
                    SANLayer(
                        in_channels=self.hidden_channels,
                        out_channels=self.hidden_channels,
                        num_filters_J=self.num_filters_J,
                    )
                )
            self.layers.append(
                SANLayer(
                    in_channels=self.hidden_channels,
                    out_channels=self.out_channels,
                    num_filters_J=self.num_filters_J,
                )
            )
        self.linear = torch.nn.Linear(out_channels, 2)

    def compute_projection_matrix(self, L):
        r"""Computation of the projection matrix which is then used
        to calculate the harmonic component in SAN layers.

        Parameters
        ---------

        L : tensor
            shape = [n_edges, n_edges]
            Hodge laplacian of rank 1.


        Returns
        --------
        _ : tensor
            shape = [n_edges, n_edges]
            Projection matrix.

        """
        P = torch.eye(L.shape[0]) - self.epsilon_har * L
        P = torch.linalg.matrix_power(P, self.J_har)
        return P

    def forward(self, x, Lup, Ldown):
        r"""Forward computation.

        Parameters
        ---------
        x : tensor
            shape = [n_nodes, channels_in]
            Node features.

        Lup : tensor
            shape = [n_edges, n_edges]
            Upper laplacian of rank 1.

        Ld : tensor
            shape = [n_edges, n_edges]
            Down laplacian of rank 1.


        Returns
        --------
        _ : tensor
            shape = [n_nodes, 2]
            One-hot labels assigned to edges.

        """
        # Compute the projection matrix for the harmonic component
        L = Lup + Ldown
        P = self.compute_projection_matrix(L)

        # Forward computation
        for layer in self.layers:
            x = layer(x, Lup, Ldown, P)
        return torch.sigmoid(self.linear(x))

In [63]:
Lup.shape, Ldown.shape, x_1.shape

(torch.Size([78, 78]), torch.Size([78, 78]), torch.Size([78, 2]))

# Train the Neural Network

The following cell performs the training, looping over the network for a low number of epochs.

In [64]:
model = SAN(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=1,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.4)
test_interval = 2
num_epochs = 10
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(x_1, Lup=Lup, Ldown=Ldown)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[-len(y_train) :].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred == y_hat).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(x_1, Lup=Lup, Ldown=Ldown)
            y_pred_test = torch.sigmoid(y_hat_test).ge(0.5).float()
            test_accuracy = (
                torch.eq(y_pred_test[: len(y_test)], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

Epoch: 1 loss: 0.7296 Train_acc: 0.0000
Epoch: 2 loss: 0.7156 Train_acc: 0.0000
Test_acc: 0.0000
Epoch: 3 loss: 0.7045 Train_acc: 0.0000
Epoch: 4 loss: 0.6968 Train_acc: 0.0000
Test_acc: 0.0000
Epoch: 5 loss: 0.6923 Train_acc: 0.0000
Epoch: 6 loss: 0.6898 Train_acc: 0.0000
Test_acc: 0.0000
Epoch: 7 loss: 0.6884 Train_acc: 0.0000
Epoch: 8 loss: 0.6876 Train_acc: 0.0000
Test_acc: 0.0000
Epoch: 9 loss: 0.6869 Train_acc: 0.0000
Epoch: 10 loss: 0.6862 Train_acc: 0.0000
Test_acc: 0.0000
