In [1]:
import torch
import numpy as np
import topomodelx.nn as nn
from topomodelx.nn import MessagePassingConv
from toponetx import SimplicialComplex as sc

The first step is to define the topological domain on which the TNN will operate, as well as the neighborhod structures characterizing this domain. We will only define th eneighborhood matrices that we plan on using.

Here, we build a simple simplicial complex domain.

In [11]:
edge_set = [[1, 2], [1, 3]]
face_set = [[2, 3, 4], [2, 4, 5]]

domain = sc(edge_set + face_set)

Now we retrieve the boundary matrix (or incidence matrix) associated to the faces of this complex.

In [12]:
row, column, incidence_2 = domain.incidence_matrix(rank=2, index=True)
print(row)
print(column)
print(incidence_2)


[(1, 2), (1, 3), (2, 3), (2, 4), (2, 5), (3, 4), (4, 5)]
[(2, 3, 4), (2, 4, 5)]
  (5, 0)	1.0
  (3, 0)	-1.0
  (2, 0)	1.0
  (6, 1)	1.0
  (4, 1)	-1.0
  (3, 1)	1.0


For each rank, the signal on this domain will look like a matrix with shape n_cells_of_rank_r x in_channels, where in_channels is the dimension of each cell's feature. In a a heterogenous domain, in_channels will vary by rank.

In [13]:
domain.simplices

x_0 = torch.tensor([[1.0, 1.0], [2.0, 2.0], [1.0, 1.0], [4.0, 4.0], [2.0, 2.0]])
print(np.array(x_1.shape[1]))
x_1 = torch.tensor([[1.0, 1.0], [2.0, 2.0], [1.0, 1.0], [4.0, 4.0], [2.0, 2.0], [1.0, 1.0], [2.0, 2.0]])
x_2 = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
print(x_2.shape)

2
torch.Size([2, 2])


In [14]:
incidence_2_torch = torch.from_numpy(incidence_2.todense()).to_sparse()
print(incidence_2_torch)


tensor(indices=tensor([[2, 3, 3, 4, 5, 6],
                       [0, 0, 1, 1, 0, 1]]),
       values=tensor([ 1., -1.,  1., -1.,  1.,  1.]),
       size=(7, 2), nnz=6, layout=torch.sparse_coo)


In [16]:
channels_face = np.array(x_2.shape[1])
channels_edge = np.array(x_1.shape[1])
send_face_to_edge = MessagePassingConv(in_channels = channels_face, out_channels = channels_edge)
message_face_to_edge = send_face_to_edge(x_2, incidence_2_torch)

send_edge_to_face = MessagePassingConv(in_channels=channels_edge, out_channels=channels_face)
print(message_face_to_edge)


torch.Size([2, 2])
torch.Size([2, 2])
tensor([[ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [-0.1412, -1.8953],
        [-0.1412, -1.8953],
        [ 0.2825,  3.7905],
        [-0.1412, -1.8953],
        [-0.2825, -3.7905]], grad_fn=<MmBackward0>)
