In [1]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
from toponetx.classes.simplicial_complex import SimplicialComplex
from topomodelx.nn.hypergraph.unigin import UniGIN
from topomodelx.utils.sparse import from_sparse

# Train a UNIGIN TNN

# Pre-processing

## Import data ##

The first step is to import the dataset, MUTAG, a benchmark dataset for graph classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:
- input signal on the nodes for each of these hypergraphs, as that will be what we feed the model in input
- the binary label associated to the hypergraph

In [2]:
dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True)
dataset = dataset[:100]
hg_list = []
x_1_list = []
y_list = []
for graph in dataset:
    hg = SimplicialComplex(to_networkx(graph)).to_hypergraph()
    hg_list.append(hg)
    x_1_list.append(graph.x)
    y_list.append(int(graph.y))

incidence_1_list = []
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = from_sparse(incidence_1)
    incidence_1_list.append(incidence_1)

# Create the Neural Network

Using the UniGINLayer class, we create a neural network with stacked layers.

In [3]:
in_channels = x_1_list[0].shape[1]
hidden_channels = 32
out_channels = 2
n_layers= 3
task_level = "graph"

model = UniGIN(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels, 
    input_drop=0.2,
    layer_drop=0.2,
    n_layers=n_layers,
    task_level=task_level,
)

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [4]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
crit = torch.nn.CrossEntropyLoss()

Split the dataset into train, val and test sets.

In [5]:
x_1_train, x_1_test = train_test_split(x_1_list, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=0.2, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=0.2, shuffle=False)

x_1_train, x_1_val = train_test_split(x_1_train, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_val = train_test_split(
    incidence_1_train, test_size=0.2, shuffle=False
)
y_train, y_val = train_test_split(y_train, test_size=0.2, shuffle=False)

The below cell trains the model for 70 epochs printing the train loss, and validation accuracy after every epoch. The model is then evaluated on a seperate test-set after training.

In [6]:
torch.manual_seed(0)
test_interval = 10
num_epochs = 70
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    loss = 0
    for x_1, incidence_1, y in zip(x_1_train, incidence_1_train, y_train):
        output = model(x_1, incidence_1)
        loss += crit(output.unsqueeze(0), torch.tensor([y]))
    loss.backward()
    optimizer.step()
    if epoch % test_interval == 0:
        print(f"Epoch {epoch} loss: {loss.item()}")
        model.eval()
        with torch.no_grad():
            correct = 0
            for x_1, incidence_1, y in zip(x_1_val, incidence_1_val, y_val):
                output = model(x_1, incidence_1)
                pred = torch.argmax(output)
                if pred == y:
                    correct += 1
            print(f"Epoch {epoch} Validation accuracy: {correct / len(y_val)}")

model.eval()
with torch.no_grad():
    correct = 0
    for x_1, incidence_1, y in zip(x_1_test, incidence_1_test, y_test):
        output = model(x_1, incidence_1)
        pred = torch.argmax(output)
        if pred == y:
            correct += 1
    print(f"Test accuracy: {correct / len(y_test)}")

Epoch 0 loss: 264.6469421386719
Epoch 0 Validation accuracy: 0.5625
Epoch 10 loss: 41.20396041870117
Epoch 10 Validation accuracy: 0.5625
Epoch 20 loss: 35.169708251953125
Epoch 20 Validation accuracy: 0.5625
Epoch 30 loss: 30.082611083984375
Epoch 30 Validation accuracy: 0.5625
Epoch 40 loss: 26.874971389770508
Epoch 40 Validation accuracy: 0.5625
Epoch 50 loss: 21.022449493408203
Epoch 50 Validation accuracy: 0.5625
Epoch 60 loss: 26.520870208740234
Epoch 60 Validation accuracy: 0.625
Test accuracy: 0.7
