In [1]:
import torch
import numpy as np
from toponetx import SimplicialComplex
from topomodelx.nn.simplicial.template_layer import TemplateLayer

# Create signal and domain

The first step is to define the topological domain on which the TNN will operate, as well as the neighborhod structures characterizing this domain. We will only define th eneighborhood matrices that we plan on using.

Here, we build a simple simplicial complex domain.

In [2]:
edge_set = [[1, 2], [1, 3]]
face_set = [[2, 3, 4], [2, 4, 5]]

domain = SimplicialComplex(edge_set + face_set)

Now we retrieve the boundary matrix (or incidence matrix) associated to the faces of this complex.

In [3]:
row, column, incidence_2 = domain.incidence_matrix(rank=2, index=True)
print(row)
print(column)
print(incidence_2)


[(1, 2), (1, 3), (2, 3), (2, 4), (2, 5), (3, 4), (4, 5)]
[(2, 3, 4), (2, 4, 5)]
  (5, 0)	1.0
  (3, 0)	-1.0
  (2, 0)	1.0
  (6, 1)	1.0
  (4, 1)	-1.0
  (3, 1)	1.0


For each rank, the signal on this domain will look like a matrix with shape n_cells_of_rank_r x in_channels, where in_channels is the dimension of each cell's feature. In a a heterogenous domain, in_channels will vary by rank.

In [4]:
domain.simplices

x_0 = torch.tensor([[1.0, 1.0], [2.0, 2.0], [1.0, 1.0], [4.0, 4.0], [2.0, 2.0]])
x_1 = torch.tensor([[1.0, 1.0], [2.0, 2.0], [1.0, 1.0], [4.0, 4.0], [2.0, 2.0], [1.0, 1.0], [2.0, 2.0]])
x_2 = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
print(x_2.shape)

torch.Size([2, 2])


In [5]:
incidence_2_torch = torch.from_numpy(incidence_2.todense()).to_sparse()
print(incidence_2_torch)

tensor(indices=tensor([[2, 3, 3, 4, 5, 6],
                       [0, 0, 1, 1, 0, 1]]),
       values=tensor([ 1., -1.,  1., -1.,  1.,  1.]),
       size=(7, 2), nnz=6, layout=torch.sparse_coo)


# Create Neural Network

Stack layers.

In [6]:
channels_face = np.array(x_2.shape[1])
channels_edge = np.array(x_1.shape[1])


In [7]:
class TemplateNN(torch.nn.Module):
    def __init__(self, 
                 channels_edge, 
                 channels_face, 
                 incidence_2_torch, 
                 n_layers=2):
        super().__init__()
        modules = []
        for _ in range(n_layers):
            modules.append(
                TemplateLayer(
                    in_channels=channels_face, 
                    intermediate_channels=channels_edge, 
                    out_channels=channels_face, 
                    incidence_matrix_2=incidence_2_torch
        ))
        self.sequential = torch.nn.Sequential(*modules)
        self.linear = torch.nn.Linear(channels_face, 1)

    def forward(self, x_faces):
        x_faces = self.sequential(x_faces)
        return self.linear(x_faces)

In [8]:
model = TemplateNN(channels_edge, channels_face, incidence_2_torch, n_layers=2)
faces_gt_labels = torch.Tensor([[0], [1]])  # (n_faces, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(5):
    optimizer.zero_grad()
    faces_pred_labels = model(x_2)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(faces_pred_labels, faces_gt_labels)
    loss.backward()
    optimizer.step()

defined messages
defined messages
started forward
started forward of message
got weighted x
got message
started norm
torch.Size([7, 2])
tensor(indices=tensor([[2, 3, 3, 4, 5, 6],
                       [0, 0, 1, 1, 0, 1]]),
       values=tensor([ 1., -1.,  1., -1.,  1.,  1.]),
       size=(7, 2), nnz=6, layout=torch.sparse_coo)
got nbh size
torch.Size([7])
torch.Size([7, 2])
did norm
did update
finished first message passing step
started forward of message
got weighted x
got message
started norm
torch.Size([2, 7])
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
                       [2, 3, 5, 3, 4, 6]]),
       values=tensor([ 1., -1.,  1.,  1., -1.,  1.]),
       size=(2, 7), nnz=6, layout=torch.sparse_coo)
got nbh size
torch.Size([2])
torch.Size([2, 2])
did norm
did update
started forward
started forward of message
got weighted x
got message
started norm
torch.Size([7, 2])
tensor(indices=tensor([[2, 3, 3, 4, 5, 6],
                       [0, 0, 1, 1, 0, 1]]),
       values=tensor([ 1., -