In [1]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
from toponetx.classes.simplicial_complex import SimplicialComplex

from topomodelx.nn.hypergraph.unigcn import UniGCN
import warnings

warnings.filterwarnings("ignore")
# %load_ext autoreload
# %autoreload 2

# Train a UNIGCN TNN

# Pre-processing

## Import data ##

The first step is to import the dataset, MUTAG, a benchmark dataset for graph classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:
- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input
- the binary label associated to the hypergraph

In [2]:
dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True)
dataset = dataset[:100]
hg_list = []
x_1_list = []
y_list = []
for graph in dataset:
    hg = SimplicialComplex(to_networkx(graph)).to_hypergraph()
    hg_list.append(hg)
    x_1_list.append(graph.x)
    y_list.append(int(graph.y))

incidence_1_list = []
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse_csr()
    incidence_1_list.append(incidence_1)

# Create the Neural Network

Using the UniGCNLayer class, we create a neural network with stacked layers.

In [3]:
channels_edge = x_1_list[0].shape[1]
hidden_channels = 32
out_channels = 1
n_layers = 2
task_level = "graph" if out_channels==1 else "node"

model = UniGCN(
        in_channels=channels_edge,
        hidden_channels=hidden_channels,
        out_channels=out_channels,
        n_layers=n_layers,
        task_level=task_level,)

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [4]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
crit = torch.nn.BCELoss()

Split the dataset into train, val, and test sets.

In [5]:
x_1_train, x_1_test = train_test_split(x_1_list, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=0.2, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=0.2, shuffle=False)

x_1_train, x_1_val = train_test_split(x_1_train, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_val = train_test_split(
    incidence_1_train, test_size=0.2, shuffle=False
)
y_train, y_val = train_test_split(y_train, test_size=0.2, shuffle=False)

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [6]:
torch.manual_seed(0)
test_interval = 10
num_epochs = 30
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    loss = 0
    for x_1, incidence_1, y in zip(x_1_train, incidence_1_train, y_train):
        output = model(x_1.float(), incidence_1.float())
        output = torch.nn.functional.sigmoid(output)
        loss += crit(output, torch.tensor([y]).float())
    loss.backward()
    optimizer.step()
    
    if epoch % test_interval == 0:
        print(f"Epoch {epoch} train loss: {loss.item()}")
        model.eval()
        with torch.no_grad():
            correct = 0
            for x_1, incidence_1, y in zip(x_1_val, incidence_1_val, y_val):
                output = model(x_1.float(), incidence_1.float())
                output = torch.nn.functional.sigmoid(output)
                pred = output > 0.5
                if pred == y:
                    correct += 1
            print(f"Epoch {epoch} Val acc: {correct / len(y_val)}")

model.eval()
with torch.no_grad():
    correct = 0
    for x_1, incidence_1, y in zip(x_1_test, incidence_1_test, y_test):
        output = model(x_1.float(), incidence_1.float())
        output = torch.nn.functional.sigmoid(output)
        pred = output > 0.5
        if pred == y:
            correct += 1
    print(f"Test accuracy: {correct / len(y_test)}")

Epoch 0 train loss: 88.3000259399414
Epoch 0 Val acc: 0.4375
Epoch 10 train loss: 49.60942077636719
Epoch 10 Val acc: 0.5625
Epoch 20 train loss: 30.046979904174805
Epoch 20 Val acc: 0.625
Test accuracy: 0.8
