# Train a Hypergraph Message Passing Neural Network (HMPNN)

In this notebook, we will create and train a Hypergraph Message Passing Neural Network in the hypergraph domain. This method is introduced in the paper [Message Passing Neural Networks for
Hypergraphs](https://arxiv.org/abs/2203.16995) by Heydari et Livi 2022. We will use a benchmark dataset, Cora, a collection of 2708 academic papers and 5429 citation relations, to do the task of node classification. There are 7 category labels, namely `Case_Based`, `Genetic_Algorithms`, `Neural_Networks`, `Probabilistic_Methods`, `Reinforcement_Learning`, `Rule_Learning` and `Theory`.

Each document is initially represented as a binary vector of length 1433, standing for a unique subset of the words within the papers, in which a value of 1 means the presence of its corresponding word in the paper.

In [None]:
import torch
import torch_geometric.datasets as geom_datasets
from sklearn.metrics import accuracy_score

from topomodelx.nn.hypergraph.hmpnn import HMPNN

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

Here we download the dataset. It contains initial representation of nodes, the adjacency information, category labels and train-val-test masks.

In [None]:

cora = geom_datasets.Planetoid(root="/TopoModelX/data/cora", name="Cora")
data = cora.data

x_0s = data.x
y = data.y
edge_index = data.edge_index

train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask

## Define neighborhood structures and lift into hypergraph domain. ##

Now we retrieve the neighborhood structure (i.e. their representative matrice) that we will use to send messges from node to hyperedges. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\text{nodes} \times n_\text{edges}$.

In citation Cora dataset we lift graph structure to the hypergraph domain by creating hyperedges from 1-hop graph neighbourhood of each node. 


In [None]:
# Ensure the graph is undirected (optional but often useful for one-hop neighborhoods).
edge_index = to_undirected(edge_index)

# Create a list of one-hop neighborhoods for each node.
one_hop_neighborhoods = []
for node in range(data.num_nodes):
    # Get the one-hop neighbors of the current node.
    neighbors = data.edge_index[1, data.edge_index[0] == node]

    # Append the neighbors to the list of one-hop neighborhoods.
    one_hop_neighborhoods.append(neighbors.numpy())

# Detect and eliminate duplicate hyperedges.
unique_hyperedges = set()
hyperedges = []
for neighborhood in one_hop_neighborhoods:
    # Sort the neighborhood to ensure consistent comparison.
    neighborhood = tuple(sorted(neighborhood))
    if neighborhood not in unique_hyperedges:
        hyperedges.append(list(neighborhood))
        unique_hyperedges.add(neighborhood)    

Additionally we print the statictis associated with obtained incidence matrix

In [None]:

# Calculate hyperedge statistics.
hyperedge_sizes = [len(he) for he in hyperedges]
min_size = min(hyperedge_sizes)
max_size = max(hyperedge_sizes)
mean_size = np.mean(hyperedge_sizes)
median_size = np.median(hyperedge_sizes)
std_size = np.std(hyperedge_sizes)
num_single_node_hyperedges = sum(np.array(hyperedge_sizes) == 1)

# Print the hyperedge statistics.
print(f'Hyperedge statistics: ')
print('Number of hyperedges without duplicated hyperedges', len(hyperedges))
print(f'min = {min_size}, ')
print(f'max = {max_size}, ')
print(f'mean = {mean_size}, ')
print(f'median = {median_size}, ')
print(f'std = {std_size}, ')
print(f'Number of hyperedges with size equal to one = {num_single_node_hyperedges}')


Construct incidence matrix

In [None]:
max_edges = len(hyperedges)
incidence_1 = np.zeros((x_0s.shape[0], max_edges))
for col, neighibourhood in enumerate(hyperedges):
    for row in neighibourhood:
        incidence_1[row, col] = 1

assert all(incidence_1.sum(0)>0) == True, "Some hyperedges are empty"
assert all(incidence_1.sum(1)>0) == True, "Some nodes are not in any hyperedges"
incidence_1 = torch.Tensor(incidence_1).to_sparse_coo()

# Train the Neural Network

We then specify the hyperparameters and construct the model, the loss and optimizer.

In [5]:
torch.manual_seed(41)

in_features = 1433
hidden_features = 8
num_classes = 7
n_layers = 1

model = HMPNN(in_features, (256, hidden_features), num_classes, n_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()



train_mask = dataset["train_mask"]
val_mask = dataset["val_mask"]
test_mask = dataset["test_mask"]


x = dataset["x"]
y = dataset["y"]
incidence_1 = dataset["incidence_1"]

Now it's time to train the model, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [6]:
torch.manual_seed(0)
test_interval = 5
num_epochs=100


initial_x_1 = torch.zeros_like(x)
for epoch in range(1, num_epochs + 1):
    model.train()
    optimizer.zero_grad()
    y_hat = model(x, initial_x_1, incidence_1)
    loss = loss_fn(y_hat[train_mask], y[train_mask])
    loss.backward()
    optimizer.step()

    train_loss = loss.item()
    y_pred = y_hat.argmax(dim=-1)
    train_acc = accuracy_score(y[train_mask], y_pred[train_mask])
    
    if epoch % test_interval == 0:
        model.eval()
        
        y_hat = model(x, initial_x_1, incidence_1)
        val_loss = loss_fn(y_hat[val_mask], y[val_mask]).item()
        y_pred = y_hat.argmax(dim=-1)
        val_acc = accuracy_score(y[val_mask], y_pred[val_mask])


        test_loss = loss_fn(y_hat[test_mask], y[test_mask]).item()
        y_pred = y_hat.argmax(dim=-1)
        test_acc = accuracy_score(y[test_mask], y_pred[test_mask])
        print(
            f"Epoch: {epoch + 1} train loss: {train_loss:.4f} train acc: {train_acc:.2f} "
            f"val loss: {val_loss:.4f} val acc: {val_acc:.2f}"
            f"test loss: {test_acc:.4f} val acc: {test_acc:.2f}"
        )

        

        
        

Epoch: 6 train loss: 1.7054 train acc: 0.63 val loss: 1.8855 val acc: 0.34test loss: 0.3410 val acc: 0.34
Epoch: 11 train loss: 1.6260 train acc: 0.69 val loss: 1.8205 val acc: 0.32test loss: 0.3380 val acc: 0.34
Epoch: 16 train loss: 1.5252 train acc: 0.79 val loss: 1.7742 val acc: 0.35test loss: 0.3920 val acc: 0.39
Epoch: 21 train loss: 1.4274 train acc: 0.85 val loss: 1.7526 val acc: 0.35test loss: 0.3850 val acc: 0.39
Epoch: 26 train loss: 1.3159 train acc: 0.83 val loss: 1.7416 val acc: 0.36test loss: 0.3850 val acc: 0.39
Epoch: 31 train loss: 1.2421 train acc: 0.84 val loss: 1.6929 val acc: 0.39test loss: 0.4150 val acc: 0.41
Epoch: 36 train loss: 1.1510 train acc: 0.91 val loss: 1.6686 val acc: 0.41test loss: 0.4340 val acc: 0.43
Epoch: 41 train loss: 1.0704 train acc: 0.91 val loss: 1.5701 val acc: 0.49test loss: 0.4900 val acc: 0.49
Epoch: 46 train loss: 0.9775 train acc: 0.90 val loss: 1.5366 val acc: 0.50test loss: 0.5130 val acc: 0.51
Epoch: 51 train loss: 0.9132 train acc