In [1]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
from toponetx.classes.simplicial_complex import SimplicialComplex
from topomodelx.nn.hypergraph.unigin import UniGIN

# Train a UNIGIN TNN

# Pre-processing

## Import data ##

The first step is to import the dataset, MUTAG, a benchmark dataset for graph classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:
- input signal on the nodes for each of these hypergraphs, as that will be what we feed the model in input
- the binary label associated to the hypergraph

In [2]:
dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True)
dataset = dataset[:100]
hg_list = []
x_1_list = []
y_list = []
for graph in dataset:
    hg = SimplicialComplex(to_networkx(graph)).to_hypergraph()
    hg_list.append(hg)
    x_1_list.append(graph.x)
    y_list.append(int(graph.y))

incidence_1_list = []
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    incidence_1_list.append(incidence_1)

# Create the Neural Network

Using the UniGINLayer class, we create a neural network with stacked layers.

In [3]:
node_dim = x_1_list[0].shape[1]
intermediate_channels = 32
out_dim = 2
model = UniGIN(
    in_channels_node=node_dim,
    intermediate_channels=intermediate_channels,
    out_channels=out_dim,
    n_layers=2,
)

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [4]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
crit = torch.nn.CrossEntropyLoss()

Split the dataset into train, val and test sets.

In [5]:
x_1_train, x_1_test = train_test_split(x_1_list, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=0.2, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=0.2, shuffle=False)

x_1_train, x_1_val = train_test_split(x_1_train, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_val = train_test_split(
    incidence_1_train, test_size=0.2, shuffle=False
)
y_train, y_val = train_test_split(y_train, test_size=0.2, shuffle=False)

The below cell trains the model for 50 epochs printing the train loss, and validation accuracy after every epoch. The model is then evaluated on a seperate test-set after training.

In [6]:
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    loss = 0
    for x_1, incidence_1, y in zip(x_1_train, incidence_1_train, y_train):
        output = model(x_1, incidence_1)
        loss += crit(output.unsqueeze(0), torch.tensor([y]))
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch} loss: {loss.item()}")
    model.eval()
    with torch.no_grad():
        correct = 0
        for x_1, incidence_1, y in zip(x_1_val, incidence_1_val, y_val):
            output = model(x_1, incidence_1)
            pred = torch.argmax(output)
            if pred == y:
                correct += 1
        print(f"Epoch {epoch} Validation accuracy: {correct / len(y_val)}")

model.eval()
with torch.no_grad():
    correct = 0
    for x_1, incidence_1, y in zip(x_1_test, incidence_1_test, y_test):
        output = model(x_1, incidence_1)
        pred = torch.argmax(output)
        if pred == y:
            correct += 1
    print(f"Test accuracy: {correct / len(y_test)}")

Epoch 0 loss: 43.388511657714844
Epoch 0 Validation accuracy: 0.5625
Epoch 1 loss: 41.726402282714844
Epoch 1 Validation accuracy: 0.5625
Epoch 2 loss: 40.228939056396484
Epoch 2 Validation accuracy: 0.5625
Epoch 3 loss: 38.930335998535156
Epoch 3 Validation accuracy: 0.5625
Epoch 4 loss: 37.89554214477539
Epoch 4 Validation accuracy: 0.5625
Epoch 5 loss: 37.096839904785156
Epoch 5 Validation accuracy: 0.5625
Epoch 6 loss: 36.53312301635742
Epoch 6 Validation accuracy: 0.5625
Epoch 7 loss: 36.17842483520508
Epoch 7 Validation accuracy: 0.5625
Epoch 8 loss: 35.982330322265625
Epoch 8 Validation accuracy: 0.5625
Epoch 9 loss: 35.8914680480957
Epoch 9 Validation accuracy: 0.5625
Epoch 10 loss: 35.861602783203125
Epoch 10 Validation accuracy: 0.5625
Epoch 11 loss: 35.86212921142578
Epoch 11 Validation accuracy: 0.5625
Epoch 12 loss: 35.87543869018555
Epoch 12 Validation accuracy: 0.5625
Epoch 13 loss: 35.89324951171875
Epoch 13 Validation accuracy: 0.5625
Epoch 14 loss: 35.91067886352539
E