# Train a Higher-Order Attention Network for Complex Classification.

In this notebook we will train a HOAN for mesh classification (as defined in [HZPMG22]). We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the combinatorial complex. 

In [223]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split

from toponetx import CombinatorialComplex
import toponetx.datasets as datasets
from topomodelx.nn.combinatorial.hoan_mc_layer import HOANMCLayer
from topomodelx.base.aggregation import Aggregation

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [224]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a combinatorial complex.

We will also retrieve:
- input signal on the nodes, edges, and faces for each of these combinatorial complexes, as that will be what we feed the model in input
- the label associated to the combinatorial complexes

In [225]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]

Loading shrec 16 small dataset...

done!


## Define neighborhood structures and lift into combinatorial complex domain. ##


Now, we lift each simplicial complex into a combinatorial complex. 

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each combinatorial complex. In the case of this architecture, we need the boundary matrices (or incidence matrices) $B_1$, $B_2$ with shape $n_\text{nodes} \times n_\text{1cells}$ and $n_\text{1cells} \times n_\text{2cells}$ respectively. We also need the upwards adacency matrices for nodes and one-cells $A_{\uparrow, 0}$, $A_{\uparrow, 1}$, and the downwards adacency matrix for two cells, $A_{\downarrow, 2}$.


In [226]:
def calc_adjacency_down(incidence_mat):
    """Used to calculate downwards adjacency matrix from corresponding incidence matrix."""
    down_lap_2 = torch.sparse.mm(
        incidence_mat.float().transpose(1, 0), incidence_mat.float()
    )
    down_lap_2 = down_lap_2
    diag_tens = torch.zeros((incidence_mat.shape[1],))
    one_indices, two_indices = incidence_mat.indices()
    for i in two_indices:
        diag_tens[i].add_(1)
    D_down = torch.diag(diag_tens).to_sparse()
    adj_matrix = D_down - down_lap_2
    return adj_matrix

In [227]:
cc_list = []
incidence_1_list = []
incidence_2_list = []
up_adjacency_0_list = []
up_adjacency_1_list = []
down_adjacency_2_list = []
for simplex in simplexes:
    cc = simplex.to_combinatorial_complex()
    cc_list.append(cc)

    incidence_1 = cc.incidence_matrix(rank=0, to_rank=1)
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    incidence_1_list.append(incidence_1)

    incidence_2 = cc.incidence_matrix(rank=1, to_rank=2)
    incidence_2 = torch.from_numpy(incidence_2.todense()).to_sparse()
    incidence_2_list.append(incidence_2)

    up_adjacency_0 = cc.adjacency_matrix(rank=0, via_rank=1)
    up_adjacency_0 = torch.from_numpy(up_adjacency_0.todense()).to_sparse()
    up_adjacency_0_list.append(up_adjacency_0)

    up_adjacency_1 = cc.adjacency_matrix(rank=1, via_rank=2)
    up_adjacency_1 = torch.from_numpy(up_adjacency_1.todense()).to_sparse()
    up_adjacency_1_list.append(up_adjacency_1)

    down_adjacency_2 = calc_adjacency_down(incidence_2)
    down_adjacency_2_list.append(down_adjacency_2)

In [228]:
num_classes = max(ys) + 1

# Create the Neural Network

Using the HOANMCLayer class, we create a neural network with one topological layer.

In [229]:
channels = [x_0s[0].shape[-1], x_1s[0].shape[-1], x_2s[0].shape[-1]]
channels

[6, 10, 7]

In [230]:
class HOANMCNN(torch.nn.Module):
    """Neural network implementation of Template for hypergraph classification.

    Parameters
    ---------
    channels : list[int] length 3
        Dimension of features at nodes, one-cells, and two-cells respectively
    n_classes : int
        Number of classes in output.
    n_layers : 1
        Amount of HOAN message passing layers.

    """

    def __init__(self, channels, n_classes, n_layers=1):
        super().__init__()
        self.n_classes = n_classes
        layers = []
        for _ in range(n_layers):
            layers.append(
                HOANMCLayer(
                    channels=channels,
                )
            )
        self.layers = torch.nn.ModuleList(layers)
        self.tanh = torch.nn.Tanh()
        self.linear_0 = torch.nn.Linear(channels[0], n_classes)
        self.linear_1 = torch.nn.Linear(channels[1], n_classes)
        self.linear_2 = torch.nn.Linear(channels[2], n_classes)

        self.inter_aggr = Aggregation(
            aggr_func="mean",
            update_func="sigmoid",
        )

        self.intra_aggr = Aggregation(
            aggr_func="sum",
            update_func=None,
        )

    def forward(
        self,
        x_0,
        x_1,
        x_2,
        up_adjacency_0,
        incidence_1,
        up_adjacency_1,
        incidence_2,
        down_adjacency_2,
    ):
        """
        Forward computation through layers. Then local aggregation, then tanh activation, then linear layer, then global aggregation
        (with sigmoid update function).

        Parameters
        ----------
        x_0 : tensor shape=[n_nodes, channels[0]]
            Features on nodes.
        x_1 : tensor shape=[n_1cells, channels[1]]
            Features on one-cells.
        x_2 : tensor shape=[n_2cells, channels[2]]
            Features on two-cells.
        up_adjacency_0 : tensor shape=[n_nodes, n_nodes]
            Adjacency matrix for nodes across one-cells.
        incidence_1 : tensor shape=[n_nodes, n_1cells]
            Incidence matrix mapping one-cells to nodes.
        up_adjacency_1 : tensor shape=[n_1cells, n_1cells]
            Adjacency matrix for one-cells across two-cells.
        incidence_2 : tensor shape=[n_1cells, n_2cells]
            Incidence matrix mapping two-cells to one-cells.
        down_adjacency_2 : tensor shape=[n_2cells, n_2cells]
            Adjacency matrix for two-cells through one-cells.
        """
        for layer in self.layers:
            x_0, x_1, x_2 = layer(
                x_0,
                x_1,
                x_2,
                up_adjacency_0,
                incidence_1,
                up_adjacency_1,
                incidence_2,
                down_adjacency_2,
            )
        x_0 = self.intra_aggr(torch.split(x_0, 1, dim=0))
        x_1 = self.intra_aggr(torch.split(x_1, 1, dim=0))
        x_2 = self.intra_aggr(torch.split(x_2, 1, dim=0))

        x_0 = self.tanh(x_0)
        x_1 = self.tanh(x_1)
        x_2 = self.tanh(x_2)

        x_0 = self.linear_0(x_0)
        x_1 = self.linear_1(x_1)
        x_2 = self.linear_2(x_2)

        return self.inter_aggr([x_0, x_1, x_2])

# Training the Model

In [231]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, random_state=43)
x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, random_state=43)
x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, random_state=43)

up_adj_0_train, up_adj_0_test = train_test_split(
    up_adjacency_0_list, test_size=test_size, random_state=43
)
up_adj_1_train, up_adj_1_test = train_test_split(
    up_adjacency_1_list, test_size=test_size, random_state=43
)
down_adj_2_train, down_adj_2_test = train_test_split(
    down_adjacency_2_list, test_size=test_size, random_state=43
)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=test_size, random_state=43
)
incidence_2_train, incidence_2_test = train_test_split(
    incidence_2_list, test_size=test_size, random_state=43
)

y_train, y_test = train_test_split(ys, test_size=test_size, random_state=43)
y_train.shape

(80,)

In [232]:
model = HOANMCNN(
    channels=channels,
    n_classes=num_classes,
)
model = model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.CrossEntropyLoss()

In [233]:
test_interval = 1
num_epochs = 4

for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, x_1, x_2, up_adj0, inc_1, up_adj1, inc_2, down_adj2, y in zip(
        x_0_train,
        x_1_train,
        x_2_train,
        up_adj_0_train,
        incidence_1_train,
        up_adj_1_train,
        incidence_2_train,
        down_adj_2_train,
        y_train,
    ):
        x_0, x_1, x_2, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor(y).long().to(device),
        )

        up_adj0 = up_adj0.float().to(device)
        inc_1 = inc_1.float().to(device)
        up_adj1 = up_adj1.float().to(device)
        inc_2 = inc_2.float().to(device)
        down_adj2 = down_adj2.float().to(device)

        opt.zero_grad()
        y_hat = model(x_0, x_1, x_2, up_adj0, inc_1, up_adj1, inc_2, down_adj2)
        loss = loss_fn(y_hat.flatten(), y)
        loss.backward()

        opt.step()
        epoch_loss.append(loss.item())

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )

    if epoch_i % test_interval == 0:
        corr = 0
        with torch.no_grad():
            for x_0, x_1, x_2, up_adj0, inc_1, up_adj1, inc_2, down_adj2, y in zip(
                x_0_train,
                x_1_train,
                x_2_train,
                up_adj_0_train,
                incidence_1_train,
                up_adj_1_train,
                incidence_2_train,
                down_adj_2_train,
                y_train,
            ):
                x_0, x_1, x_2, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(x_2).float().to(device),
                    torch.tensor(y).long().to(device),
                )
                x_list = [x_0, x_1, x_2]

                up_adj0 = up_adj0.float().to(device)
                inc_1 = inc_1.float().to(device)
                up_adj1 = up_adj1.float().to(device)
                inc_2 = inc_2.float().to(device)
                down_adj2 = down_adj2.float().to(device)

                y_hat = model(x_0, x_1, x_2, up_adj0, inc_1, up_adj1, inc_2, down_adj2)
                test_loss = loss_fn(y_hat.flatten(), y)
                if torch.argmax(y_hat) == y.item():
                    corr += 1
            acc = corr / (y_test.shape[0])
            print(f"Test Loss: {test_loss:.4f}", flush=True)
            print(f"Test Accuracy: {acc:.4f}")

Epoch: 1 loss: 3.4497
Test Loss: 3.4534
Test Accuracy: 0.2000
Epoch: 2 loss: 3.4228
Test Loss: 3.4924
Test Accuracy: 0.2000
Epoch: 3 loss: 3.4005
Test Loss: 3.4550
Test Accuracy: 0.2500
Epoch: 4 loss: 3.3898
Test Loss: 3.4582
Test Accuracy: 0.2500
