# GNN expressive power


## Setup


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install lightning
!pip install torch torchvision torchaudio
!pip install torch-geometric
!pip install networkx
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html


In [None]:
import os.path as osp

import lightning as L
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
import torch_scatter

import torch_geometric.transforms as T
from torch_geometric.data.lightning import LightningDataset
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import MLP, global_add_pool

from torch_geometric.nn.models.basic_gnn import BasicGNN
from torch_geometric.nn.conv import (
    EdgeConv,
    GATConv,
    GATv2Conv,
    GCNConv,
    GINConv,
    MessagePassing,
    PNAConv,
    SAGEConv,
)
from typing import Callable, Optional, Union, Tuple, List, Final
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    Size,
    SparseTensor,
)
from torch import Tensor
from torch_geometric.utils import spmm

from torch_geometric.data import Data
import networkx as nx

import numpy as np


from torch_geometric.utils import to_dense_adj, to_undirected, to_networkx

from abc import ABC, abstractmethod

from torch_geometric.transforms import BaseTransform

from torch_geometric.data import Dataset

from lightning import LightningModule
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint

from sklearn.model_selection import KFold


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

## Data processing


In [None]:
def calculate_shortest_paths(graph: nx.Graph):
    shortest_paths = nx.floyd_warshall_numpy(graph)

    shortest_paths = np.where(shortest_paths == np.inf, -1, shortest_paths)

    # values = np.unique(shortest_paths)
    # values = values[values > 0]
    # values = np.sort(values)
    # values = torch.tensor(values, dtype=torch.float32)

    m = int(shortest_paths.max())


    res = []
    # for d in values:
    #     d = d.item()
    #     x = torch.tensor(shortest_paths == d).nonzero().t().contiguous()
    #     res.append(x)
    for i in range(m):
        x = torch.tensor(shortest_paths == i + 1).nonzero().t().contiguous()
        res.append(x)

    rest = torch.tensor(shortest_paths == -1).nonzero().t().contiguous()

    return res, rest

In [None]:
def calculate_resistance_distance_matrix(graph: nx.Graph):
    cc = nx.is_connected(graph)

    N = graph.number_of_nodes()

    g_components_list = [
        graph.subgraph(c).copy() for c in nx.connected_components(graph)
    ]
    g_resistance_matrix = np.zeros((N, N)) - 1.0
    g_index = 0
    for item in g_components_list:
        cur_adj = nx.to_numpy_array(item)
        cur_num_nodes = cur_adj.shape[0]
        cur_res_dis = np.linalg.pinv(
            np.diag(cur_adj.sum(axis=-1))
            - cur_adj
            + np.ones((cur_num_nodes, cur_num_nodes), dtype=np.float32) / cur_num_nodes
        ).astype(np.float32)
        A = np.diag(cur_res_dis)[:, None]
        B = np.diag(cur_res_dis)[None, :]
        cur_res_dis = A + B - 2 * cur_res_dis
        g_resistance_matrix[
            g_index : g_index + cur_num_nodes, g_index : g_index + cur_num_nodes
        ] = cur_res_dis
        g_index += cur_num_nodes
    g_cur_index = []
    for item in g_components_list:
        g_cur_index.extend(list(item.nodes))
    ori_idx = np.arange(N)
    g_resistance_matrix[g_cur_index, :] = g_resistance_matrix[ori_idx, :]
    g_resistance_matrix[:, g_cur_index] = g_resistance_matrix[:, ori_idx]

    if (g_resistance_matrix.min() == -1 and cc) or (
        g_resistance_matrix.min() != -1 and not cc
    ):
        raise Exception("aaaaaaa")

    rd_int = np.rint(g_resistance_matrix)

    values = np.unique(rd_int)
    values = values[values > 0]
    values = np.sort(values)
    values = torch.tensor(values, dtype=torch.float32)

    rest = torch.tensor(g_resistance_matrix == -1).nonzero().t().contiguous()

    m = int(rd_int.max())

    res = []
    # for d in values:
    for i in range(m):
        d = d.item()
        x = torch.tensor(rd_int == d).nonzero().t().contiguous()
        res.append(x)

    return res, rest

In [None]:
MAX_DISTANCE = 65
MAX_RESISTANCE_DISTANCE = 30


def calc(data: Data, calc_sp: bool, calc_rd: bool):
    if data.is_directed():
        raise ValueError("Only undirected graphs are supported.")

    if data.edge_attr is not None:
        raise ValueError("Edge attributes are not supported.")
        # data.edge_attr = None

    # if data.num_nodes > MAX_DISTANCE:
    #     raise ValueError(f"Graph has too many nodes ({data.num_nodes} > {MAX_DISTANCE}).")

    if data.has_self_loops():
        raise ValueError("Graph has self-loops.")

    nx_graph = to_networkx(data, to_undirected=True)

    if calc_sp:
        shortest_paths, rest = calculate_shortest_paths(nx_graph)

        num_d = len(shortest_paths)
        if num_d > MAX_DISTANCE:
            raise ValueError(
                f"Graph max distance too high ({num_d+1} >= {MAX_DISTANCE})."
            )

        for i in range(num_d):
            setattr(data, f"sp_edge_index_{i}", shortest_paths[i])

        for i in range(num_d, MAX_DISTANCE):
            setattr(data, f"sp_edge_index_{i}", torch.empty(2, 0, dtype=torch.int))

        data.sp_max = num_d + 1
        data.sp_rest_index = rest

    if calc_rd:
        rd_edge_index, rest = calculate_resistance_distance_matrix(nx_graph)

        num_d = len(rd_edge_index)

        if num_d >= MAX_RESISTANCE_DISTANCE:
            raise ValueError(
                f"Graph has max resistance distance too high ({num_d+1} >= {MAX_RESISTANCE_DISTANCE})."
            )

        for i in range(num_d):
            setattr(data, f"rd_edge_index_{i}", rd_edge_index[i])

        for i in range(num_d, MAX_RESISTANCE_DISTANCE):
            setattr(data, f"rd_edge_index_{i}", torch.empty(2, 0, dtype=torch.int))

        data.rd_max = num_d + 1
        data.rd_rest_index = rest


def get_sp(data, k):
    # print(data.sp_max)
    max_index = int(data.sp_max.max())
    if k is not None:
        max_index = min(max_index, k)
    # max_index = MAX_DISTANCE

    return [getattr(data, f"sp_edge_index_{i}") for i in range(max_index)]


def get_rd(data, k):
    max_index = int(data.rd_max.max())
    if k is not None:
        max_index = min(max_index, k)
    return [getattr(data, f"rd_edge_index_{i}") for i in range(max_index)]

In [None]:
class TransformExtended(BaseTransform):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(
        self,
        data: Data,
    ) -> Data:
        calc(data, calc_sp=True, calc_rd=True)

        return data

## Models


### GINConvBase

Abstract base class for GIN convolutional layers.


In [None]:
class GINConvBase(MessagePassing, ABC):
    def __init__(
        self,
        nn: Callable,
        eps: float = 0.0,
        train_eps: bool = True,
        max_distance: int = MAX_DISTANCE,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)
        self.nn = nn
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.empty(1))
        else:
            self.register_buffer("eps", torch.empty(1))

        self.max_distance = max_distance

    def reset_parameters(self):
        super().reset_parameters()
        reset(self.nn)
        self.eps.data.fill_(self.initial_eps)

    @abstractmethod
    def forward(
        self, x: Union[Tensor, OptPairTensor], data: Data, size: Size = None
    ) -> Tensor:
        pass

    def message(self, x_j: Tensor) -> Tensor:
        return x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: OptPairTensor) -> Tensor:
        if isinstance(adj_t, SparseTensor):
            adj_t = adj_t.set_value(None, layout=None)
        return spmm(adj_t, x[0], reduce=self.aggr)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(nn={self.nn})"

### GINConvPlain

GIN convolutional layer based on GINConvBase.


In [None]:
class GINConvPlain(GINConvBase):
    def forward(
        self, x: Union[Tensor, OptPairTensor], data: Data, size: Size = None
    ) -> Tensor:
        edge_index = data.edge_index

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # propagate_type: (x: OptPairTensor)
        neighbor_data = self.propagate(edge_index, x=x, size=size)

        # print(neighbor_data.shape)

        x_r = x[1]
        if x_r is None:
            raise ValueError("x_r is None")

        out = (1 + self.eps) * x_r + neighbor_data

        return self.nn(out)

### GINConvSP

GIN convolutional layer with shortest path distance based aggregation.


In [None]:
class GINConvSP(GINConvBase):
    def __init__(self, nn: Callable, in_channels: int = None, k: int = None, **kwargs):
        super().__init__(nn, **kwargs)

        if in_channels is None:
            raise ValueError("in_channels param missing")

        self.k = k
        if self.k is None:
          self.sp_global = torch.nn.Parameter(
            torch.empty(in_channels), requires_grad=True
          )
          self.k = MAX_DISTANCE
        else:
          self.sp_global = None

        # self.sp_weights = torch.nn.Parameter(
        #     torch.empty((self.k, in_channels)), requires_grad=True
        # )
        self.sp_weights = torch.nn.Parameter(
            torch.empty(self.k), requires_grad=True
        )

    def reset_parameters(self):
        super().reset_parameters()

        # self.sp_weights.reset_parameters()
        # torch.nn.init.normal_(self.sp_weights, mean=0, std=1)

        # self.sp_weights.data.fill_(1 / MAX_DISTANCE)
        self.sp_weights.data.fill_(1.0)

        if self.sp_global is not None:
          self.sp_global.data.fill_(0.0)

    def forward(self, x, data: Data, size: Size = None) -> Tensor:
        sp_edge_index = get_sp(data, self.k)

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)


        max_dist = len(sp_edge_index)

        normalized_weights = F.softmax(self.sp_weights, dim=0)
        # normalized_weights = self.sp_weights


        neighbor_data = torch.zeros_like(x[1])

        for i in range(max_dist):
            neighbor_data += (
                self.propagate(sp_edge_index[i], x=x, size=size)[1] * normalized_weights[i]
            )

        if self.sp_global is not None:
            neighbor_data += (
                self.propagate(data.sp_rest_index, x=x, size=size)[1]
                * (1 + self.sp_global)
            )

        x_r = x[1]
        if x_r is None:
            raise ValueError("x_r is None")

        out = (1 + self.eps) * x_r + neighbor_data

        return self.nn(out)

### GINConvRD

GIN convolutional layer with resistance distance based aggregation.


In [None]:
class GINConvRD(GINConvBase):
    def __init__(self, nn: Callable, in_channels: int = None, k: int = None, **kwargs):
        super().__init__(nn, **kwargs)


        if in_channels is None:
            raise ValueError("in_channels param missing")

        self.k = k
        if self.k is None:
          self.rd_global = torch.nn.Parameter(
            torch.empty(in_channels), requires_grad=True
          )
          self.k = MAX_RESISTANCE_DISTANCE
        else:
          self.rd_global = None

        # self.rd_weights = torch.nn.Parameter(torch.empty(self.max_distance))
        self.rd_weights = torch.nn.Parameter(
            torch.empty(self.k), requires_grad=True
        )

    def reset_parameters(self):
        super().reset_parameters()

        # self.rd_weights.data.fill_(torch.randn(self.max_distance))
        # self.rd_weights.data.fill_(1)
        # self.rd_weights.data.fill_(1 / MAX_RESISTANCE_DISTANCE)
        self.rd_weights.data = torch.randn(self.k)

        if self.rd_global is not None:
          self.rd_global.data.fill_(0.0)


    def forward(self, x, data: Data, size: Size = None) -> Tensor:
        rd_edge_index = get_rd(data, self.k)

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        neighbor_data = torch.zeros_like(x[1])

        max_dist = len(rd_edge_index)

        normalized_weights = F.softmax(self.rd_weights, dim=0)
        # normalized_weights = self.rd_weights

        for i in range(max_dist):
            neighbor_data += (
                self.propagate(rd_edge_index[i], x=x, size=size)[1] * normalized_weights[i]
            )

        if self.rd_global is not None:
            neighbor_data += (
                self.propagate(data.rd_rest_index, x=x, size=size)[1]
                * (1 + self.rd_global)
            )

        x_r = x[1]
        if x_r is None:
            raise ValueError("x_r is None")

        out = (1 + self.eps) * x_r + neighbor_data

        return self.nn(out)

### GINExtended

GIN implementation which takes a GINConvBase instance as convolutional layer.


In [None]:
class GINExtended(BasicGNN):
    supports_edge_weight: Final[bool] = False
    supports_edge_attr: Final[bool] = False
    supports_norm_batch: Final[bool]

    def init_conv(
        self,
        in_channels: int,
        out_channels: int,
        conv_class: GINConvBase = GINConvPlain,
        model_args: dict = {},
        **kwargs
    ) -> MessagePassing:
        self.conv_class = conv_class

        mlp = MLP(
            [in_channels, out_channels, out_channels],
            act=self.act,
            act_first=self.act_first,
            norm=self.norm,
            norm_kwargs=self.norm_kwargs,
        )
        arch = conv_class(mlp, in_channels=in_channels, **model_args, **kwargs)
        arch.reset_parameters()
        return arch

    def forward(  # noqa
        self,
        data: Data,
        batch_size: Optional[int] = None,
        num_sampled_nodes_per_hop: Optional[List[int]] = None,
        num_sampled_edges_per_hop: Optional[List[int]] = None,
    ):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        xs: List[Tensor] = []
        assert len(self.convs) == len(self.norms)
        for i, (conv, norm) in enumerate(zip(self.convs, self.norms)):
            if num_sampled_nodes_per_hop is not None and not torch.jit.is_scripting():
                x, edge_index, value = self._trim(
                    i,
                    num_sampled_nodes_per_hop,
                    num_sampled_edges_per_hop,
                    x,
                    edge_index,
                    edge_weight if edge_weight is not None else edge_attr,
                )
                if edge_weight is not None:
                    edge_weight = value
                else:
                    edge_attr = value

            x = conv(x, data)

            if i < self.num_layers - 1 or self.jk_mode is not None:
                if self.act is not None and self.act_first:
                    x = self.act(x)
                if self.supports_norm_batch:
                    x = norm(x, batch, batch_size)
                else:
                    x = norm(x)
                if self.act is not None and not self.act_first:
                    x = self.act(x)
                x = self.dropout(x)
                if hasattr(self, "jk"):
                    xs.append(x)

        x = self.jk(xs) if hasattr(self, "jk") else x
        x = self.lin(x) if hasattr(self, "lin") else x

        return x

### Model


In [None]:
class Model(LightningModule):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_channels: int = 64,
        num_layers: int = 3,
        dropout: float = 0.5,
        conv_class: GINConvBase = GINConvPlain,
        learning_rate: float = 0.005,
        model_args: dict = {},
    ):
        super().__init__()

        print("in_channels:", in_channels)
        print("hidden_channels:", hidden_channels)
        print("out_channels:", out_channels)

        self.gnn = GINExtended(
            in_channels,
            hidden_channels,
            num_layers,
            conv_class=conv_class,
            dropout=dropout,
            jk="cat",
            model_args=model_args,
        )

        self.classifier = MLP(
            [hidden_channels] *2 + [out_channels],
            # norm="batch_norm",
            norm=None,
            dropout=dropout,
        )

        self.train_acc = Accuracy(task="multiclass", num_classes=out_channels)
        self.val_acc = Accuracy(task="multiclass", num_classes=out_channels)
        self.test_acc = Accuracy(task="multiclass", num_classes=out_channels)

        self.learning_rate = learning_rate

    def forward(self, data):
        x = self.gnn(data)
        x = global_add_pool(x, data.batch)
        x = self.classifier(x)
        return x

    def training_step(self, data: Data, batch_idx):
        y_hat = self(data)

        loss = F.cross_entropy(y_hat, data.y)

        y_smax = F.log_softmax(y_hat, -1)
        # pred = y_smax.argmax(dim=1)
        self.train_acc(y_smax, data.y)

        self.log(
            "train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=data.batch_size
        )
        self.log(
            "train_acc", self.train_acc, prog_bar=True, on_step=False, on_epoch=True
        )
        return loss

    def validation_step(self, data, batch_idx):
        y_hat = self(data)
        loss = F.cross_entropy(y_hat, data.y)
        self.log(
            "val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=64
        )
        self.val_acc(y_hat.softmax(dim=-1), data.y)
        self.log("val_acc", self.val_acc, prog_bar=True, on_step=False, on_epoch=True)

    def test_step(self, data, batch_idx):
        y_hat = self(data)
        y_smax = y_hat.softmax(dim=-1)
        self.test_acc(y_smax, data.y)

        self.log("test_acc", self.test_acc, prog_bar=True, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

## Training


In [None]:
def train(
    dataset: Dataset,
    conv_class: GINConvBase,
    num_layers: int = 3,
    k_distance: int = None,
    max_epochs=150,
    batch_size=64,
):
    L.seed_everything(42)

    test_dataset = dataset[: len(dataset) // 10]
    val_dataset = dataset[len(dataset) // 10 : 2 * len(dataset) // 10]
    train_dataset = dataset[2 * len(dataset) // 10 :]

    datamodule = LightningDataset(
        train_dataset, val_dataset, test_dataset, batch_size=batch_size, num_workers=4
    )

    early_stop_callback = EarlyStopping(
        monitor="val_loss", min_delta=0.00, patience=15, verbose=False, mode="min"
    )

    model = Model(
        dataset.num_node_features,
        dataset.num_classes,
        conv_class=conv_class,
        num_layers=num_layers,
        model_args={"k": k_distance},
    )

    checkpoint = ModelCheckpoint(monitor="val_acc", save_top_k=1, mode="max")

    accelerator = "gpu" if torch.cuda.is_available() else "mps"

    trainer = Trainer(
        accelerator=accelerator,
        devices=1,
        max_epochs=max_epochs,
        log_every_n_steps=5,
        callbacks=[checkpoint],
    )

    trainer.fit(
        model,
        train_dataloaders=datamodule.train_dataloader(),
        val_dataloaders=datamodule.val_dataloader(),
    )
    trainer.test(ckpt_path="last", dataloaders=datamodule.test_dataloader())
    trainer.test(ckpt_path="best", dataloaders=datamodule.test_dataloader())

In [None]:

def train_kfold(
    dataset: Dataset,
    conv_class: GINConvBase,
    n_splits: int = 3,
    max_epochs: int = 150,
    num_layers: int = 3,
    k_distance:int = None
):
    # L.seed_everything(42)
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    accuracies = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        datamodule = LightningDataset(
            train_subset, val_subset, batch_size=64, num_workers=4
        )

        early_stop_callback = EarlyStopping(
            monitor="val_loss", min_delta=0.00, patience=15, verbose=False, mode="min"
        )

        model = Model(
            dataset.num_node_features,
            dataset.num_classes,
            conv_class=conv_class,
            num_layers=num_layers,
            model_args={"k": k_distance},
            dropout=0.0

        )

        checkpoint = ModelCheckpoint(monitor="val_acc", save_top_k=1, mode="max")
        accelerator = "gpu" if torch.cuda.is_available() else "mps"

        trainer = Trainer(
            accelerator=accelerator,
            devices=1,
            max_epochs=max_epochs,
            log_every_n_steps=5,
            callbacks=[checkpoint, early_stop_callback],
        )

        # Train and validate
        trainer.fit(
            model,
            train_dataloaders=datamodule.train_dataloader(),
            val_dataloaders=datamodule.val_dataloader(),
        )

        # Test and store accuracy for this fold
        test_result = trainer.test(ckpt_path="best", dataloaders=datamodule.val_dataloader())
        print(test_result)
        accuracies.append(test_result[0]["test_acc"])

    # Report k-fold results
    print(f"Average Accuracy: {np.mean(accuracies):.4f} +/- {np.std(accuracies):.4f}")

    return accuracies


## Experiments


In [None]:
L.seed_everything(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

root = osp.join("datasets")

### PROTEINS Dataset


In [None]:

# dataset_preprocess = TUDataset(root, "MUTAG", pre_transform=transform_preprocess)
root_plain = osp.join("datasets_plain")
dataset_proteins_plain = TUDataset(root_plain, "PROTEINS")
dataset_proteins_plain = dataset_proteins_plain.shuffle()

In [None]:
# train(dataset_proteins_plain, GINConvPlain, l=3, max_epochs=50)

In [None]:
transform_preprocess = T.Compose([TransformExtended()])

# dataset_preprocess = TUDataset(root, "MUTAG", pre_transform=transform_preprocess)
dataset_proteins = TUDataset(root, "PROTEINS", pre_transform=transform_preprocess)
dataset_proteins = dataset_proteins.shuffle()

In [None]:
train_kfold(dataset_proteins, GINConvPlain, n_splits=3, num_layers=3, max_epochs=100)

In [None]:
train_kfold(dataset_proteins, GINConvSP, n_splits=3, num_layers=3, k_distance=2, max_epochs=100)

In [None]:
train_kfold(dataset_proteins, GINConvSP, n_splits=3, num_layers=3, k_distance=5, max_epochs=100)

In [None]:
train_kfold(dataset_proteins, GINConvSP, n_splits=3, num_layers=3, k_distance=None, max_epochs=100)

In [None]:
train_kfold(dataset_proteins, GINConvRD, n_splits=3, num_layers=3, k_distance=2, max_epochs=100)

In [None]:
train_kfold(dataset_proteins, GINConvRD, n_splits=3, num_layers=3, k_distance=5, max_epochs=100)

In [None]:
train_kfold(dataset_proteins, GINConvRD, n_splits=3, num_layers=3, k_distance=None, max_epochs=100)

In [None]:
# train(dataset_proteins, GINConvSP, num_layers=3, k_distance=1, max_epochs=50)

In [None]:
# train(dataset_proteins, GINConvSP, num_layers=3, k_distance=1, max_epochs=50)

In [None]:
# train(dataset_proteins, GINConvRD, num_layers=3, k_distance=None, max_epochs=50)

In [None]:
# train(dataset_proteins, GINConvSP, l=3, k=None, max_epochs=50)

In [None]:
# train(dataset_enzymes, GINConvSP, l=3, k=1, max_epochs=50)

In [None]:
# train(dataset_enzymes, GINConvSP, l=3, k=5, max_epochs=50)

In [None]:
# train(dataset_enzymes, GINConvSP, l=3, k=None, max_epochs=50)

## Synthetic Dataset

In [None]:
# Code adopted from Zhang et al. for generating synthetic graphs

import torch
import numpy as np
# from ogb.graphproppred import PygGraphPropPredDataset
# from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset
from functools import lru_cache
import pyximport
import torch.distributed as dist

pyximport.install(setup_args={"include_dirs": np.get_include()})
# from . import algos

from functools import partial
import networkx as nx
from torch_geometric.data import Data

@torch.jit.script
def convert_to_single_emb(x, offset: int = 512):
    feature_num = x.size(1) if len(x.size()) > 1 else 1
    feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long)
    x = x + feature_offset
    return x


function_list = [
        partial(nx.fast_gnp_random_graph, p=0.2, seed=2022),
        partial(nx.connected_watts_strogatz_graph, k=3, p=0.8),
        partial(nx.random_tree, seed=2022),
        partial(nx.random_powerlaw_tree, seed=2022),
        partial(nx.planted_partition_graph, p_in=0.75, p_out=0.01, seed=2022),
        nx.connected_caveman_graph,
        partial(nx.circulant_graph, offsets=[1, 2]),
        partial(nx.circulant_graph, offsets=[1, 3]),
        nx.cycle_graph,
    ]

function_dict = {
    idx: item for idx, item in enumerate(function_list)
}

relabel_list = ['a','b','c','d','e','f','g']


def counter_example_1_1(k, m):
    n = 2 * k * m + 1
    edge_index_1_1 = np.arange(2*k*m) + 1
    edge_index_1_2 = np.mod(edge_index_1_1, 2 * k * m) + 1
    edge_index_1 = np.concatenate((edge_index_1_1[None, :], edge_index_1_2[None, :]), axis=0)

    edge_index_2_1 = np.zeros(2*k*m, dtype=int) + n
    edge_index_2_mask = np.mod(edge_index_1_1, m) == 0
    edge_index_2 = np.concatenate((edge_index_2_1[edge_index_2_mask][None, :], edge_index_1_1[edge_index_2_mask][None, :]), axis=0)

    edge_index = np.concatenate((edge_index_1, edge_index_2), axis=1)

    G = nx.Graph()
    G.add_edges_from((edge_index - 1).T.tolist())
    return G

def counter_example_1_2(k, m):
    n = 2 * k * m + 1
    edge_index_1_1 = np.arange(k * m) + 1
    edge_index_1_2 = np.mod(edge_index_1_1, k * m) + 1
    edge_index_1 = np.concatenate((edge_index_1_1[None, :], edge_index_1_2[None, :]), axis=0)

    edge_index_2_1 = np.zeros(2 * k * m, dtype=int) + n
    edge_index_2_1_1 = np.arange(2*k*m) + 1
    edge_index_2_mask = np.mod(edge_index_2_1_1, m) == 0
    edge_index_2 = np.concatenate(
        (edge_index_2_1[edge_index_2_mask][None, :], edge_index_2_1_1[edge_index_2_mask][None, :]), axis=0)

    edge_index_3_1 = edge_index_1_1 + k * m
    edge_index_3_2 = np.mod(edge_index_1_1, k * m) + k * m + 1
    edge_index_3 = np.concatenate(
        (edge_index_3_1[None, :], edge_index_3_2[None, :]), axis=0
    )

    edge_index = np.concatenate((edge_index_1, edge_index_2, edge_index_3), axis=1)

    G = nx.Graph()
    G.add_edges_from((edge_index - 1).T.tolist())
    return G

def counter_example_2_1(m):
    n = 2 * m
    edge_index_1_1 = np.arange(2 * m) + 1
    edge_index_1_2 = np.mod(edge_index_1_1, n) + 1

    edge_index = np.concatenate((edge_index_1_1[None, :], edge_index_1_2[None, :]), axis=0)
    edge_index = np.concatenate((edge_index, np.array([m, 2*m], dtype=int)[:, None]), axis=1)

    G = nx.Graph()
    G.add_edges_from((edge_index - 1).T.tolist())
    return G

def counter_example_2_2(m):
    n = 2 * m
    edge_index_1_1 = np.arange(m) + 1
    edge_index_1_2 = np.mod(edge_index_1_1, m) + 1

    edge_index = np.concatenate((edge_index_1_1[None, :], edge_index_1_2[None, :]), axis=0)
    edge_index = np.concatenate((edge_index, np.array([m, 2 * m], dtype=int)[:, None]), axis=1)

    edge_index_3_1 = np.arange(m) + 1
    edge_index_3_2 = np.mod(edge_index_3_1, m) + m + 1
    edge_index_3 = np.concatenate(((edge_index_3_1+m)[None, :], edge_index_3_2[None, :]), axis=0)
    edge_index = np.concatenate((edge_index, edge_index_3), axis=1)

    G = nx.Graph()
    G.add_edges_from((edge_index - 1).T.tolist())
    return G

def counter_example_3_component(n, d, out_deg, base_index):
    node_degree_list = np.zeros(n, dtype=int)
    edge_index = []
    node_degree_list[:out_deg] = 1
    choose_list_i = np.argsort(node_degree_list)
    for i in choose_list_i:
        # np.random.shuffle(node_index_list)
        if node_degree_list[i] == d:
            continue
        choose_list_j = np.argsort(node_degree_list)
        for j in choose_list_j:
            if j == i:
                continue
            if node_degree_list[j] < d:
                edge_index.append((i + base_index, j + base_index))
                node_degree_list[j] += 1
                node_degree_list[i] += 1
            if node_degree_list[i] == d:
                break

    return edge_index

def counter_exmple_3():
    # (1) sample the number of connected components
    total_nodes = 40
    n_components = np.random.choice([2,3,4])
    # n_components = 4
    # np.random.seed(10)
    components_tree = nx.random_tree(n_components)
    components_tree_edge_list = list(components_tree.edges())
    components_tree_degree_list = list(components_tree.degree())

    avg_n_max = total_nodes // n_components
    avg_n_min_odd = 7
    avg_n_min_even = 6

    edge_index_list = []
    base_index = [0]
    all_d = np.random.choice([3])
    for i in range(n_components):
        cur_d = all_d
        cur_ed = components_tree_degree_list[i][1]

        if cur_ed % 2 == 1:
            cur_n = np.random.choice(np.arange(start=max(avg_n_min_odd, (cur_d+cur_ed+1)//2*2+1), stop=avg_n_max+2, step=2))
        else:
            cur_n = np.random.choice(np.arange(start=max(avg_n_min_even, ((cur_d+cur_ed+1)//2+1)*2), stop=avg_n_max+2, step=2))

        connect_flag = True
        while connect_flag:
            g_edge_index = counter_example_3_component(cur_n, cur_d, cur_ed, base_index[i])
            G = nx.Graph()
            G.add_edges_from(g_edge_index)
            if nx.is_connected(G):
                connect_flag = False
                edge_index_list.extend(g_edge_index)
                base_index.append(base_index[-1] + cur_n)

    # (2) connect components
    last_degree_list = np.ones(base_index[-1])
    for pair_i, pair_j in components_tree_edge_list:
        i_out_deg = components_tree_degree_list[pair_i][1]
        j_out_deg = components_tree_degree_list[pair_j][1]

        to_i = 0
        to_j = 0
        for i in range(i_out_deg):
            if last_degree_list[i+base_index[pair_i]] > 0:
                to_i = i+base_index[pair_i]
                last_degree_list[i+base_index[pair_i]] -= 1
                break
        for j in range(j_out_deg):
            if last_degree_list[j + base_index[pair_j]] > 0:
                to_j = j+base_index[pair_j]
                last_degree_list[j + base_index[pair_j]] -= 1
                break
        edge_index_list.append((to_i, to_j))

    G = nx.Graph()
    G.add_edges_from(edge_index_list)
    return G

def generate_single_graphs(n, idx):
    assert idx <= 8 and idx >= 0
    if idx == 4 or idx == 5:
        return function_dict[idx](l=4, k=3)
    else:
        return function_dict[idx](n=n)


def generate_cut_examples():
    func_dict = {
        0: counter_example_1_1,
        1: counter_example_1_2,
        2: counter_example_2_1,
        3: counter_example_2_2,
        4: counter_exmple_3,
    }

    # index = int(np.random.choice(5, 1, p=[0.2,0.2,0.2,0.2,0.2]))
    # if index == 4:
    #     return func_dict[index]()
    # elif index > 1:
    #     m = int(np.random.choice(2) + 3)
    #     return func_dict[index](m)
    # else:
    #     m = int(np.random.choice(2) + 3)
    #     k = int(np.random.choice(2) + 1)
    #     return func_dict[index](k, m)

    res = [
        func_dict[4](),
    ]
    for i in range(2, 4):
      res.append(func_dict[i](3))
      res.append(func_dict[i](4))
    for i in range(2):
      res.append(func_dict[i](1, 3))
      res.append(func_dict[i](2, 3))
      res.append(func_dict[i](1, 4))
      res.append(func_dict[i](2, 4))

    return res


def convert_graph_to_item(graph):
    new_item = Data()
    G_int = nx.convert_node_labels_to_integers(graph)
    G = G_int.to_directed() if not nx.is_directed(G_int) else G_int
    edge_index = torch.tensor(list(G.edges)).t().contiguous()

    new_item.edge_index = edge_index
    # new_item.edge_attr = torch.zeros(new_item.edge_index.shape[1], 3).to(item.edge_attr)
    new_item.__num_nodes__ = G.number_of_nodes()
    # new_item.idx = item.idx
    new_item.x = torch.ones(new_item.num_nodes, 9)
    # new_item.y = item.y
    # new_item.resistance_distance = calculate_resistance_distance_online(G_int, new_item.num_nodes)
    # new_item.resistance_distance = torch.zeros(new_item.x.shape[0], 130)
    return new_item


def preprocess_item_articulation_point(new_graph):
    # new_graph = generate_cut_examples()
    item = convert_graph_to_item(new_graph)

    edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x
    # resistance_distance = item.resistance_distance
    N = x.size(0)
    x = convert_to_single_emb(x)

    # node adj matrix [N, N] bool
    adj = torch.zeros([N, N], dtype=torch.bool)
    adj[edge_index[0, :], edge_index[1, :]] = True


    G = nx.from_numpy_array(adj.numpy())
    g_ant_points_list = list(nx.articulation_points(G))
    # y = torch.zeros(item.x.shape[0]) - 1.0
    # y[g_ant_points_list] = 1.0
    item.y = torch.tensor(1 if len(g_ant_points_list) > 0 else 0, dtype=torch.long)

    item.x = x

    return item

def preprocess_item_bridge(new_graph):
    item = convert_graph_to_item(new_graph)

    edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x
    # resistance_distance = item.resistance_distance
    N = x.size(0)
    x = convert_to_single_emb(x)

    # node adj matrix [N, N] bool
    adj = torch.zeros([N, N], dtype=torch.bool)
    adj[edge_index[0, :], edge_index[1, :]] = True

    G = nx.from_numpy_array(adj.numpy())
    g_bridge_list = list(nx.bridges(G))

    item.y = torch.tensor(1 if len(g_bridge_list) > 0 else 0, dtype=torch.long)
    item.x = x

    return item


def generate_bridges():
  x = generate_cut_examples()


  return [preprocess_item_bridge(g) for g in x]

def generate_aps():
  x = generate_cut_examples()

  return [preprocess_item_articulation_point(g) for g in x]



In [None]:
bridges = generate_bridges()
aps = generate_aps()

print("Bridges: ", [d.y.item() for d in bridges], len(bridges))
print("Aps: ", [d.y.item() for d in aps], len(aps))


In [None]:
b1 = [d for d in bridges if d.y == 1]
b0 = [d for d in bridges if d.y == 0]

bridges_balanced = b1*8 + b0*5

by = [d.y for d in bridges_balanced]

print("Total: ", len(bridges_balanced))
print("Has bridge: ", sum(by))

In [None]:
a1 = [d for d in aps if d.y == 1]
a0 = [d for d in aps if d.y == 0]

aps_balanced = a1*6 + a0*7

ay = [d.y for d in aps_balanced]

print("Total: ", len(aps_balanced))
print("Has bridge: ", sum(ay))

In [None]:
from torch_geometric.data import InMemoryDataset
class SyntheticDataset(InMemoryDataset):
    def __init__(self, root, data, transform=None, pre_transform=None):
        self.data_list = data
        super().__init__(root, transform, pre_transform)
        self.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return 'data.pt'

    def process(self):
        if self.pre_transform is not None:
            # print("hallo")
            data_list = self.pre_transform(self.data_list)
        self.save(data_list, self.processed_paths[0])


In [None]:
bridged_idx = 786

In [None]:
# x_bridges =  [bridges[7], bridges[11],  bridges[12] , bridges[10]]
x_bridges = [bridges[0], bridges[11], bridges[7], bridges[1]]
print(x_bridges)

bridged_idx += 1

x_bridges20 = x_bridges * 5
x_bridge_dataset = SyntheticDataset(f"x_bridge_{bridged_idx}", x_bridges20, pre_transform=transform_preprocess)
x_bridge_dataset = x_bridge_dataset.shuffle()

In [None]:
x = np.random.randint(1000)
bridge_dataset = SyntheticDataset(f"bridge_{x}", bridges_balanced, pre_transform=transform_preprocess)
bridge_dataset = bridge_dataset.shuffle()

In [None]:
x = np.random.randint(1000)
aps_dataset = SyntheticDataset(f"aps_{x}", aps_balanced, pre_transform=transform_preprocess)
aps_dataset = aps_dataset.shuffle()

In [None]:
def train_all(
    dataset: Dataset,
    conv_class: GINConvPlain,
    num_layers: int = 3,
    k_distance: int = None,
    max_epochs=150,
    hidden_channels=64,
    learning_rate=0.005
):
    L.seed_everything(42)

    datamodule = LightningDataset(
        dataset, batch_size=1, num_workers=2
    )

    model = Model(
        dataset.num_node_features,
        dataset.num_classes,
        conv_class=conv_class,
        num_layers=num_layers,
        model_args={"k": k_distance},
        hidden_channels=hidden_channels,
        learning_rate=learning_rate,
        dropout=0.
    )

    checkpoint = ModelCheckpoint(monitor="train_acc", save_top_k=1, mode="max", dirpath="models", filename='{epoch}-{train_acc:.2f}')

    accelerator = "gpu" if torch.cuda.is_available() else "mps"

    trainer = Trainer(
        accelerator=accelerator,
        devices=1,
        max_epochs=max_epochs,
        log_every_n_steps=5,
        callbacks=[checkpoint],
    )

    trainer.fit(
        model,
        train_dataloaders=datamodule.train_dataloader(),
    )
    trainer.test(ckpt_path="last",  dataloaders=datamodule.train_dataloader())
    trainer.test(ckpt_path="best",  dataloaders=datamodule.train_dataloader())

In [None]:
x = [f.sp_max for f in aps_dataset]
print(max(x))

In [None]:
x_bridge_dataset[0].y

In [None]:
train_all(bridge_dataset, GINConvSP, num_layers=10, hidden_channels=128, learning_rate=0.001, k_distance=None, max_epochs=500)

In [None]:
train_all(bridge_dataset, GINConvPlain, num_layers=10, hidden_channels=128, k_distance=None, max_epochs=500)

In [None]:
import networkx as nx

import matplotlib.pyplot as plt



def vis(data):
  print([d.y for d in data])
  for i, d in enumerate(data):
    print("Idx:", i)
    plt.figure(i)
    g = to_networkx(d, to_undirected=True)
    nx.draw(g)

  plt.show()

In [None]:
vis(x_bridge_dataset)

In [None]:
train_all(bridge_dataset, GINConvSP, num_layers=1, k_distance=None, max_epochs=300)

In [None]:
train_all(bridge_dataset, GINConvPlain, num_layers=5, k_distance=None, max_epochs=300)

In [None]:
vis(bridge_dataset[0])

In [None]:
vis(bridge_dataset[1])

In [None]:
train_all(bridge_dataset, GINConvPlain, num_layers=1, k_distance=None, max_epochs=1000)

In [None]:
train_all(bridge_dataset, GINConvSP, num_layers=1, hidden_channels=64, k_distance=None, max_epochs=100)