# Tutorial: Set-up, create and train a convolutional CXN

In this notebook, we create and train a simplified, non-attentional version of a CXN network, originally proposed in the paper by Hajij et. al : Cell Complex Neural Networks (https://arxiv.org/pdf/2010.00743.pdf). We will load a cell complex dataset from the web and train the model to perform classificaiton on this dataset.

In [21]:
import torch
import numpy as np
import argparse
import random
from toponetx import CellComplex

from topomodelx.nn.cell.cxn_layer import CXNLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [22]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
device = torch.device(device)

cpu


We specify the hyperparameters for training.

In [23]:
parser = argparse.ArgumentParser()
parser.add_argument("--lr", default=1e-3, type=float)
parser.add_argument("--num_epochs", default=3, type=int)
parser.add_argument("--with_rotation", default=1, type=int, choices=[0, 1])

args, unknown = parser.parse_known_args()
training_cfg = {
    "lr": args.lr,
    "num_epochs": args.num_epochs,
}

# Pre-processing

## Create synthetic data and load its neighborhood structures

We start by creating the cell complex on which the neural network will operate.

In [24]:
edge_set = [
    [1, 2],
    [1, 3],
    [2, 4],
    [3, 4],
    [4, 5],
    [1, 6],
]  # two edges stick out on diag
node_set = [1, 2, 3, 4, 5, 6]
face_set = [[1, 2, 3, 4]]
complex = CellComplex(edge_set + node_set + face_set)
print(complex)
print(len(complex.edges))

Cell Complex with 6 nodes, 8 edges  and 1 2-cells 
8


We will need the adjacency matrix $A_{\uparrow, 0}$ and the coboundary matrix $B_2^T$.

In [25]:
incidence_2_t = complex.incidence_matrix(rank=2).T
adjacency_0 = complex.adjacency_matrix(rank=0)
incidence_2_t = torch.from_numpy(incidence_2_t.todense()).to_sparse().float()
adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse().float()

Now we create data on this complex. Specifically, we need node features and edge features for both train and test datasets.

In [26]:
x_0_train = []
num_features_0 = 4
for _ in range(100):
    x_0_train.append(torch.Tensor(np.random.rand(len(complex.nodes), num_features_0)))

x_1_train = []
num_features_1 = 5
for _ in range(100):
    x_1_train.append(torch.Tensor(np.random.rand(len(complex.edges), num_features_1)))

x_0_test = []
num_features_0 = 4
for _ in range(10):
    x_0_test.append(torch.Tensor(np.random.rand(len(complex.nodes), num_features_0)))

x_1_test = []
num_features_1 = 5
for _ in range(10):
    x_1_test.append(torch.Tensor(np.random.rand(len(complex.edges), num_features_1)))

Now we must define labels associated to these datasets, as we will perform binary node classification. For the purposes of the tutorial, we leave these completely random.

In [34]:
labels_train = [random.randint(0, 1) for _ in range(100)]
labels_test = [random.randint(0, 1) for _ in range(10)]

Let's define the input feature dimensions as channel dimensions. We will use this to define our model.

In [28]:
in_ch_0 = num_features_0
in_ch_1 = num_features_1
in_ch_2 = 5
print(f"in_ch_v {in_ch_0} in_ch_e {in_ch_1} in_ch_f {in_ch_2}")

in_ch_v 4 in_ch_e 5 in_ch_f 5


# Create the Neural Network

Using the CXNLayer class, we create a neural network with stacked layers. We define the amount of channels on the face and edge ranks to be different, making this a heterogenous network.

In [29]:
class CXN(torch.nn.Module):
    """Convolutional CXN.

    Parameters
    ----------
    in_ch_0 : int
        Dimension of input features on nodes.
    in_ch_1 : int
        Dimension of input features on edges.
    in_ch_2 : int
        Dimension of input features on faces.
    num_classes : int
        Number of classes.
    n_layers : int
        Number of CXN layers.
    """

    def __init__(self, in_ch_0, in_ch_1, in_ch_2, num_classes, n_layers=2):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                CXNLayer(
                    in_channels_0=in_ch_0,
                    in_channels_1=in_ch_1,
                    in_channels_2=in_ch_2,
                )
            )
        self.layers = layers
        self.lin_0 = torch.nn.Linear(in_ch_0, num_classes)
        self.lin_1 = torch.nn.Linear(in_ch_1, num_classes)
        self.lin_2 = torch.nn.Linear(in_ch_2, num_classes)

    def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
        """Forward computation through CXN layers then linear layers."""
        for layer in self.layers:
            x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)
        return torch.mean(x_2, dim=0) + torch.mean(x_1, dim=0) + torch.mean(x_0, dim=0)

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer.

In [30]:
model = CXN(in_ch_0, in_ch_1, in_ch_2, num_classes=2, n_layers=2)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=args.lr)

The following cell performs the training, looping over the network for 5 epochs and testing after every 2 epochs.

In [31]:
test_interval = 2
for epoch_i in range(1, args.num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, y in zip(x_0_train, x_1_train, labels_train):

        opt.zero_grad()

        y_hat = model(
            x_0.float(), x_1.float(), adjacency_0.float(), incidence_2_t.float()
        )
        y = torch.tensor(y).long()
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    # wandb.log({"loss": np.mean(epoch_loss), "Train_acc": train_acc, "epoch": epoch_i})
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, y in zip(x_0_test, x_1_test, labels_test):
                y = torch.tensor(y).long()
                y_hat = model(x_0, x_1, adjacency_0, incidence_2_t)

                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)
            # wandb.log({"Test_acc": test_acc})

Epoch: 1 loss: 0.2939 Train_acc: 0.9800
Epoch: 2 loss: 0.0628 Train_acc: 1.0000
Test_acc: 1.0000
Epoch: 3 loss: 0.0297 Train_acc: 1.0000
