In [1]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from sklearn.model_selection import train_test_split
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
from toponetx import SimplicialComplex
from topomodelx.nn.hypergraph.unigcn_layer import UniGCNLayer
import warnings

warnings.filterwarnings("ignore")

# Pre-processing

## Import data ##

The first step is to import the dataset, MUTAG, a benchmark dataset for graph classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:
- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input
- the binary label associated to the hypergraph

In [2]:
dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True)
dataset = dataset[:100]
hg_list = []
x_1_list = []
y_list = []
for graph in dataset:
    hg = SimplicialComplex(to_networkx(graph)).to_hypergraph()
    hg_list.append(hg)
    x_1_list.append(graph.x)
    y_list.append(int(graph.y))

incidence_1_list = []
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse_csr()
    incidence_1_list.append(incidence_1)

# Create the Neural Network

Using the UniGCNLayer class, we create a neural network with stacked layers.

In [3]:
channels_edge = x_1_list[0].shape[1]
channels_node = dataset[0].x.shape[1]

In [4]:
class UniGCNNN(torch.nn.Module):
    """Neural network implementation of UniGCN for hypergraph classification.

    Parameters
    ---------
    channels_edge : int
        Dimension of edge features
    channels_node : int
        Dimension of node features
    n_layer : 2
        Amount of message passing layers.

    """

    def __init__(self, channels_edge, channels_node, n_layers=2):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                UniGCNLayer(
                    in_channels=channels_edge,
                    out_channels=channels_edge,
                )
            )
        self.layers = torch.nn.ModuleList(layers)
        self.linear = torch.nn.Linear(channels_edge, 1)

    def forward(self, x_1, incidence_1):
        """Forward computation through layers, then linear layer, then global max pooling.

        Parameters
        ---------
        x_1 : tensor
            shape = [n_edges, channels_edge]
            Edge features.

        incidence_1 : tensor
            shape = [n_nodes, n_edges]
            Boundary matrix of rank 1.

        Returns
        --------
        _ : tensor
            shape = [1]
            Label assigned to whole complex.
        """
        for layer in self.layers:
            x_1 = layer(x_1, incidence_1)
        pooled_x = torch.max(x_1, dim=0)[0]
        return torch.sigmoid(self.linear(pooled_x))

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [5]:
model = UniGCNNN(channels_edge, channels_node, n_layers=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
crit = torch.nn.BCELoss()

Split the dataset into train, val, and test sets.

In [6]:
x_1_train, x_1_test = train_test_split(x_1_list, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=0.2, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=0.2, shuffle=False)

x_1_train, x_1_val = train_test_split(x_1_train, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_val = train_test_split(
    incidence_1_train, test_size=0.2, shuffle=False
)
y_train, y_val = train_test_split(y_train, test_size=0.2, shuffle=False)

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [7]:
num_epochs = 500
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    loss = 0
    for x_1, incidence_1, y in zip(x_1_train, incidence_1_train, y_train):
        output = model(x_1.float(), incidence_1.float())
        loss += crit(output, torch.tensor([y]).float())
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch} loss: {loss.item()}")
    model.eval()
    with torch.no_grad():
        correct = 0
        for x_1, incidence_1, y in zip(x_1_val, incidence_1_val, y_val):
            output = model(x_1.float(), incidence_1.float())
            pred = output > 0.5
            if pred == y:
                correct += 1
        print(f"Epoch {epoch} Validation accuracy: {correct / len(y_val)}")

model.eval()
with torch.no_grad():
    correct = 0
    for x_1, incidence_1, y in zip(x_1_test, incidence_1_test, y_test):
        output = model(x_1.float(), incidence_1.float())
        pred = output > 0.5
        if pred == y:
            correct += 1
    print(f"Test accuracy: {correct / len(y_test)}")

Epoch 0 loss: 123.86717987060547
Epoch 0 Validation accuracy: 0.375
Epoch 1 loss: 51.91304397583008
Epoch 1 Validation accuracy: 0.5625
Epoch 2 loss: 39.290748596191406
Epoch 2 Validation accuracy: 0.5625
Epoch 3 loss: 51.299095153808594
Epoch 3 Validation accuracy: 0.5625
Epoch 4 loss: 60.093505859375
Epoch 4 Validation accuracy: 0.5625
Epoch 5 loss: 63.15834045410156
Epoch 5 Validation accuracy: 0.5625
Epoch 6 loss: 61.655792236328125
Epoch 6 Validation accuracy: 0.5625
Epoch 7 loss: 56.99005889892578
Epoch 7 Validation accuracy: 0.5625
Epoch 8 loss: 50.47241973876953
Epoch 8 Validation accuracy: 0.5625
Epoch 9 loss: 43.39476776123047
Epoch 9 Validation accuracy: 0.5625
Epoch 10 loss: 37.2061882019043
Epoch 10 Validation accuracy: 0.5625
Epoch 11 loss: 34.064640045166016
Epoch 11 Validation accuracy: 0.5625
Epoch 12 loss: 35.42961502075195
Epoch 12 Validation accuracy: 0.5
Epoch 13 loss: 39.61798095703125
Epoch 13 Validation accuracy: 0.375
Epoch 14 loss: 42.49945831298828
Epoch 14 V

Epoch 117 loss: 23.782100677490234
Epoch 117 Validation accuracy: 0.5625
Epoch 118 loss: 23.752033233642578
Epoch 118 Validation accuracy: 0.5625
Epoch 119 loss: 23.716232299804688
Epoch 119 Validation accuracy: 0.5625
Epoch 120 loss: 23.686473846435547
Epoch 120 Validation accuracy: 0.5625
Epoch 121 loss: 23.65744972229004
Epoch 121 Validation accuracy: 0.5625
Epoch 122 loss: 23.63184928894043
Epoch 122 Validation accuracy: 0.5625
Epoch 123 loss: 23.596235275268555
Epoch 123 Validation accuracy: 0.5625
Epoch 124 loss: 23.568506240844727
Epoch 124 Validation accuracy: 0.5625
Epoch 125 loss: 23.550748825073242
Epoch 125 Validation accuracy: 0.5625
Epoch 126 loss: 23.529848098754883
Epoch 126 Validation accuracy: 0.5625
Epoch 127 loss: 23.503042221069336
Epoch 127 Validation accuracy: 0.5625
Epoch 128 loss: 23.472322463989258
Epoch 128 Validation accuracy: 0.5625
Epoch 129 loss: 23.467374801635742
Epoch 129 Validation accuracy: 0.5625
Epoch 130 loss: 23.446603775024414
Epoch 130 Validati

Epoch 232 loss: 21.420305252075195
Epoch 232 Validation accuracy: 0.5625
Epoch 233 loss: 21.406696319580078
Epoch 233 Validation accuracy: 0.5625
Epoch 234 loss: 21.394020080566406
Epoch 234 Validation accuracy: 0.5625
Epoch 235 loss: 21.37169075012207
Epoch 235 Validation accuracy: 0.5625
Epoch 236 loss: 21.35594940185547
Epoch 236 Validation accuracy: 0.5625
Epoch 237 loss: 21.34318733215332
Epoch 237 Validation accuracy: 0.5625
Epoch 238 loss: 21.33356475830078
Epoch 238 Validation accuracy: 0.5625
Epoch 239 loss: 21.31667137145996
Epoch 239 Validation accuracy: 0.5625
Epoch 240 loss: 21.302318572998047
Epoch 240 Validation accuracy: 0.5625
Epoch 241 loss: 21.28500747680664
Epoch 241 Validation accuracy: 0.5625
Epoch 242 loss: 21.26559066772461
Epoch 242 Validation accuracy: 0.5625
Epoch 243 loss: 21.25322914123535
Epoch 243 Validation accuracy: 0.5625
Epoch 244 loss: 21.238706588745117
Epoch 244 Validation accuracy: 0.5625
Epoch 245 loss: 21.220836639404297
Epoch 245 Validation acc

Epoch 348 loss: 19.65319061279297
Epoch 348 Validation accuracy: 0.5625
Epoch 349 loss: 19.64687728881836
Epoch 349 Validation accuracy: 0.5625
Epoch 350 loss: 19.641027450561523
Epoch 350 Validation accuracy: 0.5625
Epoch 351 loss: 19.62809181213379
Epoch 351 Validation accuracy: 0.5625
Epoch 352 loss: 19.599239349365234
Epoch 352 Validation accuracy: 0.5625
Epoch 353 loss: 19.57975196838379
Epoch 353 Validation accuracy: 0.5625
Epoch 354 loss: 19.56479835510254
Epoch 354 Validation accuracy: 0.5625
Epoch 355 loss: 19.56086540222168
Epoch 355 Validation accuracy: 0.5625
Epoch 356 loss: 19.538368225097656
Epoch 356 Validation accuracy: 0.5625
Epoch 357 loss: 19.528959274291992
Epoch 357 Validation accuracy: 0.5625
Epoch 358 loss: 19.50667953491211
Epoch 358 Validation accuracy: 0.5625
Epoch 359 loss: 19.503355026245117
Epoch 359 Validation accuracy: 0.5625
Epoch 360 loss: 19.473203659057617
Epoch 360 Validation accuracy: 0.5625
Epoch 361 loss: 19.480627059936523
Epoch 361 Validation ac

Epoch 464 loss: 18.290189743041992
Epoch 464 Validation accuracy: 0.5625
Epoch 465 loss: 18.294099807739258
Epoch 465 Validation accuracy: 0.5625
Epoch 466 loss: 18.279863357543945
Epoch 466 Validation accuracy: 0.5625
Epoch 467 loss: 18.259613037109375
Epoch 467 Validation accuracy: 0.5625
Epoch 468 loss: 18.255352020263672
Epoch 468 Validation accuracy: 0.5625
Epoch 469 loss: 18.243770599365234
Epoch 469 Validation accuracy: 0.5625
Epoch 470 loss: 18.24730682373047
Epoch 470 Validation accuracy: 0.5625
Epoch 471 loss: 18.235071182250977
Epoch 471 Validation accuracy: 0.5625
Epoch 472 loss: 18.221952438354492
Epoch 472 Validation accuracy: 0.5625
Epoch 473 loss: 18.20149803161621
Epoch 473 Validation accuracy: 0.5625
Epoch 474 loss: 18.19734001159668
Epoch 474 Validation accuracy: 0.5625
Epoch 475 loss: 18.199737548828125
Epoch 475 Validation accuracy: 0.5625
Epoch 476 loss: 18.179275512695312
Epoch 476 Validation accuracy: 0.5625
Epoch 477 loss: 18.176069259643555
Epoch 477 Validatio