# Tutorial: Set-up, create and train a convolutional CXN

In this notebook, we create and train a simplified, non-attentional version of a CXN network, originally proposed in the paper by Hajij et. al : Cell Complex Neural Networks (https://arxiv.org/pdf/2010.00743.pdf). We will load a cellular complex dataset from the web and train the model to perform classificaiton on this dataset.

In [1]:
import data
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader

torch.multiprocessing.set_sharing_strategy("file_system")
# import wandb

from topomodelx.nn.cellular.convcxn_layer import ConvCXNLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
device = torch.device(device)

cpu


We specify the hyperparameters for training. This sort of formulation can be useful for performing hyperparameter sweeps later on. Wandb is a useful tool for tracking sweeps -- feel free to uncomment the wandb code in order to launch your own experiment tracking with it.

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("--lr", default=1e-3, type=float)
parser.add_argument("--dataset", default="shrec_16", type=str)
parser.add_argument("--hidden_dim", default=64, type=int)
parser.add_argument("--input_to_lin_layers", "-i", default=32, type=int)
parser.add_argument("--dropout", default=0.1, type=float)
parser.add_argument("--num_epochs", default=3, type=int)
parser.add_argument("--num_features", default=50, type=int)
parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument("--with_rotation", default=1, type=int, choices=[0, 1])

args, unknown = parser.parse_known_args()
training_cfg = {
    "dataset": args.dataset,
    "lr": args.lr,
    "hidden_dim": args.hidden_dim,
    "input_to_lin_layers": args.input_to_lin_layers,
    "dropout": args.dropout,
    "num_epochs": args.num_epochs,
}
ds_name = args.dataset
num_features = args.num_features
rot = args.with_rotation
cat = True
num_classes = 5
# wandb.login()
# wandb.init(config=args, name=f"ccnn_att{ds_name}_added_noise_lr_{args.lr}_hidden_{args.hidden_dim}",project='shrec_16', entity='')

# Pre-processing

## Import data and neighborhood structures

Now we load the train/test datasets from the web. The data object will contain both the cell features as well as the associate neighborhood matrices.

In [4]:
train_dataset = data.Shrec16AugDataset(
    root="./dataset",
    name="shrec_16",
    split="train",
    num_rot=1,
    cat=True,
    num_features=50,
)
test_dataset = data.Shrec16AugDataset(
    root="./dataset",
    name="shrec_16",
    split="test",
    num_rot=1,
    cat=True,
    num_features=50,
)

train_loader = DataLoader(train_dataset, batch_size=None, num_workers=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=None, num_workers=4, shuffle=False)
data = train_dataset[0]

From the loaded dataset, we extract the features on nodes/edges/faces, as well as their dimensions. We will need this for defining our model.

In [5]:
x_0 = data.x
x_1 = data.x_e
x_2 = data.x_f

in_ch_0 = x_0.size(1)
in_ch_1 = x_1.size(1)
in_ch_2 = x_2.size(1)
print(f"in_ch_v {in_ch_0} in_ch_e {in_ch_1} in_ch_f {in_ch_2}")

in_ch_v 6 in_ch_e 50 in_ch_f 50


# Create the Neural Network

Using the ConvCXNLayer class, we create a neural network with stacked layers. We define the amount of channels on the face and edge ranks to be different, making this a heterogenous network.

In [6]:
class ConvCXN(torch.nn.Module):
    """Convolutional CXN.

    Parameters
    ----------
    in_ch_0 : int
        Dimension of input features on nodes.
    in_ch_1 : int
        Dimension of input features on edges.
    in_ch_2 : int
        Dimension of input features on faces.
    num_classes : int
        Number of classes.
    n_layers : int
        Number of CXN layers.
    """

    def __init__(self, in_ch_0, in_ch_1, in_ch_2, num_classes, n_layers=2):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                ConvCXNLayer(
                    in_channels_0=in_ch_0,
                    in_channels_1=in_ch_1,
                    in_channels_2=in_ch_2,
                )
            )
        self.layers = layers
        self.lin_0 = torch.nn.Linear(in_ch_0, num_classes)
        self.lin_1 = torch.nn.Linear(in_ch_1, num_classes)
        self.lin_2 = torch.nn.Linear(in_ch_2, num_classes)

    def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
        """Forward computation through ConvCXN layers then linear layers."""
        for layer in self.layers:
            x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)
        return torch.mean(x_2, dim=0) + torch.mean(x_1, dim=0) + torch.mean(x_0, dim=0)

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer.

In [9]:
model = ConvCXN(in_ch_0, in_ch_1, in_ch_2, num_classes=num_classes, n_layers=2)

# wandb.watch(model, log_freq=len(train_dataset))
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=args.lr)

The following cell performs the training, looping over the network for 5 epochs and testing after every 2 epochs.

In [10]:
test_interval = 2
for epoch_i in range(1, args.num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for batch_i, data in enumerate(train_loader, start=1):
        opt.zero_grad()
        x_0 = data.x.to(device)
        x_1 = data.x_e.to(device)

        y = data.y.to(device)

        A0 = data.A0.to_sparse().to(device)
        B2T = data.B2T.to_sparse().to(device)

        y_hat = model(x_0.float(), x_1.float(), A0.float(), B2T.float())

        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        print(f"Done {batch_i}/{len(train_dataset)}", end="\r")
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    # wandb.log({"loss": np.mean(epoch_loss), "Train_acc": train_acc, "epoch": epoch_i})
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for batch_i, data in enumerate(test_loader, start=1):
                x_0 = data.x.to(device)
                x_1 = data.x_e.to(device)
                y = data.y.to(device)

                A0 = data.A0.to_sparse().to(device)
                B2T = data.B2T.to_sparse().to(device)

                y_hat = model(x_0.float(), x_1.float(), A0.float(), B2T.float())

                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
                print(f"Done {batch_i}/{len(test_dataset)}", end="\r")
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)
            # wandb.log({"Test_acc": test_acc})

Epoch: 1 loss: 2.0716 Train_acc: 0.2000
Epoch: 2 loss: 1.7892 Train_acc: 0.2000
Test_acc: 0.2000
Epoch: 3 loss: 1.6736 Train_acc: 0.2250
