In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [3]:
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

from stnets.layers import LTN, MergeOper, SplitOper, MultiMergeOper, MultiSplitOper
from stnets.layers.attention import HigherOrderAttentionLayer, SparseHigherOrderAttentionLayer
from stnets.layers.attention import MultiHeadHigherOrderAttention, SpMultiHeadHigherOrderAttention

from stnets.layers.attention import (
    MultiHeadHigherOrderAttentionClassifer,
    SpMultiHeadHigherOrderAttentionClassifer,
)

from stnets.topology import SimplicialComplex
from stnets.modality.mesh_to_simplicial_complex import (coo_2_torch_tensor, 
                                                        mesh_2_operators)

## Introduction to Simplicial Attention:

Given a simplical complex $\mathcal{X}$, the simplicial attention mechanism defines attention between two simplices $i$ and $j$ that are related via a simplicial structure map such as the boundary/k-Hodge Laplacian or the higher order adjacency. From this perspective, at first glance, simplicial attention seems a straightforward analogue of graph attention given. When the structure matrix is symmetric, such as the k-Hodge Laplacian, then the generalization is indeed relatively similar to the graph case. However, when considering the attention structure among simplices of different dimension, the attention mechanism becomes more complicated. Here we only define this difficult case.   

To this end, let $A:C^{s}(\mathcal{X})\to C^{t}(\mathcal{X})$ be a map. Assuming a fixed ordering of the simplices of the complex $\mathcal{X}$, we will denote to the matrix induced by $A$ also by the same notation. Assuming the matrix $A$ is asymmetric, the simplicial attention network induced by $A$ is a cochain map $SimplAttenNet_A(W_s,W_t):C^{s}(\mathcal{X},d_{s_{in}}) \times C^{t}(\mathcal{X},d_{t_{in}}) \to C^{t}(\mathcal{X},d_{t_{out}})$  defined via : 
\begin{equation}
(x_{s}^k,x_{t}^k) \to x_{t}^{k+1} =  \phi( (A \odot att^{k}) * x_{s}^k * W_s )
\end{equation}                
where $d_{in}$ is the feature dimension of input cochain $x_{i}^k$, $d_{out}$ dimension of the output cochain $x_{t}^{k+1}$, $W_s \in \mathbb{R}^{d_{s_{in}}\times d_{t_{out}}} $ ,$W_t \in \mathbb{R}^{d_{t_{in}}\times d_{s_{out}}} $ are trainable parameters and $att: C^{s}(\mathcal{X})\to C^{t}(\mathcal{X}) $ is a simplicial attention matrix that has the same dimension as the matrix $A$ and is defined via $e^{l}_{st}= LeakyRelu((a^{l})^T (W_s x_{s}^k||W_t x_{t}^k) )$ and 

\begin{equation}
att_{s,t}^l =  \frac{e_{st}^l}{ \sum_{k \in \mathcal{N}_A(s) e_{sk}^l } }
\end{equation} 

here $a^{l} \in \mathbb{R}^{s_{out}+t_{out}} $  is a trainable parameter and $\mathcal{N}_A(s)$ represent the neighboring structure of the simplex $s$ with respect to the matrix $A$. One may think about this structure as being the simplices $t$ that are nonzero in the column $s$ of the matrix $A$. 

Multiple things that one should notice about the above definition. First, the matrix $att$ is asymmetric and only the attention $att_{st}$ is recorded but not $att_{ts}$. The revesed attention matrix has to be computed seperatly. The reversed attention matrix has the same dimensions as the $A^T$, the tranpose of the operator $A$, and we will denote it by $\bar{att}$. With this, the update equation becomes :

\begin{equation}
(x_{s}^k,x_{t}^k) \to x_{s}^{k+1} =  \phi( (A^T \odot \bar{att}^{k}) * x_{t}^k * W_t )
\end{equation}  
Note that when we perform this computation care must be taken when computing $e^l_{ts}$. Namely,

$e^{l}_{ts}= LeakyRelu((rev(a^{l}))^T [W_t x_{t}^k||W_s x_{s}^k])$, where $rev(a^{l})= ( a^l[:t_{out}]||a^l[t_{out}:] ) $

Our implementation below outputs both $x_{s}^{k+1}$ and $x_{t}^{k+1}$ when the input operator is asymetric.

Now lets see how to define the above model using our package.


## Define a simple simplicial complex

Lets start by defining a very simple simplicial complex.

In [4]:
simplices = [(0, 1, 2), (1, 2, 3), (2, 3), (1, 2, 4), (5, 3), (0, 4)]

In [5]:
HL = SimplicialComplex(simplices)

N0 = len(HL.n_faces(0))
N1 = len(HL.n_faces(1))
N2 = len(HL.n_faces(2))


## Getting the operators : 
We can get the operators defined on the complex we started above easily using the following function :  

In [6]:

Adj0, Adj1, Coadj1, Coadj2, L0, L1, L2, B1, B2, B1T, B2T = mesh_2_operators(simplices,signed=False,
                                                                          norm_method=None,
                                                                           output_type="torch")



computing the boundary matrices..

computing the Hodge Laplacians matrices..



  self._set_arrayXarray(i, j, x)


## Define the data 

Lets generate some random data that lives on the complex.

In [7]:
# generate random input data
import torch

feature_input_space_dim = 3
batch_size = 10
x_v = torch.rand(N0, 3)  # cochain on the nodes
x_e = torch.rand(N1, 5)  # cochain on the edges
x_f = torch.rand(N2, 3)  # cochain on the faces

## Define the model

In our implementation of the network above, we did a dense and a sparse implementations. To see how this works, we imported the functions SimplicialAttentionLayer and SparseSimplicialAttentionLayer above. 


In [8]:
# define a model :

model_dense_arrays = HigherOrderAttentionLayer(
    source_in_features=5,
    target_in_features=3,
    source_out_features=20,
    target_out_features=20,
)

# infer :
out_e, _ = model_dense_arrays(x_e, None, L1.to_dense())

out_e.shape

# the above model works with dense matrix, so things might be slow when the complex get larger.

torch.Size([9, 20])

In [14]:
# sparse models are also supported :
model = SparseHigherOrderAttentionLayer(
    source_in_features=5,
    target_in_features=3,
    source_out_features=20,
    target_out_features=20,
)

In [15]:
# inferance is similar to before
v_out_s, out_e = model(
    hs=x_e, ht=x_v, operator_list=B1.coalesce().indices(), operator_symmetry=False
)

In [16]:
# in case the operator is symmetric, the second tensor should be empty in this case.
out_e_2, _ = model(
    hs=x_e, ht=None, operator_list=L1.coalesce().indices(), operator_symmetry=True
)

In [17]:
# multi head attention can also be defined in the same way

multihead_model = MultiHeadHigherOrderAttentionClassifer(
    source_in_features=5,
    target_in_features=3,
    source_out_features=20,
    target_out_features=20,
    num_heads=3,
    source_n_classes=3,
    target_n_classes=3,
)

In [18]:
mh_f_out, mh_e_out = multihead_model(x_e, x_f, B2.t().to_dense())
mh_e_out2, _ = multihead_model(x_e, None, L1.to_dense())

In [19]:
mh_e_out.shape

torch.Size([9, 3])

In [20]:
multihead_model_sp = SpMultiHeadHigherOrderAttentionClassifer(
    source_in_features=5,
    target_in_features=3,
    source_out_features=20,
    target_out_features=20,
    num_heads=5,
)

In [21]:
mh_v_out_sp, mh_e_out_sp = multihead_model_sp(
    x_e, x_f, B2.t().coalesce().indices(), False
)
mh_e_out2_sp, _ = multihead_model_sp(x_e, None, L1.coalesce().indices(), True)

In [22]:
mh_e_out2_sp.shape

torch.Size([9, 5])