# Train a Simplicial Convolutional Network (SCN)

In [6]:
import torch
import numpy as np
import toponetx.datasets as datasets

from sklearn.model_selection import train_test_split
from topomodelx.nn.simplicial.scn_layer import SCNLayer

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import dataset ##

According to the original paper, SCN is good at simplex classification. Thus, I chose shrec_16, a benchmark dataset for 3D mesh classification.

In [4]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
ys = ys.reshape((100, 1))
simplexes = shrec["complexes"]

Loading shrec 16 small dataset...

done!


In [8]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)

The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.


## Define neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the normalized Laplacian matrix on nodes, edges, and faces. We also convert the neighborhood structures to torch tensors.

In [5]:
A_0s = []
A_1s = []
A_2s = []
for x in simplexes:
    A_0 = x.normalized_laplacian_matrix(rank=0)
    A_1 = x.normalized_laplacian_matrix(rank=1)
    A_2 = x.normalized_laplacian_matrix(rank=2)

    A_0 = torch.from_numpy(A_0.todense()).to_sparse()
    A_1 = torch.from_numpy(A_1.todense()).to_sparse()
    A_2 = torch.from_numpy(A_2.todense()).to_sparse()

    A_0s.append(A_0)
    A_1s.append(A_1)
    A_2s.append(A_2)

# Create the Neural Network

In [7]:
class SCN(torch.nn.Module):
    """Simplex Convolutional Network Implementation for binary node classification.

    Parameters
    ---------
    in_channels_0 : int
        Dimension of input features on nodes.
    in_channels_1 : int
        Dimension of input features on edges.
    in_channels_2 : int
        Dimension of input features on faces.
    num_classes : int
        Number of classes.
    n_layers : int
        Amount of message passing layers.

    """

    def __init__(
        self, in_channels_0, in_channels_1, in_channels_2, num_classes, n_layers=2
    ):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                SCNLayer(
                    in_channels_0=in_channels_0,
                    in_channels_1=in_channels_1,
                    in_channels_2=in_channels_2,
                )
            )
        self.layers = torch.nn.ModuleList(layers)
        self.lin_0 = torch.nn.Linear(in_channels_0, num_classes)
        self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)
        self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)
        # self.layers = layers

    def forward(self, x_0, x_1, x_2, A_0, A_1, A_2):
        """Forward computation.

        Parameters
        ---------
        x_0 : tensor
            shape = [n_nodes, channels]
            Node features.

        Returns
        --------
        _ : tensor
            shape = [n_nodes, 2]
            One-hot labels assigned to nodes.

        """
        for layer in self.layers:
            x_0, x_1, x_2 = layer(x_0, x_1, x_2, A_0, A_1, A_2)
        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)

        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
        # Return the sum of the averages
        return (
            two_dimensional_cells_mean
            + one_dimensional_cells_mean
            + zero_dimensional_cells_mean
        )

# Train the Neural Network

We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [9]:
in_channels_0 = x_0s[i_complex].shape[1]
in_channels_1 = x_1s[i_complex].shape[1]
in_channels_2 = x_2s[i_complex].shape[1]

In [10]:
model = SCN(in_channels_0, in_channels_1, in_channels_2, num_classes=1)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

In [12]:
test_size = 0.2
x_0s_train, x_0s_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
x_1s_train, x_1s_test = train_test_split(x_1s, test_size=test_size, shuffle=False)
x_2s_train, x_2s_test = train_test_split(x_2s, test_size=test_size, shuffle=False)

A_0s_train, A_0s_test = train_test_split(A_0s, test_size=test_size, shuffle=False)
A_1s_train, A_1s_test = train_test_split(A_1s, test_size=test_size, shuffle=False)
A_2s_train, A_2s_test = train_test_split(A_2s, test_size=test_size, shuffle=False)

y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

The following cell performs the training, looping over the network for a low number of epochs.

In [18]:
test_interval = 2
num_epochs = 8
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, x_1, x_2, A_0, A_1, A_2, y in zip(
        x_0s_train, x_1s_train, x_2s_train, A_0s_train, A_1s_train, A_2s_train, y_train
    ):
        x_0, x_1, x_2, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor(y).float().to(device),
        )
        A_0, A_1, A_2 = (
            A_0.float().to(device),
            A_1.float().to(device),
            A_2.float().to(device),
        )
        optimizer.zero_grad()
        y_hat = model(x_0, x_1, x_2, A_0, A_1, A_2)
        # loss = criterion(y_hat, y)
        loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x_0, x_1, x_2, A_0, A_1, A_2, y in zip(
                x_0s_test, x_1s_test, x_2s_test, A_0s_test, A_1s_test, A_2s_test, y_test
            ):
                x_0, x_1, x_2, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(x_2).float().to(device),
                    torch.tensor(y).float().to(device),
                )
                A_0, A_1, A_2 = (
                    A_0.float().to(device),
                    A_1.float().to(device),
                    A_2.float().to(device),
                )
                y_hat = model(x_0, x_1, x_2, A_0, A_1, A_2)
                test_loss = loss_fn(y_hat, y)
            print(f"Test_loss: {test_loss:.4f}", flush=True)

Epoch: 1 loss: 76.6313
Epoch: 2 loss: 76.5828
Test_loss: 94.7222
Epoch: 3 loss: 76.5361
Epoch: 4 loss: 76.4920
Test_loss: 94.5671
Epoch: 5 loss: 76.4495
Epoch: 6 loss: 76.4081
Test_loss: 94.4065
Epoch: 7 loss: 76.3676
Epoch: 8 loss: 76.3283
Test_loss: 94.2507
