# Train a Simplicial Attention Network (SAN)

TODO

In [15]:
import numpy as np
import torch
from torch.nn.parameter import Parameter
from toponetx import SimplicialComplex
import toponetx.datasets.graph as graph
from torch_geometric.utils.convert import to_networkx

from topomodelx.nn.simplicial.san_layer import SANLayer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Pre-processing

## Import dataset ##

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

In [16]:
dataset = graph.karate_club(complex_type="simplicial")
print(dataset)

Simplicial Complex with shape [34, 78, 45, 11, 2] and dimension 4


## Define neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messages on the domain. In this case, we need the down and upper laplacians of the nodes, $L_d^1=B_1^TB_1$ and $L_u^1=B_2B_2^T$ respectively, both with dimensions $n_\text{edges} \times n_\text{edges}$. We also convert the neighborhood structures to torch tensors.

In [17]:
Ldown = torch.from_numpy(dataset.down_laplacian_matrix(rank=1).todense()).to_sparse()
Lup = torch.from_numpy(dataset.up_laplacian_matrix(rank=1).todense()).to_sparse()

## Import signal ##

The original task is node classification, but SAN work only with simplices of order higher or equal to 1 (edges, tetrahedrons, etc.). We will retrieve the input signal on both nodes and links, aggregate the information of nodes into link features and apply the SAN model on those resulting edge features. We will finally obtain the estimated node labels from the edge-level model output.

In [18]:
x_0 = []
for _, v in dataset.get_simplex_attributes("node_feat").items():
    x_0.append(v)
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

x_1 = []
for k, v in dataset.get_simplex_attributes("edge_feat").items():
    x_1.append(v)
x_1 = torch.tensor(np.stack(x_1))
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

x_2 = []
for k, v in dataset.get_simplex_attributes("face_feat").items():
    x_2.append(v)
x_2 = np.stack(x_2)
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")

There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.


In [19]:
in_channels = x_1.shape[-1]
hidden_channels = 16
out_channels = 2

## Define binary labels
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert the binary labels into one-hot encoder form, and keep the first four nodes' true labels for the purpose of testing.

In [20]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_test = y_true[:4]
y_train = y_true[-30:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# Create the Neural Network

Using the HSNLayer class, we create a neural network with stacked layers. A linear layer at the end produces an output with shape $n_\text{nodes} \times 2$, so we can compare with our binary labels.

In [21]:
class SAN(torch.nn.Module):
    r"""Simplicial Attention Network (SAN) implementation for binary edge classification.

    Parameters
    ----------
    in_channels : int
        Dimension of input features.
    hidden_channels : int
        Dimension of hidden features.
    out_channels : int
        Dimension of output features.
    num_filters_J : int, optional
        Approximation order for simplicial filters. Defaults to 2.
    J_har : int, optional
        Approximation order for harmonic convolution. Defaults to 5.
    epsilon_har : float, optional
        Epsilon value for harmonic convolution. Defaults to 1e-1.
    n_layers : int, optional
        Number of message passing layers. Defaults to 2.
    """

    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        num_filters_J=2,
        J_har=5,
        epsilon_har=1e-1,
        n_layers=2,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_filters_J = num_filters_J
        self.J_har = J_har
        self.epsilon_har = epsilon_har
        if n_layers == 1:
            self.layers = [
                SANLayer(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    num_filters_J=self.num_filters_J,
                )
            ]
        else:
            self.layers = [
                SANLayer(
                    in_channels=self.in_channels,
                    out_channels=self.hidden_channels,
                    num_filters_J=self.num_filters_J,
                )
            ]
            for _ in range(n_layers - 2):
                self.layers.append(
                    SANLayer(
                        in_channels=self.hidden_channels,
                        out_channels=self.hidden_channels,
                        num_filters_J=self.num_filters_J,
                    )
                )
            self.layers.append(
                SANLayer(
                    in_channels=self.hidden_channels,
                    out_channels=self.out_channels,
                    num_filters_J=self.num_filters_J,
                )
            )
        self.linear = torch.nn.Linear(out_channels, 2)

    def compute_projection_matrix(self, L):
        r"""Computation of the projection matrix which is then used
        to calculate the harmonic component in SAN layers.

        Parameters
        ---------

        L : tensor
            shape = [n_edges, n_edges]
            Hodge laplacian of rank 1.


        Returns
        --------
        _ : tensor
            shape = [n_edges, n_edges]
            Projection matrix.

        """
        P = torch.eye(L.shape[0]) - self.epsilon_har * L
        P = torch.linalg.matrix_power(P, self.J_har)
        return P

    def forward(self, x, Lup, Ldown):
        r"""Forward computation.

        Parameters
        ---------
        x : tensor
            shape = [n_nodes, channels_in]
            Node features.

        Lup : tensor
            shape = [n_edges, n_edges]
            Upper laplacian of rank 1.

        Ld : tensor
            shape = [n_edges, n_edges]
            Down laplacian of rank 1.


        Returns
        --------
        _ : tensor
            shape = [n_nodes, 2]
            One-hot labels assigned to edges.

        """
        # Compute the projection matrix for the harmonic component
        L = Lup + Ldown
        P = self.compute_projection_matrix(L)

        # Forward computation
        for layer in self.layers:
            x = layer(x, Lup, Ldown, P)
        return torch.sigmoid(self.linear(x))

# Train the Neural Network

The following cell performs the training, looping over the network for a low number of epochs.

In [22]:
model = SAN(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=1,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.4)
test_interval = 2
num_epochs = 5
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(x_1, Lup=Lup, Ldown=Ldown)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[-len(y_train) :].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred == y_hat).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(x_1, Lup=Lup, Ldown=Ldown)
            y_pred_test = torch.sigmoid(y_hat_test).ge(0.5).float()
            test_accuracy = (
                torch.eq(y_pred_test[: len(y_test)], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

NameError: name 'SANConv' is not defined

In [11]:
y_hat = model(x_1, Lup=Lup, Ldown=Ldown)

In [31]:
weight_irr = Parameter(torch.Tensor(J, channels_nodes, output_dim))
att_irr = Parameter(torch.Tensor(2 * att_slice, 1))

In [34]:
x_irr = torch.matmul(x_1, weight_irr).reshape(-1, J * output_dim)
x_irr.shape

torch.Size([78, 6])

In [36]:
(x_irr @ att_irr[:att_slice, :]).shape

torch.Size([78, 1])

In [39]:
(x_irr @ att_irr[:att_slice, :]).shape

torch.Size([78, 1])

In [44]:
(x_irr @ att_irr[att_slice:, :]).T.shape

torch.Size([1, 78])

In [79]:
e_irr = (x_irr @ att_irr[:att_slice, :]) + (x_irr @ att_irr[att_slice:, :]).T

In [97]:
(x_irr @ att_irr[:att_slice, :]).shape

torch.Size([78, 1])

In [102]:
(x_irr @ att_irr[att_slice:, :]).T.shape

torch.Size([1, 78])

In [98]:
torch.mm(x_irr, att_irr[:att_slice, :]).shape

torch.Size([78, 1])

In [101]:
torch.mm(x_irr, att_irr[att_slice:, :]).T.shape

torch.Size([1, 78])

In [95]:
alpha_irr = torch.sparse.softmax(e_irr.sparse_mask(Ldown), dim=1)

In [93]:
alpha_irr

tensor(indices=tensor([[ 0,  0,  0,  ..., 77, 77, 77],
                       [ 0,  1,  2,  ..., 75, 76, 77]]),
       values=tensor([0.0417, 0.0417, 0.0417,  ..., 0.0455, 0.0455, 0.0455]),
       size=(78, 78), nnz=1134, layout=torch.sparse_coo,
       grad_fn=<SparseSoftmaxBackward0>)

In [110]:
alpha_exp_irr = alpha_irr.unsqueeze(0)
for p in range(J - 1):
    alpha_exp_irr = torch.cat(
        [alpha_exp_irr, torch.mm(alpha_exp_irr[p], alpha_irr).unsqueeze(0)], dim=0
    )

In [111]:
alpha_exp_irr

tensor(indices=tensor([[ 0,  0,  0,  ...,  2,  2,  2],
                       [ 0,  0,  0,  ..., 77, 77, 77],
                       [ 0,  1,  2,  ..., 75, 76, 77]]),
       values=tensor([0.0417, 0.0417, 0.0417,  ..., 0.0376, 0.0368, 0.0476]),
       size=(3, 78, 78), nnz=10742, layout=torch.sparse_coo,
       grad_fn=<CatBackward0>)

In [115]:
x_irr = torch.matmul(x_1, weight_irr)
x_irr.shape

torch.Size([3, 78, 2])

In [126]:
torch.sum(torch.matmul(alpha_exp_irr.to_dense(), x_irr), dim=0).shape

torch.Size([78, 2])

In [78]:
A.sparse_mask(Ldown)

tensor(indices=tensor([[ 0,  0,  0,  ..., 77, 77, 77],
                       [ 0,  1,  2,  ..., 75, 76, 77]]),
       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),
       size=(78, 78), nnz=1134, layout=torch.sparse_coo)

In [135]:
layer = SANLayer(channels_in=channels_nodes, channels_out=output_dim, J=J)

In [137]:
layer(x_1, Lup, Ldown, Lup).shape

torch.Size([78, 2])

In [15]:
in_channels = channels_nodes
out_channels = output_dim
num_filters_J = 3

In [11]:
from topomodelx.base.aggregation import Aggregation
from topomodelx.base.conv import Conv

In [62]:
class SANConv(Conv):
    r"""Class for the SAN Convolution"""

    def __init__(
        self,
        in_channels,
        out_channels,
        p_filters,
        initialization="xavier_uniform",
    ):
        super(Conv, self).__init__(
            att=True,
            initialization=initialization,
        )
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.p_filters = p_filters
        self.initialization = initialization

        self.weight = Parameter(
            torch.Tensor(self.p_filters, self.in_channels, self.out_channels)
        )

        self.att_weight = Parameter(
            torch.Tensor(
                2 * self.out_channels * self.p_filters,
            )
        )

        self.reset_parameters()

    def forward(self, x_source, neighborhood):
        """Forward pass.

        This implements message passing:
        - from source cells with input features `x_source`,
        - via `neighborhood` defining where messages can pass,
        - to target cells, which are the same source cells.

        In practice, this will update the features on the target cells.

        If not provided, x_target is assumed to be x_source,
        i.e. source cells send messages to themselves.

        Parameters
        ----------
        x_source : Tensor, shape=[..., n_source_cells, in_channels]
            Input features on source cells.
            Assumes that all source cells have the same rank r.
        neighborhood : torch.sparse, shape=[n_target_cells, n_source_cells]
            Neighborhood matrix.

        Returns
        -------
        _ : Tensor, shape=[..., n_target_cells, out_channels]
            Output features on target cells.
            Assumes that all target cells have the same rank s.
        """
        x_message = torch.matmul(x_source, self.weight)
        # Reshape required to re-use the attention function of parent Conv class
        # -> [num_nodes, out_channels * p_filters]
        x_message_reshaped = x_message.permute(1, 0, 2).reshape(
            -1, self.out_channels * self.p_filters
        )

        # SAN always requires attention
        # In SAN, neighborhood is defined by lower/upper laplacians; we only use them as masks
        # to keep only the relevant attention coeffs
        neighborhood = neighborhood.coalesce()
        self.target_index_i, self.source_index_j = neighborhood.indices()
        attention_values = self.attention(x_message_reshaped)
        att_laplacian = torch.sparse_coo_tensor(
            indices=neighborhood.indices(),
            values=attention_values,
            size=neighborhood.shape,
        )

        # Attention coeffs are normalized using softmax
        att_laplacian = torch.sparse.softmax(att_laplacian, dim=1)
        # We need to compute the power of the attention laplacian according to the filter order p
        att_laplacian_power = torch.stack(
            [
                torch.linalg.matrix_power(att_laplacian, p + 1)
                if p > 1
                else att_laplacian
                for p in range(1, self.p_filters + 1)
            ]
        ).to_dense()

        # When computing the final message on targets, we need to compute the power of the attention laplacian
        # according to the filter order p
        x_message_on_target = torch.matmul(att_laplacian_power, x_message).sum(dim=0)

        return x_message_on_target

In [59]:
sanconv = SANConv(
    in_channels=in_channels, out_channels=out_channels, p_filters=num_filters_J
)

In [60]:
att_lap = sanconv(x_1, Lup)

In [61]:
att_lap.shape

torch.Size([78, 2])

In [198]:
convs_irr = [SANConv(in_channels, out_channels, p) for p in range(num_filters_J)]
convs_sol = [SANConv(in_channels, out_channels, p) for p in range(num_filters_J)]
conv_har = Conv(in_channels, out_channels)

In [199]:
z_irr = torch.stack([conv(x_1, Ldown) for conv in convs_irr]).sum(dim=0)
z_sol = torch.stack([conv(x_1, Lup) for conv in convs_sol]).sum(dim=0)
z_har = conv_har(x_1, P)

NameError: name 'P' is not defined

tensor([[ 0.0004,  0.0030],
        [ 0.0038, -0.0032],
        [-0.0014,  0.0211],
        [ 0.0010,  0.0031],
        [ 0.0013,  0.0038],
        [ 0.0002,  0.0057],
        [ 0.0195, -0.0203],
        [ 0.0034,  0.0041],
        [ 0.0016,  0.0020],
        [ 0.0013,  0.0027],
        [-0.0266,  0.0401],
        [ 0.0019,  0.0081],
        [ 0.0018,  0.0024],
        [ 0.0158, -0.0140],
        [ 0.0018,  0.0024],
        [-0.0006,  0.0011],
        [ 0.0018, -0.0033],
        [-0.0070,  0.0321],
        [ 0.0261, -0.0297],
        [-0.0020,  0.0137],
        [-0.0033,  0.0058],
        [ 0.0206, -0.0203],
        [-0.0033,  0.0058],
        [-0.0002,  0.0114],
        [-0.0040,  0.0236],
        [ 0.0268, -0.0354],
        [ 0.0043,  0.0022],
        [ 0.0066, -0.0129],
        [-0.0003, -0.0007],
        [-0.0097,  0.0384],
        [-0.0042,  0.0695],
        [-0.0027,  0.0474],
        [-0.0024,  0.0114],
        [-0.0029,  0.0138],
        [ 0.0015,  0.0017],
        [-0.0074,  0