# Train a Combinatorial Complex Attention Neural Network for Mesh Classification.

We create and train a combinatorial complex attention neural network for mesh classification, introduced in [Figure 35(b), Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf).

### The Neural Network:

The neural network is composed of a sequence of identical attention layers for a dimension two combinatorial complex. Each layer is composed of two levels. In both levels, messages computed for the cells of identical dimension are aggregated using a sum operation. All the messages are computed using the attention mechanisms for squared and non-squared neighborhoods presented in [Definitions 31, 32, and 33, Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf). The following message passing scheme is followed in each of the levels for each layer:

1. First level:

🟥 $\quad m^{0\rightarrow 0}_{y\rightarrow x} = \phi\left(\left((A_{\uparrow, 0})_{xy} \cdot \text{att}_{xy}^{0\rightarrow 0}\right) h_y^{t,(0)} \Theta^t_{0\rightarrow 0}\right)$

🟥 $\quad m^{0\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((B_{1}^T)_{xy} \cdot \text{att}_{xy}^{0\rightarrow 1}\right) h_y^{t,(0)} \Theta^t_{0\rightarrow 1}\right)$

🟥 $\quad  m^{1\rightarrow 0}_{y\rightarrow x} = \phi\left(\left((B_{1})_{xy} \cdot \text{att}_{xy}^{1\rightarrow 0}\right) h_y^{t,(1)} \Theta^t_{1\rightarrow 0}\right)$

🟥 $\quad  m^{1\rightarrow 2}_{y\rightarrow x} = \phi\left(\left((B_{2}^T)_{xy} \cdot \text{att}_{xy}^{1\rightarrow 2}\right) h_y^{t,(1)} \Theta^t_{1\rightarrow 2}\right)$

🟥 $\quad m^{2\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((B_{2})_{xy} \cdot \text{att}_{xy}^{2\rightarrow 1}\right) h_y^{t,(2)} \Theta^t_{2\rightarrow 1}\right)$

🟧 $\quad m^{0\rightarrow 0}_{x}=\sum_{y\in A_{\uparrow, 0}(x)} m^{0\rightarrow 0}_{y\rightarrow x}$

🟧 $\quad m^{0\rightarrow 1}_{x}=\sum_{y\in B_{1}^T(x)} m^{0\rightarrow 1}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 0}_{x}=\sum_{y\in B_{1}(x)} m^{1\rightarrow 0}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 2}_{x}=\sum_{y\in B_{2}^T(x)} m^{1\rightarrow 2}_{y\rightarrow x}$

🟧 $\quad m^{2\rightarrow 1}_{x}=\sum_{y\in B_{2}(x)} m^{2\rightarrow 1}_{y\rightarrow x}$

🟩 $\quad m_x^{(0)}=m^{0\rightarrow 0}_{x}+m^{1\rightarrow 0}_{x}$

🟩 $\quad m_x^{(1)}=m^{0\rightarrow 1}_{x}+m^{2\rightarrow 1}_{x}$

🟩 $\quad m_x^{(2)}=m^{1\rightarrow 2}_{x}$

🟦 $\quad i_x^{t,(0)} = m_x^{(0)}$

🟦 $\quad i_x^{t,(1)} = m_x^{(1)}$

🟦 $\quad i_x^{t,(2)} = m_x^{(2)}$

where $i_x^{t,(\cdot)}$ represents intermediate feature vectors.


2. Second level:


🟥 $\quad m^{0\rightarrow 0}_{y\rightarrow x} = \phi\left(\left((A_{\uparrow, 0})_{xy} \cdot \text{att}_{xy}^{0\rightarrow 0}\right) i_y^{t,(0)} \Theta^t_{0\rightarrow 0}\right)$

🟥 $\quad m^{1\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((A_{\uparrow, 1})_{xy} \cdot \text{att}_{xy}^{1\rightarrow 1}\right) i_y^{t,(1)} \Theta^t_{1\rightarrow 1}\right)$

🟥 $\quad m^{2\rightarrow 2}_{y\rightarrow x} = \phi\left(\left((A_{\downarrow, 2})_{xy} \cdot \text{att}_{xy}^{2\rightarrow 2}\right) i_y^{t,(2)} \Theta^t_{2\rightarrow 2}\right)$

🟥 $\quad m^{0\rightarrow 1}_{y\rightarrow x} = \phi\left(\left((B_{1}^T)_{xy} \cdot \text{att}_{xy}^{0\rightarrow 1}\right) i_y^{t,(0)} \Theta^t_{0\rightarrow 1}\right)$

🟥 $\quad m^{1\rightarrow 2}_{y\rightarrow x} = \phi\left(\left((B_{2}^T)_{xy} \cdot \text{att}_{xy}^{1\rightarrow 2}\right) i_y^{t,(1)} \Theta^t_{1\rightarrow 2}\right)$

🟧 $\quad m^{0\rightarrow 0}_{x}=\sum_{y\in A_{\uparrow, 0}(x)} m^{0\rightarrow 0}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 1}_{x}=\sum_{y\in A_{\uparrow, 1}(x)} m^{1\rightarrow 1}_{y\rightarrow x}$

🟧 $\quad m^{2\rightarrow 2}_{x}=\sum_{y\in A_{\downarrow, 2}(x)} m^{2\rightarrow 2}_{y\rightarrow x}$

🟧 $\quad m^{0\rightarrow 1}_{x}=\sum_{y\in B_{1}^T(x)} m^{0\rightarrow 1}_{y\rightarrow x}$

🟧 $\quad m^{1\rightarrow 2}_{x}=\sum_{y\in B_{2}^T(x)} m^{1\rightarrow 2}_{y\rightarrow x}$

🟩 $\quad m_x^{(0)}=m^{0\rightarrow 0}_{x}+m^{1\rightarrow 0}_{x}$

🟩 $\quad m_x^{(1)}=m^{1\rightarrow 1}_{x} + m^{0\rightarrow 1}_{x}$

🟩 $\quad m_x^{(2)}=m^{1\rightarrow 2}_{x} + m^{2\rightarrow 2}_{x}$

🟦 $\quad h_x^{t+1,(0)} = m_x^{(0)}$

🟦 $\quad h_x^{t+1,(1)} = m_x^{(1)}$

🟦 $\quad h_x^{t+1,(2)} = m_x^{(2)}$

In both message passing levels, $\phi$ represents a common activation function. Also, $\Theta$ and $\text{att}$ represent learnable weights and attention matrices, respectively, that are different in each level. Attention matrices are introduced in [Figure 35(b), Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023)](https://arxiv.org/pdf/2206.00606.pdf). In this case, attention matrices are computed using the LeakyReLU activation function, as in previous versions of the paper.

Notations, adjacency, coadjacency, and incidence matrices are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031). The tensor diagram for the layer can be found in the first column and last row of Figure 11, from the same paper.

### The Task:

We train this model to perform entire complex classification on [`MUTAG` from the TUDataset](https://paperswithcode.com/dataset/mutag). This dataset contains:
- 188 samples of chemical compounds represented as graphs,
- with 7 discrete node features.

The task is to predict the mutagenicity of each compound on Salmonella typhimurium.

# Set-up


In [3]:
import torch
import numpy as np
import torch.nn as nn
from toponetx import CombinatorialComplex
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from toponetx.datasets.mesh import shrec_16
from sklearn.model_selection import train_test_split
from scipy.sparse import csr_matrix
from topomodelx.nn.combinatorial.hmc_layer import HMCLayer

ModuleNotFoundError: No module named 'topomodelx'

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

We import a subset of MUTAG, a benchmark dataset for graph classification.

We then lift each graph into our topological domain of choice, here: a cell complex.

We also retrieve:
- input signals `x_0` and `x_1` on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model's inputs,
- a binary classification label `y` associated to the cell complex.

In [5]:
shrec_training, shrec_testing = shrec_16()

downloading shrec 16 full dataset...

done!
Loading shrec 16 full dataset...

done!


In [6]:
class SHRECDataset(Dataset):
    def __init__(self, data):
        self.complexes = [cc.to_combinatorial_complex() for cc in data["complexes"]]
        self.x_0 = data["node_feat"]
        self.x_1 = data["edge_feat"]
        self.x_2 = data["face_feat"]
        self.y = data["label"]
        self.a0, self.a1, self.coa2, self.b1, self.b2 = self._get_neighborhood_matrix()

    def _get_neighborhood_matrix(self):
        
        a0 = []
        a1 = []
        coa2 = []
        b1 = []
        b2 = []
        
        for cc in self.complexes:

            a0.append(torch.from_numpy(cc.adjacency_matrix(0,1).todense()).to_sparse())
            a1.append(torch.from_numpy(cc.adjacency_matrix(1,2).todense()).to_sparse())
            
            B = cc.incidence_matrix(rank=2,to_rank=1)
            A = B.T @ B
            A.setdiag(0)
            coa2.append(torch.from_numpy(A.todense()).to_sparse())
            
            b1.append(torch.from_numpy(cc.incidence_matrix(1,0).todense()).to_sparse())
            b2.append(torch.from_numpy(cc.incidence_matrix(2,1).todense()).to_sparse())
            
        return a0, a1, coa2, b1, b2

    def num_classes(self):
        return len(np.unique(self.y))
    
    def channels_dim(self):
        return self.x_0[0].shape[1], self.x_1[0].shape[1], self.x_2[0].shape[1]
    
    def __len__(self):
        return len(self.complexes)

    def __getitem__(self, idx):
        return self.x_0[idx], self.x_1[idx], self.x_2[idx], self.y[idx], self.a0[idx], self.a1[idx], self.coa2[idx], self.b1[idx], self.b2[idx]

downloading dataset...



FileNotFoundError: [Errno 2] No such file or directory

In [210]:
"""
Node features:
    - Position
    - Normal
"""

training_node_feat[0]

array([[ 0.567542  ,  0.570995  , -0.210023  , -0.06894332,  0.72564355,
         0.6846081 ],
       [ 0.603908  ,  0.557968  , -0.223298  ,  0.91155493,  0.37192256,
         0.17533174],
       [ 0.591494  ,  0.464737  ,  0.350567  ,  0.33390166,  0.55019138,
         0.76537515],
       ...,
       [-0.331773  , -0.33001   ,  0.049767  , -0.28529544, -0.95405566,
        -0.09156586],
       [-0.401883  , -0.263047  ,  0.136766  , -0.66302596, -0.62547973,
         0.41130485],
       [ 0.261249  ,  0.128185  ,  0.304997  ,  0.7728916 , -0.44845904,
         0.44891321]])

In [152]:
"""
Edge features:
    - Dihedral angle
    - Edge span
    - 2 edge angle in the triangle
    - 6 edge ratios
"""

training_edge_feat[0]

In [89]:
"""
Face features:
    - Face area
    - Face normal
    - Face angle
"""

training_face_feat[0]

tensor(indices=tensor([[  0,   0,   0,  ..., 499, 499, 499],
                       [  1,   2,   4,  ..., 465, 478, 495]]),
       values=tensor([1, 1, 1,  ..., 1, 1, 1]),
       size=(500, 500), nnz=1500, layout=torch.sparse_coo)


In [68]:
# training_complexes[0].to_trimesh().show()

## Define neighborhood structures. ##

Implementing the CCXN architecture will require to perform message passing along neighborhood structures of the cell complexes.

Thus, now we retrieve these neighborhood structures (i.e. their representative matrices) that we will use to send messages. 

For the CCXN, we need the adjacency matrix $A_{\uparrow, 0}$ and the coboundary matrix $B_2^T$ of each cell complex.

In [70]:
adjacency_0_train_list = []
adjacency_1_train_list = []
coadjacency_2_train_list = []

incidence_1_train_list = []
incidence_2_train_list = []

for cc in training_complexes:

    adjacency_0_train = torch.from_numpy(cc.adjacency_matrix(0,1).todense()).to_sparse()
    adjacency_1_train = torch.from_numpy(cc.adjacency_matrix(1,2).todense()).to_sparse()
    B = cc.incidence_matrix(rank=2,to_rank=1,incidence_type="down", sparse=True)
    A = csr_matrix(B.T) @ csr_matrix(B)
    coadjacency_2_train = torch.from_numpy(A.todense()).to_sparse()

    adjacency_0_train_list.append(adjacency_0_train)
    adjacency_1_train_list.append(adjacency_1_train)
    coadjacency_2_train_list.append(coadjacency_2_train)

    incidence_1_train = torch.from_numpy(
        cc.incidence_matrix(1, 0).todense()
    ).to_sparse()
    incidence_2_train = torch.from_numpy(
        cc.incidence_matrix(2, 1).todense()
    ).to_sparse()

    incidence_1_train_list.append(incidence_1_train)
    incidence_2_train_list.append(incidence_2_train)

In [19]:
# training_complexes[0].coadjacency_matrix(rank=2,via_rank=1).todense()
from scipy.sparse import csr_matrix
B = training_complexes[0].incidence_matrix(rank=2,to_rank=1,incidence_type="down", sparse=True)
print(B.shape)
A = csr_matrix(B.T) @ csr_matrix(B)
print(A.shape)


Adjacency_0 of the 0-th complex: torch.Size([252, 252]).
Adjacency_1 of the 0-th complex: torch.Size([750, 750]).
Coadjacency_2 of the 0-th complex: torch.Size([500, 500]).
Incidence_1 of the 0-th complex: torch.Size([252, 750]).
Incidence_2 of the 0-th complex: torch.Size([750, 500]).


In [71]:
i_cc = 0
print(f"Adjacency_0 of the {i_cc}-th complex: {adjacency_0_train_list[i_cc].shape}.")
print(f"Adjacency_1 of the {i_cc}-th complex: {adjacency_1_train_list[i_cc].shape}.")
print(
    f"Coadjacency_2 of the {i_cc}-th complex: {coadjacency_2_train_list[i_cc].shape}."
)
print(f"Incidence_1 of the {i_cc}-th complex: {incidence_1_train_list[i_cc].shape}.")
print(f"Incidence_2 of the {i_cc}-th complex: {incidence_2_train_list[i_cc].shape}.")

In [122]:
adjacency_0_test_list = []
adjacency_1_test_list = []
coadjacency_2_test_list = []
incidence_1_test_list = []
incidence_2_test_list = []

for cc in testing_complexes:

    adjacency_0_test = torch.from_numpy(cc.adjacency_matrix(0,1).todense()).to_sparse()
    adjacency_1_test = torch.from_numpy(cc.adjacency_matrix(1,2).todense()).to_sparse()
    B = cc.incidence_matrix(rank=2,to_rank=1,incidence_type="down", sparse=True)
    A = csr_matrix(B.T) @ csr_matrix(B)
    coadjacency_2_test = torch.from_numpy(A.todense()).to_sparse()
    #torch.from_numpy(cc.coadjacency_matrix(2,1).todense()).to_sparse()

    adjacency_0_test_list.append(adjacency_0_test)
    adjacency_1_test_list.append(adjacency_1_test)
    coadjacency_2_test_list.append(coadjacency_2_test)

    incidence_1_test = torch.from_numpy(cc.incidence_matrix(1, 0).todense()).to_sparse()
    incidence_2_test = torch.from_numpy(cc.incidence_matrix(2, 1).todense()).to_sparse()

    incidence_1_test_list.append(incidence_1_test)
    incidence_2_test_list.append(incidence_2_test)

The dimension of input features on nodes, edges and faces are: 6, 10 and 7.


In [143]:
i_cc = 0
print(f"Adjacency_0 of the {i_cc}-th complex: {adjacency_0_test_list[i_cc].shape}.")
print(f"Coadjacency_2 of the {i_cc}-th complex: {coadjacency_2_test_list[i_cc].shape}.")
print(f"Adjacency_1 of the {i_cc}-th complex: {adjacency_1_test_list[i_cc].shape}.")
print(f"Incidence_1 of the {i_cc}-th complex: {incidence_1_test_list[i_cc].shape}.")
print(f"Incidence_2 of the {i_cc}-th complex: {incidence_2_test_list[i_cc].shape}.")

tensor([[ 0.8908,  0.0882, -0.1654,  0.8797,  0.4745,  0.0311],
        [ 0.8642,  0.0510, -0.2256,  0.3513, -0.1276, -0.9275],
        [ 0.8768,  0.0145, -0.1483,  0.7101, -0.6908,  0.1361],
        ...,
        [-0.8155,  0.2695,  0.0032, -0.3205,  0.9301, -0.1793],
        [-0.7604,  0.1678,  0.1526, -0.8972, -0.1857,  0.4007],
        [-0.8403,  0.2445, -0.0152, -0.9430,  0.3155,  0.1055]])
tensor([[1.1074, 0.1018, 1.2133,  ..., 0.6742, 1.3744, 1.0898],
        [0.8857, 0.1143, 1.1985,  ..., 1.2117, 1.3744, 1.0531],
        [1.0356, 0.1516, 0.6496,  ..., 1.6172, 2.4451, 3.7761],
        ...,
        [0.1432, 0.2407, 1.5333,  ..., 0.6164, 3.0104, 0.3801],
        [0.0089, 0.1163, 1.9122,  ..., 1.1152, 0.4290, 6.7573],
        [1.0388, 0.1344, 2.1152,  ..., 2.7802, 2.8356, 1.1152]],
       dtype=torch.float64)
tensor([[ 1.0507e-03,  9.3395e-01, -2.3948e-01,  ...,  1.2039e+00,
          9.5637e-01,  9.8130e-01],
        [ 6.8416e-04,  4.7123e-01,  6.4168e-01,  ...,  1.2788e+00,
      

# Create the Neural Network

Using the CCXNLayer class, we create a neural network with stacked layers.

In [197]:
d_0 = training_node_feat[0].shape[-1]
d_1 = training_edge_feat[0].shape[-1]
d_2 = training_face_feat[0].shape[-1]

in_channels = [d_0, d_1, d_2]

print(
    f"The dimension of input features on nodes, edges and faces are: {d_0}, {d_1} and {d_2}."
)

In [125]:
class HoanMeshClassifier(torch.nn.Module):
    """HoanMeshClassifier.

    Parameters
    ----------
    in_channels : List[int]
        Dimension of input features on nodes, edges and faces respectively.
    intermediate_channels : List[int]
        Dimension of intermediate features on nodes, edges and faces respectively.
    out_channels : List[int]
        Dimension of output features on nodes, edges and faces respectively.
    num_classes : int
        Number of classes.
    n_layers : int
        Number of CCXN layers.
    """

    def __init__(
        self,
        in_channels,
        intermediate_channels,
        out_channels,
        num_classes,
        negative_slope=0.2,
        n_layers=1,
    ) -> None:

        super().__init__()
        self.num_classes = num_classes
        self.layer = HMCLayer(
            in_channels=in_channels,
            intermediate_channels=intermediate_channels,
            out_channels=out_channels,
            negative_slope=negative_slope,
        )
        """self.layers = torch.nn.ModuleList(
            HMCLayer(
                in_channels=in_channels,
                intermediate_channels=intermediate_channels,
                out_channels=out_channels,
                negative_slope=negative_slope
            ) for _ in range(n_layers)
        )"""

        self.l0 = torch.nn.Linear(out_channels[0], num_classes)
        self.l1 = torch.nn.Linear(out_channels[1], num_classes)
        self.l2 = torch.nn.Linear(out_channels[2], num_classes)

    def forward(
        self,
        x_0,
        x_1,
        x_2,
        neighborhood_0_to_0,
        neighborhood_1_to_1,
        neighborhood_2_to_2,
        neighborhood_0_to_1,
        neighborhood_1_to_2,
    ) -> torch.Tensor:

        x_0, x_1, x_2 = self.layer(
            x_0,
            x_1,
            x_2,
            neighborhood_0_to_0,
            neighborhood_1_to_1,
            neighborhood_2_to_2,
            neighborhood_0_to_1,
            neighborhood_1_to_2,
        )
        """
        for layer in self.layers:
            x_0, x_1, x_2 = layer(x_0,
                x_1,
                x_2,
                neighborhood_0_to_0,
                neighborhood_1_to_1,
                neighborhood_2_to_2,
                neighborhood_0_to_1,
                neighborhood_1_to_2
                                  )"""

        x_0 = self.l0(x_0)
        x_1 = self.l1(x_1)
        x_2 = self.l2(x_2)

        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
        # Return the sum of the averages
        return (
            two_dimensional_cells_mean
            + one_dimensional_cells_mean
            + zero_dimensional_cells_mean
        )

tensor([[ 0.8690, -0.0219, -0.0431,  0.5085,  0.6664, -0.5453],
        [ 0.9026, -0.0730,  0.0168,  0.9621,  0.0354,  0.2704],
        [ 0.8901, -0.0837, -0.0672,  0.7793,  0.2262, -0.5844],
        ...,
        [-0.8531, -0.0409,  0.1088, -0.9241, -0.1282,  0.3600],
        [-0.8517,  0.0935, -0.0554, -0.9194,  0.3311, -0.2122],
        [-0.8742,  0.0140,  0.0262, -0.9966,  0.0628,  0.0539]],
       dtype=torch.float64)


# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

In [199]:
num_classes = np.unique(training_labels).shape[0]
model = HoanMeshClassifier(
    in_channels,
    in_channels,
    in_channels,
    negative_slope=0.2,
    num_classes=num_classes,
    n_layers=1,
)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)

tensor([[-0.5280, -0.5312,  0.0055, -0.9404, -0.2801,  0.1930],
        [-0.4962, -0.5500, -0.0275, -0.2368, -0.7371, -0.6329],
        [-0.4748, -0.5537,  0.0069,  0.2126, -0.9430,  0.2561],
        ...,
        [ 0.7673,  0.4709,  0.0332,  0.5263,  0.7800,  0.3387],
        [ 0.7886,  0.4498, -0.0085,  0.9372,  0.3407, -0.0749],
        [ 0.7647,  0.4704, -0.0378,  0.4602,  0.7919, -0.4014]],
       dtype=torch.float64)


We train the HoanMeshClassifier using low amount of epochs: we keep training minimal for the purpose of rapid testing.

In [200]:
test_interval = 25
num_epochs = 500
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for (
        x_0,
        x_1,
        x_2,
        adjacency_0,
        adjacency_1,
        coadjacency_2,
        incidence_1,
        incidence_2,
        y,
    ) in zip(
        training_node_feat,
        training_edge_feat,
        training_face_feat,
        adjacency_0_train_list,
        adjacency_1_train_list,
        coadjacency_2_train_list,
        incidence_1_train_list,
        incidence_2_train_list,
        training_labels,
    ):
        x_0, x_1, x_2 = (
            torch.tensor(x_0, dtype=torch.float).to(device),
            torch.tensor(x_1, dtype=torch.float).to(device),
            torch.tensor(x_2, dtype=torch.float).to(device),
        )
        y = torch.tensor(y, dtype=torch.long).to(device)
        adjacency_0, adjacency_1, coadjacency_2 = (
            adjacency_0.float().to(device),
            adjacency_1.float().to(device),
            coadjacency_2.float().to(device),
        )
        incidence_1, incidence_2 = incidence_1.float().to(
            device
        ), incidence_2.float().to(device)
        opt.zero_grad()
        y_hat = model.forward(
            x_0,
            x_1,
            x_2,
            adjacency_0,
            adjacency_1,
            coadjacency_2,
            incidence_1,
            incidence_2,
        )
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, x_2, adjacency_0, adjacency_1, coadjacency_2, incidence_1, incidence_2, y in zip(
                testing_node_feat, testing_edge_feat, testing_face_feat, adjacency_0_test_list, adjacency_1_test_list, coadjacency_2_test_list, incidence_1_test_list, incidence_2_test_list, testing_labels
            ):
                x_0, x_1, x_2 = (
                    torch.tensor(x_0, dtype=torch.float).to(device),
                    torch.tensor(x_1, dtype=torch.float).to(device),
                    torch.tensor(x_2, dtype=torch.float).to(device),
                )
                y = torch.tensor(y, dtype=torch.long).to(device)
                adjacency_0, adjacency_1, coadjacency_2 = (
                    adjacency_0.float().to(device),
                    adjacency_1.float().to(device),
                    coadjacency_2.float().to(device),
                )
                incidence_1, incidence_2 = incidence_1.float().to(
                    device
                ), incidence_2.float().to(device)
                opt.zero_grad()
                y_hat = model.forward(x_0, x_1, x_2, adjacency_0, adjacency_1, coadjacency_2, incidence_1, incidence_2)
                loss = crit(y_hat, y)
                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

Epoch: 1 loss: 654795.3756 Train_acc: 0.0250
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/Users/manuellecha/miniconda3/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/ll/bd2dgg_50_52khsw_lrcpvwc0000gn/T/ipykernel_63000/3112448459.py", line 38, in <module>
    y_hat = model.forward(
  File "/var/folders/ll/bd2dgg_50_52khsw_lrcpvwc0000gn/T/ipykernel_63000/4123705235.py", line 58, in forward
    x_0, x_1, x_2 = layer(x_0,
  File "/Users/manuellecha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/manuellecha/PycharmProjects/TopoModelX_UBTeam/topomodelx/nn/combinatorial/hmc_layer.py", line 175, in forward
    )
  File "/Users/manuellecha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/manuellecha/PycharmProjects/TopoModelX_

30
