In [None]:
import hashlib
import networkx as nx
from tqdm.auto import tqdm
from typing import List, Dict, Optional


class WeisfeilerLehmanHashing(object):
    """
    Weisfeiler-Lehman feature extractor class.

    Args:
        graph (NetworkX graph): NetworkX graph for which we do WL hashing.
        wl_iterations (int): Number of WL iterations.
        use_node_attribute (Optional[str]): Optional attribute name to be used.
        erase_base_features (bool): Delete the base features after hashing.
    """

    def __init__(
        self,
        graph: nx.classes.graph.Graph,
        wl_iterations: int,
        use_node_attribute: Optional[str],
        erase_base_features: bool,
    ):
        """
        Initialization method which also executes feature extraction.
        """
        self.wl_iterations = wl_iterations
        self.graph = graph
        self.use_node_attribute = use_node_attribute
        self.erase_base_features = erase_base_features
        self._set_features()
        self._do_recursions()

    def _set_features(self):
        """
        Creating the features.
        """
        if self.use_node_attribute is not None:
            # We retrieve the features of the nodes with the attribute name
            # `use_node_attribute` and assign them into a dictionary with structure:
            # {node_a_name: feature_of_node_a}
            # Nodes without this feature will not appear in the dictionary.
            features = nx.get_node_attributes(self.graph, self.use_node_attribute)

            # We check whether all nodes have the requested feature
            if len(features) != self.graph.number_of_nodes():
                missing_nodes = []
                # We find up to five missing nodes so to make
                # a more informative error message.
                for node in tqdm(
                    self.graph.nodes,
                    total=self.graph.number_of_nodes(),
                    leave=False,
                    dynamic_ncols=True,
                    desc="Searching for missing nodes"
                ):
                    if node not in features:
                        missing_nodes.append(node)
                    if len(missing_nodes) > 5:
                        break
                raise ValueError(
                    (
                        "We expected for ALL graph nodes to have a node "
                        "attribute name `{}` to be used as part of "
                        "the requested embedding algorithm, but only {} "
                        "out of {} nodes has the correct attribute. "
                        "Consider checking for typos and missing values, "
                        "and use some imputation technique as necessary. "
                        "Some of the nodes without the requested attribute "
                        "are: {}"
                    ).format(
                        self.use_node_attribute,
                        len(features),
                        self.graph.number_of_nodes(),
                        missing_nodes
                    )
                )
            # If so, we assign the feature set.
            self.features = features
        else:
            # Default: use node degrees as initial labels
            self.features = {
                node: self.graph.degree(node) for node in self.graph.nodes()
            }

        # extracted_features[node] = [f^0(node), f^1(node), ..., f^T(node)]
        self.extracted_features = {k: [str(v)]
                                   for k, v in self.features.items()}

    def _erase_base_features(self):
        """
        Erasing the base features (the initial labels).
        """
        for k, v in self.extracted_features.items():
            # remove the first element (base feature f^0)
            del self.extracted_features[k][0]

    def _do_a_recursion(self) -> Dict[int, str]:
        """
        The method does a single WL recursion.

        Return types:
            * **new_features** *(dict of strings)* - The hash table with extracted WL features.
        """
        new_features = {}
        for node in self.graph.nodes():
            nebs = self.graph.neighbors(node)
            degs = [self.features[neb] for neb in nebs]
            features = [str(self.features[node])] + \
                sorted([str(deg) for deg in degs])
            features = "_".join(features)
            hash_object = hashlib.md5(features.encode())
            hashing = hash_object.hexdigest()
            new_features[node] = hashing

        # Append this iteration's feature to the history
        self.extracted_features = {
            k: self.extracted_features[k] + [v] for k, v in new_features.items()
        }
        return new_features

    def _do_recursions(self):
        """
        The method does a series of WL recursions.
        """
        for _ in range(self.wl_iterations):
            self.features = self._do_a_recursion()
        if self.erase_base_features:
            self._erase_base_features()

    def get_node_features(self) -> Dict[int, List[str]]:
        """
        Return the node level features.

        Returns:
            dict: {node: [feature_it0, feature_it1, ..., feature_itT]}
        """
        return self.extracted_features

    def get_graph_features(self) -> List[str]:
        """
        Return the graph level features as a bag (multiset)
        of all node features across all iterations.

        Returns:
            list of str: concatenation over nodes and iterations.
        """
        return [
            feature
            for node, features in self.extracted_features.items()
            for feature in features
        ]


# ============================================================
# Example usage
# ============================================================
if __name__ == "__main__":
    # -----------------------------
    # Example 1: plain graph, degree as base feature
    # -----------------------------
    G = nx.karate_club_graph()  # classic benchmark graph

    wl = WeisfeilerLehmanHashing(
        graph=G,
        wl_iterations=2,         # number of WL iterations
        use_node_attribute=None, # start from degrees
        erase_base_features=False
    )

    node_feats = wl.get_node_features()
    graph_feats = wl.get_graph_features()

    print("Example 1: Karate graph (degree-based WL)")
    # features for a single node (e.g., node 0)
    print("Node 0 features:", node_feats[0])
    print("Total number of graph-level features:", len(graph_feats))
    print("First 10 graph features:", graph_feats[:10])

    # -----------------------------
    # Example 2: using a node attribute
    # -----------------------------
    H = nx.Graph()
    H.add_nodes_from([0, 1, 2, 3])
    H.add_edges_from([(0, 1), (1, 2), (2, 3)])

    # Add a node attribute "label"
    nx.set_node_attributes(H, {
        0: "A",
        1: "B",
        2: "A",
        3: "C",
    }, name="label")

    wl_attr = WeisfeilerLehmanHashing(
        graph=H,
        wl_iterations=3,
        use_node_attribute="label",  # use this attribute instead of degree
        erase_base_features=True     # optionally drop the original labels
    )

    node_feats_attr = wl_attr.get_node_features()
    graph_feats_attr = wl_attr.get_graph_features()

    print("\nExample 2: Small graph with 'label' node attribute")
    for node, feats in node_feats_attr.items():
        print(f"Node {node} features:", feats)

    print("Graph-level feature multiset (first few):", graph_feats_attr[:10])


Example 1: Karate graph (degree-based WL)
Node 0 features: ['16', 'ef9153b15158f0e1f50641e7bf8c1b28', '50402170d75b12360445faec47820429']
Total number of graph-level features: 102
First 10 graph features: ['16', 'ef9153b15158f0e1f50641e7bf8c1b28', '50402170d75b12360445faec47820429', '9', 'da1b21c9ce096bf157fe29ba7c59563b', '3d0d0163a1a838e3b3f26bf1090220ca', '10', 'f3c087dc4f4756e76a65388e0f7bf976', '4bdae6d3c2f6dc1f170cd3804b8e2d42', '6']

Example 2: Small graph with 'label' node attribute
Node 0 features: ['350f8fcc3cd566cd76e9906ea8e5936b', 'b0f06103498e6e63e39c3de8de18fbbe', 'c31c3ea6b78a54ef15a9855ce4d0527a']
Node 1 features: ['03cd6f755301fc7ddc543563e50e2183', 'cfbf572289373f2e2c464e052c709b19', '1140848dbe25c31020720fcdd74e00ae']
Node 2 features: ['3510e0bfa31ada47a9b76fc9ba9dff17', 'e079c71875a4f69ce9b55abb5bc794a4', '34be3d3dd583406233aa5b19bb74542b']
Node 3 features: ['a4d1dc80e304ab3d90468c2f37cba861', '073e851d7bbb87ccca45ab16c7b612e2', 'a5e07b02f5915e2528c69d13be7539f0']


In [None]:
import networkx as nx

# --- Graph A: Cycle C6 ---
G1 = nx.cycle_graph(6)

# --- Graph B: Two triangles ---
G2 = nx.disjoint_union(nx.cycle_graph(3), nx.cycle_graph(3))

wl1 = WeisfeilerLehmanHashing(G1, wl_iterations=5, use_node_attribute=None, erase_base_features=False)
wl2 = WeisfeilerLehmanHashing(G2, wl_iterations=5, use_node_attribute=None, erase_base_features=False)

print("Graph features match?", sorted(wl1.get_graph_features()) == sorted(wl2.get_graph_features()))


Graph features match? True


In [None]:
!pip install toponetx

Collecting toponetx
  Downloading TopoNetX-0.2.0-py3-none-any.whl.metadata (13 kB)
Collecting trimesh (from toponetx)
  Downloading trimesh-4.10.0-py3-none-any.whl.metadata (13 kB)
Downloading TopoNetX-0.2.0-py3-none-any.whl (114 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.4/114.4 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading trimesh-4.10.0-py3-none-any.whl (736 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m736.6/736.6 kB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh, toponetx
Successfully installed toponetx-0.2.0 trimesh-4.10.0


In [None]:
"""Functions for computing neighborhoods of a complex."""

from typing import Literal

import toponetx as tnx
from scipy.sparse import csr_matrix


def neighborhood_from_complex(
    domain: tnx.Complex,
    neighborhood_type: Literal["adj", "coadj"] = "adj",
    neighborhood_dim=None,
) -> tuple[list, csr_matrix]:
    """Compute the neighborhood of a complex.

    This function returns the indices and matrix for the neighborhood specified by
    `neighborhood_type` and `neighborhood_dim` for the input complex `domain`.

    Parameters
    ----------
    domain : toponetx.classes.Complex
        The complex to compute the neighborhood for.
    neighborhood_type : {"adj", "coadj"}, default="adj"
        The type of neighborhood to compute. "adj" for adjacency matrix, "coadj" for coadjacency matrix.
    neighborhood_dim : dict
        The integer parameters needed to specify the neighborhood of the cells to generate the embedding.
        In TopoNetX  (co)adjacency neighborhood matrices are specified via one or two parameters.
        - For Cell/Simplicial/Path complexes (co)adjacency matrix is specified by a single parameter, this is precisely
        neighborhood_dim["rank"].
        - For Combinatorial/ColoredHyperGraph the (co)adjacency matrix is specified by a single parameter, this is precisely
        neighborhood_dim["rank"] and neighborhood_dim["via_rank"].

    Notes
    -----
        here neighborhood_dim={"rank": 1, "via_rank": -1} specifies the dimension for
        which the cell embeddings are going to be computed.
        "rank": 1 means that the embeddings will be computed for the first dimension.
        The integer "via_rank": -1 is ignored when the input is cell/simplicial complex
        and  must be specified when the input complex is a combinatorial complex or
        colored hypergraph.

    Returns
    -------
    ind : list
        A list of the indices for the nodes in the neighborhood.
    A : scipy.sparse.csr_matrix
        The matrix representing the neighborhood.

    Raises
    ------
    TypeError
        If `domain` is not a SimplicialComplex, CellComplex, PathComplex ColoredHyperGraph or CombinatorialComplex.
    TypeError
        If `neighborhood_type` is invalid.
    """
    if neighborhood_dim is None:
        neighborhood_dim = {"rank": 0, "via_rank": -1}

    if neighborhood_type not in ["adj", "coadj"]:
        raise TypeError(
            f"Input neighborhood_type must be `adj` or `coadj`, got {neighborhood_type}."
        )

    if isinstance(domain, tnx.SimplicialComplex | tnx.CellComplex | tnx.PathComplex):
        if neighborhood_type == "adj":
            ind, A = domain.adjacency_matrix(neighborhood_dim["rank"], index=True)
        else:
            ind, A = domain.coadjacency_matrix(neighborhood_dim["rank"], index=True)
    elif isinstance(domain, tnx.CombinatorialComplex | tnx.ColoredHyperGraph):
        if neighborhood_type == "adj":
            ind, A = domain.adjacency_matrix(
                neighborhood_dim["rank"], neighborhood_dim["via_rank"], index=True
            )
        else:
            ind, A = domain.coadjacency_matrix(
                neighborhood_dim["rank"], neighborhood_dim["via_rank"], index=True
            )
    else:
        raise TypeError(
            "Input Complex can only be a SimplicialComplex, CellComplex, PathComplex ColoredHyperGraph or CombinatorialComplex."
        )

    return ind, A

In [51]:
import hashlib
from typing import Any, Dict, List, Optional, Union

import networkx as nx
from scipy.sparse import csr_matrix

#from topoembedx.neighborhood import neighborhood_from_complex


class HigherOrderWeisfeilerLehmanHashing:
    """
    A fully general multi-neighborhood Weisfeiler–Lehman (WL) refinement engine.

    This class supports:

    (1) Standard graph WL       (domain = networkx.Graph)
    (2) Higher-order WL         (TopoNetX complex + one neighborhood)
    (3) Full SWL-style WL       (domain + multiple neighborhoods)

    Update rule:

        c^{t+1}_σ = HASH(
            c^t_σ,
            M^{(1)}_t(σ),
            M^{(2)}_t(σ),
            ...,
            M^{(m)}_t(σ)
        )

    where each M^{(i)}_t(σ) is the multiset of neighbor labels under
    the i-th neighborhood adjacency A^{(i)}.
    """

    # ------------------------------------------------------------------
    def __init__(self, wl_iterations: int = 3, erase_base_features: bool = False):
        self.wl_iterations = wl_iterations
        self.erase_base_features = erase_base_features

        self.domain = None
        self.nodes: List[Any] = []           # list of cell IDs (TopoNetX) or nodes (graph)
        self.adj_mats: List[csr_matrix] = [] # list of CSR neighborhood matrices

        self.extracted_features: Dict[Any, List[str]] = {}  # node → [labels]

    # ------------------------------------------------------------------
    def fit(
        self,
        domain: Union[nx.Graph, "tnx.Complex"],
        neighborhood_types: Union[str, List[str]] = "adj",
        neighborhood_dims: Optional[Union[Dict, List[Optional[Dict]]]] = None,
    ) -> "HigherOrderWeisfeilerLehmanHashing":
        """
        Fit Higher-Order WL on a domain with one or more neighborhoods.

        Parameters
        ----------
        domain : nx.Graph OR TopoNetX Complex
        neighborhood_types : str or list[str]
        neighborhood_dims : dict or list[dict]

        Returns
        -------
        self
        """

        self.domain = domain

        # -------------------------------------------------------
        # CASE 1: Graph domain
        # -------------------------------------------------------
        if isinstance(domain, nx.Graph):
            ind = list(domain.nodes())
            self.nodes = ind[:]

            A = nx.to_scipy_sparse_array(domain, nodelist=ind).tocsr()
            self.adj_mats = [A]

        # -------------------------------------------------------
        # CASE 2: TopoNetX complex domain
        # -------------------------------------------------------
        else:
            # Normalize types list
            if isinstance(neighborhood_types, str):
                types_list = [neighborhood_types]
            else:
                types_list = list(neighborhood_types)

            # Normalize dims list
            if neighborhood_dims is None or isinstance(neighborhood_dims, dict):
                dims_list = [neighborhood_dims] * len(types_list)
            else:
                dims_list = list(neighborhood_dims)

            if len(types_list) != len(dims_list):
                raise ValueError(
                    "neighborhood_types and neighborhood_dims must have same length."
                )

            ind_lists = []
            mats = []
            for ntype, ndim in zip(types_list, dims_list):
                ind, A = neighborhood_from_complex(
                    domain,
                    neighborhood_type=ntype,
                    neighborhood_dim=ndim,
                )
                ind_lists.append(ind)
                mats.append(A.tocsr())

            # All index lists must match
            base_ind = ind_lists[0]
            for ind in ind_lists[1:]:
                if ind != base_ind:
                    raise ValueError(
                        "All neighborhoods must use the same index list. "
                        "Reorder before passing to WL."
                    )

            self.nodes = base_ind[:]
            self.adj_mats = mats

        # -------------------------------------------------------
        # Initialize labels in index space {0, ..., n-1}
        # -------------------------------------------------------
        self._node_to_pos = {node: i for i, node in enumerate(self.nodes)}

        labels = {pos: "0" for pos in range(len(self.nodes))}
        self.extracted_features = {node: ["0"] for node in self.nodes}

        # -------------------------------------------------------
        # Run WL iterations
        # -------------------------------------------------------
        for _ in range(self.wl_iterations):
            labels = self._wl_step(labels)

        # Remove base features if needed
        if self.erase_base_features:
            for node in self.extracted_features:
                if self.extracted_features[node]:
                    del self.extracted_features[node][0]

        return self

    # ------------------------------------------------------------------
    def _wl_step(self, labels: Dict[int, str]) -> Dict[int, str]:
        """One iteration of WL refinement."""
        new_labels = {}
        n = len(self.nodes)

        for pos in range(n):
            parts = [labels[pos]]  # own label

            # For each adjacency matrix
            for A in self.adj_mats:
                row = A.getrow(pos)
                neigh_pos = row.indices.tolist()
                neigh_labels = sorted(labels[j] for j in neigh_pos)
                parts.append("_".join(neigh_labels))

            concat = "|".join(parts)
            hashed = hashlib.md5(concat.encode()).hexdigest()
            new_labels[pos] = hashed

            ext_node = self.nodes[pos]
            self.extracted_features[ext_node].append(hashed)

        return new_labels

    # ------------------------------------------------------------------
    def get_cell_features(self) -> Dict[Any, List[str]]:
        """Return dict: cell/node → list of WL labels over iterations."""
        return self.extracted_features

    # ------------------------------------------------------------------
    def get_domain_features(self) -> List[str]:
        """Return a flat list of all WL labels across all cells."""
        return [
            f for feats in self.extracted_features.values() for f in feats
        ]


# Higher-Order Weisfeiler–Lehman on Complexes (Single- and Multi-Neighborhood)

This tutorial explains the idea behind **higher-order Weisfeiler–Lehman (WL)** on **TopoNetX complexes**, and how it is implemented in the class

```python
HigherOrderWeisfeilerLehmanHashing
````

which supports:

1. **Standard graph WL** (domain = `networkx.Graph`),
2. **Higher-order WL** on complexes using a **single** neighborhood matrix,
3. **SWL-style multi-neighborhood WL**, where several neighborhood matrices are combined in one update rule.

The goals are:

* to see how WL is **lifted from graphs to complexes** (simplicial, cell, combinatorial),
* to understand the role of **ranks** and **neighborhood matrices** (A^{(r,k)}),
* to see how **multiple neighborhoods** are combined in the update rule,
* and to interpret the behavior of the test examples (successes and limitations).

---

## 1. Classical 1-WL on Graphs

Given a graph

$$
G = (V, E), \qquad V = {1, \dots, n},
$$

the 1-dimensional Weisfeiler–Lehman test (1-WL, or color refinement) maintains a label
$$
\ell_t(v)
$$
for every node (v \in V) at iteration (t).

The **neighbor multiset** at iteration (t) is

$$
M_t(v) = {\ell_t(u) : u \in N(v)},
$$

where (N(v)) is the set of neighbors of (v).

The **WL update** at iteration (t+1) is

$$
\ell_{t+1}(v) = h\big(\ell_t(v),, M_t(v)\big),
$$

where (h) is an injective encoding (hash) of

* the current label (\ell_t(v)), and
* the multiset of neighbor labels (M_t(v))

into a new discrete label.

After (T) iterations, a **graph-level WL signature** is often taken as the multiset

$$
\mathrm{Sig}(G) = {\ell_T(v) : v \in V}.
$$

If (\mathrm{Sig}(G_1) \neq \mathrm{Sig}(G_2)), then WL distinguishes (G_1) and (G_2).
If the signatures are equal, WL fails to distinguish them (for example, the classical case of (C_6) versus (C_3 \sqcup C_3)).

In the class `HigherOrderWeisfeilerLehmanHashing`, when the domain is a **graph**:

* we ignore `neighborhood_types` and `neighborhood_dims`,
* we use the standard adjacency matrix of the graph,
* and we run exactly this **1-WL** procedure on the nodes.

---

## 2. Complexes and Ranks in TopoNetX

TopoNetX provides several kinds of complexes:

* `SimplicialComplex`: vertices, edges, triangles, tetrahedra, and higher-dimensional simplices,
* `CellComplex`: more general cells glued along boundaries,
* `CombinatorialComplex`: very general; cells are finite sets with an incidence structure,
* `ColoredHyperGraph`: hypergraph-like structure with colors.

Each complex is graded by a **rank** (or dimension):

* rank (0): vertices (0-cells),
* rank (1): edges (1-cells),
* rank (2): faces (2-cells),
* rank (3): volumes (3-cells), etc.

We write

$$
C_r = {c_1, \dots, c_{n_r}}
$$

for the set of all cells of rank (r).

Higher-order WL on complexes will run on these cells (rather than on vertices only), by viewing them as nodes in an appropriate **cell graph**.

---

## 3. Neighborhood Matrices (A^{(r,k)}) from TopoNetX

To define neighborhoods of cells in a complex, we use **neighborhood matrices** provided (or induced) by TopoNetX and by the helper

```python
neighborhood_from_complex
```

For a complex (\mathcal{K}), a rank (r), and a “via rank” (k), we can build a binary matrix

$$
A^{(r,k)} \in {0,1}^{n_r \times n_r},
$$

where:

* rows and columns index rank-(r) cells (c_i \in C_r),
* (A^{(r,k)}_{ij} = 1) if cells (c_i) and (c_j) are considered neighbors “via rank (k)”.

Two common patterns:

### 3.1 Adjacency (type `"adj"`)

This is typically used when (r < k). Examples:

* **vertex–vertex adjacency via edges**:

  (A^{(0,1)}_{uv} = 1) if vertices (u) and (v) share an edge,

* **edge–edge adjacency via faces**:

  (A^{(1,2)}_{ij} = 1) if edges (c_i) and (c_j) lie in a common 2-cell (face).

### 3.2 Coadjacency (type `"coadj"`)

This is typically used when (r > k). Examples:

* **face–face coadjacency via edges**:

  (A^{(2,1)}_{ij} = 1) if faces (c_i) and (c_j) share an edge,

* **edge–edge coadjacency via vertices**:

  (A^{(1,0)}_{ij} = 1) if edges (c_i) and (c_j) share a vertex.

In code, we obtain these matrices (and their index ordering) via:

```python
ind, A = neighborhood_from_complex(
    domain,
    neighborhood_type="adj",  # or "coadj", "boundary", "coboundary"
    neighborhood_dim={"rank": r, "via_rank": k},
)
```

* `ind` is a list of cell IDs (TopoNetX cells, simplices, etc.),
* `A` is a sparse `csr_matrix` encoding the neighborhood relation.

The WL class itself only needs:

* the ordered list of cells `ind`,
* one or more adjacency-like matrices (A^{(r,k)}) on those cells.

---

## 4. From Single-Neighborhood HO-WL to Multi-Neighborhood SWL

### 4.1 Single-Neighborhood Higher-Order WL

If we choose **one** neighborhood matrix (A^{(r,k)}) for a fixed rank (r), we can build a **cell graph**

$$
G_r = (C_r, E_r), \qquad (c_i, c_j) \in E_r \iff A^{(r,k)}_{ij} = 1.
$$

We then run 1-WL on this graph exactly as on a usual graph, but with the nodes being rank-(r) cells. At iteration (t), each cell (c \in C_r) has a label (\ell_t(c)). Its neighbor multiset is

$$
M_t(c) = {\ell_t(c') : c' \in N(c)},
$$

where (N(c)) is the set of neighbors of (c) according to (A^{(r,k)}).

The WL update is

$$
\ell_{t+1}(c) = h\big(\ell_t(c),, M_t(c)\big).
$$

After (T) iterations, the **higher-order WL signature** of (\mathcal{K}) at rank (r) is

$$
\mathrm{Sig}^{(r,k)}(\mathcal{K}) = {\ell_T(c) : c \in C_r}.
$$

Some special cases:

* (r = 0, k = 1): standard 1-WL on the 0-skeleton graph (vertices via edges),
* (r = 2, k = 1): WL on faces that share edges,
* (r = 3, k = 0): WL on 3-cells that share vertices, and so on.

In code, this corresponds to:

```python
hol = HigherOrderWeisfeilerLehmanHashing(
    wl_iterations=3,
    erase_base_features=False,
).fit(
    domain=complex_object,
    neighborhood_types="coadj",              # or "adj"
    neighborhood_dims={"rank": r, "via_rank": k},
)

cell_features = hol.get_cell_features()   # cell → [label_0, ..., label_T]
signature    = hol.get_domain_features()  # flat bag of labels
```

This recovers the usual “higher-order WL on complexes” with a **single** neighborhood matrix.

---

### 4.2 Multi-Neighborhood SWL-Style WL

In SWL-style refinements (for example, in sheaf WL), we want to combine **several** neighborhood relations in a single update.

Suppose we have (m) neighborhood matrices on the same rank-(r) cells:

$$
A^{(1)}, A^{(2)}, \dots, A^{(m)},
$$

all of size (n_r \times n_r), and all using the same ordering of cells (C_r).

At iteration (t), each cell (c) has a label (\ell_t(c)). For each matrix (A^{(i)}), we define

$$
M_t^{(i)}(c) = {\ell_t(c') : A^{(i)}_{cc'} = 1}.
$$

The **multi-neighborhood WL update** is then

$$
\ell_{t+1}(c)
= h\big(\ell_t(c),, M_t^{(1)}(c),, M_t^{(2)}(c),, \dots,, M_t^{(m)}(c)\big).
$$

This is exactly what `HigherOrderWeisfeilerLehmanHashing` does when you pass multiple neighborhoods:

```python
hol = HigherOrderWeisfeilerLehmanHashing(
    wl_iterations=3,
    erase_base_features=False,
).fit(
    domain=complex_object,
    neighborhood_types=["adj", "coadj"],
    neighborhood_dims=[
        {"rank": 1, "via_rank": 2},  # e.g. edges via faces
        {"rank": 1, "via_rank": 0},  # e.g. edges via vertices
    ],
)
```

Internally:

* for each pair `(neighborhood_type, neighborhood_dim)` it calls `neighborhood_from_complex`,
* it collects the index lists and matrices,
* it checks that all index lists are identical (so that all matrices refer to the same ordered set of cells),
* it stores all these matrices in `self.adj_mats`,
* at each WL iteration, it aggregates neighbor labels from **each** matrix before hashing.

The final labels for each cell thus encode information from **all** neighborhoods simultaneously.

---

## 5. How the Class Behaves Conceptually

Given a domain `domain`, the call

```python
hol = HigherOrderWeisfeilerLehmanHashing(
    wl_iterations=T,
    erase_base_features=False,
).fit(
    domain,
    neighborhood_types=...,   # optional for graphs
    neighborhood_dims=...,    # optional for graphs
)
```

does the following:

1. **Domain type**:

   * If `domain` is a `networkx.Graph`:

     * the nodes of the graph become `self.nodes`,
     * the standard graph adjacency is stored as a single matrix in `self.adj_mats`,
     * classical 1-WL is run on the graph.

   * If `domain` is a TopoNetX complex (simplicial, cell, combinatorial, etc.):

     * for each neighborhood in `neighborhood_types` and `neighborhood_dims`,
       it calls `neighborhood_from_complex(domain, neighborhood_type, neighborhood_dim)`,
     * it collects the index lists and matrices,
     * it checks that all index lists are identical; if not, it raises an error,
     * it stores that common index list as `self.nodes`,
     * it stores all corresponding matrices in `self.adj_mats`.

2. **Initialization**:

   * every position (0, 1, \dots, n-1) gets initial label `"0"`,
   * for each external cell ID, `self.extracted_features[cell]` is initialized as `["0"]`.

3. **WL iterations**:

   * for each iteration:

     * for each position `pos`:

       * start from its own current label,
       * for each matrix in `self.adj_mats`:

         * look up the neighbors of `pos`,
         * collect their labels, sort them, and append them to the feature parts,
       * concatenate all parts into a single string and hash it with `md5`,
       * store the new label in the next-iteration label dictionary,
       * also append the new label to `self.extracted_features[cell_id]`.

4. **Output**:

   * `.get_cell_features()` returns a dictionary

     ```python
     {cell_id: [label_0, label_1, ..., label_T]}
     ```

   * `.get_domain_features()` returns a flat list of all labels across all cells and all iterations, which you can view as a multiset encoding the WL signature.

   * If `erase_base_features=True`, the initial label `"0"` is removed from each label list before returning.

This gives you a single, unified WL engine that handles:

* standard 1-WL on graphs,
* higher-order WL on complexes with one neighborhood matrix,
* and multi-neighborhood SWL-style updates involving several adjacency structures at once.



In [58]:
# ============================================================
# Higher-Order WL on TopoNetX Complexes + Test Suite
# ============================================================

import hashlib
from typing import Any, Dict, List, Optional, Union, Literal

import networkx as nx
import numpy as np
import toponetx as tnx
from scipy.sparse import csr_matrix, hstack, vstack


# ============================================================
# Neighborhood construction for complexes
# ============================================================

def neighborhood_from_complex(
    domain: tnx.Complex,
    neighborhood_type: Literal["adj", "coadj", "boundary", "coboundary"] = "adj",
    neighborhood_dim: Optional[Dict] = None,
) -> tuple[list, csr_matrix]:
    """
    Compute a neighborhood matrix for a TopoNetX complex.

    This function returns the indices and matrix for the neighborhood specified by
    ``neighborhood_type`` and ``neighborhood_dim`` for the input complex ``domain``.

    Supported neighborhood types
    ----------------------------
    1. ``"adj"`` (adjacency on a single rank)
       - Nodes are cells of a fixed rank r.
       - Two r-cells are adjacent if they share suitable (co)faces, as defined
         by TopoNetX's `adjacency_matrix` implementation.

    2. ``"coadj"`` (coadjacency on a single rank)
       - Similar to `"adj"`, but using the complex's `coadjacency_matrix`.

    3. ``"boundary"`` / ``"coboundary"`` (Hasse graph from incidence)
       - Here we use the complex's `incidence_matrix` to build an undirected
         **Hasse graph** between two consecutive ranks:
             lower rank: r-1
             upper rank: r
       - Nodes are the union of all (r-1)-cells and all r-cells.
       - Edges connect an (r-1)-cell to an r-cell whenever the incidence is nonzero.

       For the purposes of an undirected neighborhood, `"boundary"` and
       `"coboundary"` return the **same** Hasse adjacency; they are conceptually
       different (downward vs upward), but the graph is the same.

    Parameters
    ----------
    domain : toponetx.classes.Complex
        The complex to compute the neighborhood for.
        Must be one of:
        - SimplicialComplex
        - CellComplex
        - PathComplex
        - CombinatorialComplex
        - ColoredHyperGraph
    neighborhood_type : {"adj", "coadj", "boundary", "coboundary"}, default="adj"
        The type of neighborhood to compute.
    neighborhood_dim : dict, optional
        Integer parameters specifying which rank(s) to use.

        For "adj"/"coadj":
        ------------------
        - For Simplicial/Cell/Path:
              neighborhood_dim["rank"]
          selects the rank r whose cells will be the nodes.

        - For Combinatorial/ColoredHyperGraph:
              neighborhood_dim["rank"], neighborhood_dim["via_rank"]
          specify both the rank r of the nodes and the intermediate rank
          via which adjacency/coadjacency is computed.

        For "boundary"/"coboundary":
        ----------------------------
        We use the **incidence** between rank r and rank r-1 via
        `domain.incidence_matrix`:

        - For Simplicial/Cell/Path:
              domain.incidence_matrix(rank=r, index=True)
          is assumed to return:
              ind_low, ind_high, B
          where
              * ind_low  : labels of (r-1)-cells,
              * ind_high : labels of r-cells,
              * B        : incidence matrix (shape n_low X n_high).

        - For Combinatorial/ColoredHyperGraph:
              domain.incidence_matrix(rank=r, via_rank=r-1, index=True)
          is assumed to return:
              ind_low, ind_high, B
          with the same semantics.

        If ``neighborhood_dim`` is None, we default to:
            neighborhood_dim = {"rank": 0, "via_rank": -1}

    Returns
    -------
    ind : list
        A list of the indices for the nodes in the neighborhood graph.
        - For "adj"/"coadj": indices of rank-r cells.
        - For "boundary"/"coboundary": indices of rank-(r-1) cells followed
          by indices of rank-r cells (Hasse graph nodes).
    A : scipy.sparse.csr_matrix
        The matrix representing the neighborhood.
        - For "adj"/"coadj": square (n_r X n_r) adjacency/coadjacency matrix.
        - For "boundary"/"coboundary":
          square ((n_{r-1} + n_r) X (n_{r-1} + n_r)) adjacency matrix of the
          bipartite Hasse graph induced by the incidence matrix.

    Raises
    ------
    TypeError
        If `domain` is not a supported complex type.
    TypeError
        If `neighborhood_type` is invalid.
    """
    # Default neighborhood dimensions
    if neighborhood_dim is None:
        neighborhood_dim = {"rank": 0, "via_rank": -1}

    if neighborhood_type not in ["adj", "coadj", "boundary", "coboundary"]:
        raise TypeError(
            "Input neighborhood_type must be one of "
            "'adj', 'coadj', 'boundary', or 'coboundary', "
            f"got {neighborhood_type}."
        )

    # ------------------------------------------------------------
    # Case 1: adjacency / coadjacency on a fixed rank
    # ------------------------------------------------------------
    if neighborhood_type in ["adj", "coadj"]:
        # Simplicial / Cell / Path
        if isinstance(
            domain, tnx.SimplicialComplex | tnx.CellComplex | tnx.PathComplex
        ):
            r = neighborhood_dim["rank"]
            if neighborhood_type == "adj":
                ind, A = domain.adjacency_matrix(r, index=True)
            else:
                ind, A = domain.coadjacency_matrix(r, index=True)

        # Combinatorial / ColoredHyperGraph
        elif isinstance(domain, tnx.CombinatorialComplex | tnx.ColoredHyperGraph):
            r = neighborhood_dim["rank"]
            via = neighborhood_dim.get("via_rank", None)
            if neighborhood_type == "adj":
                ind, A = domain.adjacency_matrix(r, via, index=True)
            else:
                ind, A = domain.coadjacency_matrix(r, via, index=True)

        else:
            raise TypeError(
                "Input Complex can only be a SimplicialComplex, CellComplex, "
                "PathComplex, ColoredHyperGraph or CombinatorialComplex."
            )

        return list(ind), csr_matrix(A).asformat("csr")

    # ------------------------------------------------------------
    # Case 2: boundary / coboundary → Hasse graph from incidence_matrix
    # ------------------------------------------------------------
    if not hasattr(domain, "incidence_matrix"):
        raise TypeError(
            "The given complex does not provide an 'incidence_matrix' method, "
            "so 'boundary'/'coboundary' neighborhoods are not supported."
        )

    r = neighborhood_dim["rank"]

    # Two cases: (SC / Cell / Path) vs (CC / ColoredHyperGraph)
    if isinstance(domain, tnx.SimplicialComplex | tnx.CellComplex | tnx.PathComplex):
        # Expected signature:
        #   ind_low, ind_high, B = domain.incidence_matrix(rank=r, index=True)
        ind_low, ind_high, B = domain.incidence_matrix(r, index=True)  # type: ignore[arg-type]
    elif isinstance(domain, tnx.CombinatorialComplex | tnx.ColoredHyperGraph):
        # Expected signature:
        #   ind_low, ind_high, B = domain.incidence_matrix(rank=r, via_rank=r-1, index=True)
        via = neighborhood_dim.get("via_rank", r - 1)
        ind_low, ind_high, B = domain.incidence_matrix(  # type: ignore[arg-type]
            rank=r, via_rank=via, index=True
        )
    else:
        raise TypeError(
            "Input Complex can only be a SimplicialComplex, CellComplex, "
            "PathComplex, ColoredHyperGraph or CombinatorialComplex."
        )

    # Make sure B is CSR and unsigned
    B = abs(B).asformat("csr")
    n_low, n_high = B.shape

    # Build bipartite Hasse adjacency:
    #   [ 0   B ]
    #   [ B^T 0 ]
    zero_low = csr_matrix((n_low, n_low))
    zero_high = csr_matrix((n_high, n_high))

    upper = hstack([zero_low, B])
    lower = hstack([B.transpose(), zero_high])
    A_hasse = vstack([upper, lower]).asformat("csr")

    # Node indices = (r-1)-cells followed by r-cells
    ind = list(ind_low) + list(ind_high)

    return ind, A_hasse


# ============================================================
# General Higher-Order WL (multi-neighborhood)
# ============================================================

class HigherOrderWeisfeilerLehmanHashing:
    """
    General multi-neighborhood Weisfeiler–Lehman (WL) refinement.

    Supports:
      (1) Standard graph WL          (domain = networkx.Graph)
      (2) Higher-order WL on complex (one neighborhood)
      (3) SWL-style WL               (multiple neighborhoods A^(r,k))

    Update rule (conceptually):

        c^{t+1}_σ = HASH(
            c^t_σ,
            M^{(1)}_t(σ),
            ...,
            M^{(m)}_t(σ)
        )

    where M^{(i)}_t(σ) is the multiset of neighbor labels under the i-th
    neighborhood matrix A^{(i)}.
    """

    def __init__(self, wl_iterations: int = 3, erase_base_features: bool = False):
        self.wl_iterations = wl_iterations
        self.erase_base_features = erase_base_features

        self.domain = None
        self.nodes: List[Any] = []             # external IDs: vertices / cells
        self.adj_mats: List[csr_matrix] = []   # list of CSR matrices

        self.extracted_features: Dict[Any, List[str]] = {}  # external node -> labels

    # -------------------------------------------------------
    def fit(
        self,
        domain: Union[nx.Graph, "tnx.Complex"],
        neighborhood_types: Union[str, List[str]] = "adj",
        neighborhood_dims: Optional[Union[Dict, List[Optional[Dict]]]] = None,
    ) -> "HigherOrderWeisfeilerLehmanHashing":
        """
        Fit WL on a domain with one or more neighborhoods.

        Parameters
        ----------
        domain : networkx.Graph or TopoNetX complex
        neighborhood_types : str or list[str]
            If domain is a graph:
                - this argument is ignored; plain adjacency is used.
            If domain is a complex:
                - either a single neighborhood type (e.g. "adj")
                - or a list of types (e.g. ["adj","coadj"]).
        neighborhood_dims : dict or list[dict], optional
            Rank specifications (e.g. {"rank": 2, "via_rank": 1}).
        """
        self.domain = domain

        # ---------------------------------------------------
        # CASE 1: Graph domain → standard 1-WL
        # ---------------------------------------------------
        if isinstance(domain, nx.Graph):
            ind = list(domain.nodes())
            self.nodes = list(ind)  # force Python list

            A_arr = nx.to_scipy_sparse_array(domain, nodelist=ind)
            A = csr_matrix(A_arr)   # ensure csr_matrix
            self.adj_mats = [A]

        # ---------------------------------------------------
        # CASE 2: TopoNetX complex → use neighborhood_from_complex
        # ---------------------------------------------------
        else:
            # Normalize types to list
            if isinstance(neighborhood_types, str):
                types_list = [neighborhood_types]
            else:
                types_list = list(neighborhood_types)

            # Normalize dims to list
            if neighborhood_dims is None or isinstance(neighborhood_dims, Dict):
                dims_list = [neighborhood_dims] * len(types_list)
            else:
                dims_list = list(neighborhood_dims)

            if len(types_list) != len(dims_list):
                raise ValueError(
                    "neighborhood_types and neighborhood_dims must have same length."
                )

            ind_lists = []
            mats = []
            for ntype, ndim in zip(types_list, dims_list):
                ind, A_raw = neighborhood_from_complex(
                    domain,
                    neighborhood_type=ntype,
                    neighborhood_dim=ndim,
                )
                # Ensure index is a plain list:
                if not isinstance(ind, list):
                    ind = list(ind)
                # Ensure matrix is csr_matrix:
                A = csr_matrix(A_raw).asformat("csr")

                ind_lists.append(ind)
                mats.append(A)

            # Ensure all index lists are identical (same ordering of cells)
            base = list(ind_lists[0])
            for cur in ind_lists[1:]:
                if list(cur) != base:
                    raise ValueError(
                        "All neighborhoods must use the same index list. "
                        "Reorder externally before calling WL."
                    )

            self.nodes = base[:]     # external IDs (cells or vertices)
            self.adj_mats = mats

        # ---------------------------------------------------
        # Initialize labels on index space {0,...,n-1}
        # ---------------------------------------------------
        self._pos_to_node = {i: node for i, node in enumerate(self.nodes)}

        labels = {i: "0" for i in range(len(self.nodes))}
        self.extracted_features = {node: ["0"] for node in self.nodes}

        # ---------------------------------------------------
        # Run WL iterations
        # ---------------------------------------------------
        for _ in range(self.wl_iterations):
            labels = self._wl_step(labels)

        # Optionally remove base label
        if self.erase_base_features:
            for node in self.extracted_features:
                if self.extracted_features[node]:
                    del self.extracted_features[node][0]

        return self

    # -------------------------------------------------------
    def _wl_step(self, labels: Dict[int, str]) -> Dict[int, str]:
        """One WL refinement step in index space."""
        new_labels: Dict[int, str] = {}
        n = len(self.nodes)

        for pos in range(n):
            parts = [labels[pos]]  # own current label

            # Aggregate multiset of neighbor labels for each neighborhood
            for A in self.adj_mats:
                row = A.getrow(pos)
                neigh_idx = row.indices.tolist()
                neigh_labels = sorted(labels[j] for j in neigh_idx)
                parts.append("_".join(neigh_labels))

            concat = "|".join(parts)
            hashed = hashlib.md5(concat.encode()).hexdigest()
            new_labels[pos] = hashed

            ext_node = self._pos_to_node[pos]
            self.extracted_features[ext_node].append(hashed)

        return new_labels

    # -------------------------------------------------------
    def get_cell_features(self) -> Dict[Any, List[str]]:
        """Return dict: external cell/node → list of WL labels over iterations."""
        return self.extracted_features

    def get_domain_features(self) -> List[str]:
        """Return flattened multiset of labels across all cells and iterations."""
        return [f for feats in self.extracted_features.values() for f in feats]


# ============================================================
# Pretty printing helper for tests
# ============================================================

def print_check(name, cond, when_true, when_false):
    mark = "✓" if cond else "✗"
    print(f"{mark} {name}")
    print(f"    Result        : {cond}")
    print(f"    Interpretation: {when_true if cond else when_false}")
    print()


# ============================================================
# TEST 1 — Graph WL fails, HO-WL succeeds (rank 2, CC)
# ============================================================

def test_graph_wl_fails_but_higherorder_wl_succeeds():
    verts = [0, 1, 2, 3]
    edges = [(0,1),(1,2),(2,3),(3,0),(0,2),(0,3)]

    # CC1: one quad face
    cc1 = tnx.CombinatorialComplex()
    for v in verts:
        cc1.add_cell([v], rank=0)
    for e in edges:
        cc1.add_cell(list(e), rank=1)
    cc1.add_cell([0,1,2,3], rank=2)

    # CC2: two triangle faces
    cc2 = tnx.CombinatorialComplex()
    for v in verts:
        cc2.add_cell([v], rank=0)
    for e in edges:
        cc2.add_cell(list(e), rank=1)
    cc2.add_cell([0,1,2], rank=2)
    cc2.add_cell([0,2,3], rank=2)

    # Graph WL on 0-skeleton adjacency (rank=0 via rank=1)
    ind1_0, A1_0 = neighborhood_from_complex(
        cc1, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": 1}
    )
    ind2_0, A2_0 = neighborhood_from_complex(
        cc2, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": 1}
    )

    G1 = nx.from_scipy_sparse_array(A1_0)
    G2 = nx.from_scipy_sparse_array(A2_0)

    wl1 = HigherOrderWeisfeilerLehmanHashing().fit(G1)
    wl2 = HigherOrderWeisfeilerLehmanHashing().fit(G2)

    graph_equal = sorted(wl1.get_domain_features()) == sorted(
        wl2.get_domain_features()
    )

    # HO-WL on 2-cells (faces via edges: rank=2, via_rank=1, coadj)
    ho1 = HigherOrderWeisfeilerLehmanHashing().fit(
        cc1, neighborhood_types="coadj", neighborhood_dims={"rank": 2, "via_rank": 1}
    )
    ho2 = HigherOrderWeisfeilerLehmanHashing().fit(
        cc2, neighborhood_types="coadj", neighborhood_dims={"rank": 2, "via_rank": 1}
    )

    ho_diff = sorted(ho1.get_domain_features()) != sorted(
        ho2.get_domain_features()
    )

    return graph_equal, ho_diff


# ============================================================
# TEST 2 — Classic 1-WL failure (C6 vs C3 ⊔ C3)
# ============================================================

def test_higherorder_wl_failure_on_cell_complex():
    cx1 = tnx.CellComplex(
        [[0,1],[1,2],[2,3],[3,4],[4,5],[5,0]], ranks=1
    )
    cx2 = tnx.CellComplex(
        [[0,1],[1,2],[2,0],[3,4],[4,5],[5,3]], ranks=1
    )

    hol1 = HigherOrderWeisfeilerLehmanHashing().fit(
        cx1, neighborhood_types="adj", neighborhood_dims={"rank": 0, "via_rank": -1}
    )
    hol2 = HigherOrderWeisfeilerLehmanHashing().fit(
        cx2, neighborhood_types="adj", neighborhood_dims={"rank": 0, "via_rank": -1}
    )

    return sorted(hol1.get_domain_features()) == sorted(
        hol2.get_domain_features()
    )


# ============================================================
# TEST 3 — Invariance under relabeling (SC, rank 1)
# ============================================================

def test_higherorder_wl_invariance_sc():
    sc1 = tnx.SimplicialComplex([[0,1,2],[1,2,3]])
    sc2 = tnx.SimplicialComplex([[10,11,12],[11,12,13]])

    hol1 = HigherOrderWeisfeilerLehmanHashing().fit(
        sc1, neighborhood_types="adj", neighborhood_dims={"rank": 1, "via_rank": -1}
    )
    hol2 = HigherOrderWeisfeilerLehmanHashing().fit(
        sc2, neighborhood_types="adj", neighborhood_dims={"rank": 1, "via_rank": -1}
    )

    return sorted(hol1.get_domain_features()) == sorted(
        hol2.get_domain_features()
    )


# ============================================================
# TEST 4 — Edge-level limitation (CX, rank 1 via rank 2)
# ============================================================

def test_edge_level_higherorder_wl_behavior():
    edges = [[0,1],[1,2],[2,3],[3,0],[0,2]]

    cxA = tnx.CellComplex(edges, ranks=1)
    cxA.add_cell([0,1,2,3], rank=2)

    cxB = tnx.CellComplex(edges, ranks=1)
    cxB.add_cell([0,1,2], rank=2)
    cxB.add_cell([0,2,3], rank=2)

    holA = HigherOrderWeisfeilerLehmanHashing().fit(
        cxA, neighborhood_types="coadj", neighborhood_dims={"rank": 1, "via_rank": 2}
    )
    holB = HigherOrderWeisfeilerLehmanHashing().fit(
        cxB, neighborhood_types="coadj", neighborhood_dims={"rank": 1, "via_rank": 2}
    )

    return sorted(holA.get_domain_features()) == sorted(
        holB.get_domain_features()
    )


# ============================================================
# TEST 5 — CC rank-2 and rank-3 sensitivity
# ============================================================

def test_cc_rank2_and_rank3_behavior():
    verts = [0,1,2,3]
    edges = [[0,1],[1,2],[2,0],[1,3],[2,3]]
    faces = [[0,1,2],[1,2,3]]

    ccA = tnx.CombinatorialComplex()
    for v in verts:
        ccA.add_cell([v], rank=0)
    for e in edges:
        ccA.add_cell(e, rank=1)
    for f in faces:
        ccA.add_cell(f, rank=2)
    ccA.add_cell([0,1,2], rank=3)

    ccB = tnx.CombinatorialComplex()
    for v in verts:
        ccB.add_cell([v], rank=0)
    for e in edges:
        ccB.add_cell(e, rank=1)
    for f in faces:
        ccB.add_cell(f, rank=2)
    ccB.add_cell([0,1,2], rank=3)
    ccB.add_cell([1,2,3], rank=3)

    holA2 = HigherOrderWeisfeilerLehmanHashing().fit(
        ccA, neighborhood_types="coadj", neighborhood_dims={"rank": 2, "via_rank": 1}
    )
    holB2 = HigherOrderWeisfeilerLehmanHashing().fit(
        ccB, neighborhood_types="coadj", neighborhood_dims={"rank": 2, "via_rank": 1}
    )

    holA3 = HigherOrderWeisfeilerLehmanHashing().fit(
        ccA, neighborhood_types="coadj", neighborhood_dims={"rank": 3, "via_rank": 0}
    )
    holB3 = HigherOrderWeisfeilerLehmanHashing().fit(
        ccB, neighborhood_types="coadj", neighborhood_dims={"rank": 3, "via_rank": 0}
    )

    r2 = sorted(holA2.get_domain_features()) != sorted(
        holB2.get_domain_features()
    )
    r3 = sorted(holA3.get_domain_features()) != sorted(
        holB3.get_domain_features()
    )
    return r2, r3


# ============================================================
# TEST 6 — Invariance under relabeling (CC, rank 2)
# ============================================================

def test_higherorder_wl_invariance_cc():
    cc1 = tnx.CombinatorialComplex()
    for v in [0,1,2,3]:
        cc1.add_cell([v], rank=0)
    for e in [[0,1],[1,2],[2,3],[3,0]]:
        cc1.add_cell(e, rank=1)
    for f in [[0,1,2],[0,2,3]]:
        cc1.add_cell(f, rank=2)

    cc2 = tnx.CombinatorialComplex()
    for v in [10,11,12,13]:
        cc2.add_cell([v], rank=0)
    for e in [[10,11],[11,12],[12,13],[13,10]]:
        cc2.add_cell(e, rank=1)
    for f in [[10,11,12],[10,12,13]]:
        cc2.add_cell(f, rank=2)

    hol1 = HigherOrderWeisfeilerLehmanHashing().fit(
        cc1, neighborhood_types="coadj", neighborhood_dims={"rank": 2, "via_rank": 1}
    )
    hol2 = HigherOrderWeisfeilerLehmanHashing().fit(
        cc2, neighborhood_types="coadj", neighborhood_dims={"rank": 2, "via_rank": 1}
    )

    return sorted(hol1.get_domain_features()) == sorted(
        hol2.get_domain_features()
    )


# ============================================================
# (Optional) Neighborhood dimension sanity test (CellComplex)
# ============================================================

def _sanity_check_neighborhood_cell_complex():
    cc1 = tnx.CellComplex(
        [[0,1,2,3], [1,2,3,4], [1,3,4,5,6,7,8]]
    )
    cc2 = tnx.CellComplex([[0,1,2],[1,2,3]])

    ind1_adj, A1_adj = neighborhood_from_complex(cc1, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": -1})
    ind2_adj, A2_adj = neighborhood_from_complex(cc2, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": -1})

    print("Sanity check: CellComplex adjacency shapes:")
    print("  cc1 adj shape:", A1_adj.todense().shape, "len(ind1_adj) =", len(ind1_adj))
    print("  cc2 adj shape:", A2_adj.todense().shape, "len(ind2_adj) =", len(ind2_adj))

    ind1_co, A1_co = neighborhood_from_complex(cc1, neighborhood_type="coadj", neighborhood_dim={"rank": 1, "via_rank": 0})
    ind2_co, A2_co = neighborhood_from_complex(cc2, neighborhood_type="coadj", neighborhood_dim={"rank": 1, "via_rank": 0})

    print("  cc1 coadj shape:", A1_co.todense().shape, "len(ind1_co) =", len(ind1_co))
    print("  cc2 coadj shape:", A2_co.todense().shape, "len(ind2_co) =", len(ind2_co))
    print()


# ============================================================
# RUN ALL TESTS
# ============================================================

print("\n================ HIGHER-ORDER WL TEST SUITE ================\n")

print("TEST 1: Graph WL failure vs Higher-Order WL success\n")
g_equal, ho_diff = test_graph_wl_fails_but_higherorder_wl_succeeds()
print_check(
    "Graph WL on 0-skeleton adjacency (CC1 vs CC2)",
    g_equal,
    "Correct: graph WL ONLY sees 0–1 skeleton → cannot distinguish CC1/CC2.",
    "Unexpected: graph WL wrongly detects a difference at 0–1 skeleton level.",
)
print_check(
    "Higher-Order WL on rank-2 coadjacency (faces via edges)",
    ho_diff,
    "Correct: HO-WL on faces distinguishes single quad vs two triangles.",
    "Unexpected: HO-WL on faces does not detect the rank-2 difference.",
)

print("TEST 2: Classic 1-WL failure (C6 vs C3⊔C3) in CellComplex form\n")
fail_ok = test_higherorder_wl_failure_on_cell_complex()
print_check(
    "Rank-0 adjacency WL signature (C6 vs C3⊔C3)",
    fail_ok,
    "Correct: reproduces the known 1-WL failure example.",
    "Unexpected: WL separates C6 from C3⊔C3, contradicting the classical result.",
)

print("TEST 3: Invariance under vertex relabeling (SimplicialComplex, rank 1)\n")
inv_sc = test_higherorder_wl_invariance_sc()
print_check(
    "Rank-1 adjacency WL signature (isomorphic SCs)",
    inv_sc,
    "Correct: WL is invariant under combinatorial isomorphism.",
    "Unexpected: WL breaks under a simple vertex relabeling.",
)

print("TEST 4: Edge-level HO-WL limitation (CellComplex, rank 1 via rank 2)\n")
edge_equal = test_edge_level_higherorder_wl_behavior()
print_check(
    "Rank-1 coadjacency via rank-2 (edges via faces)",
    edge_equal,
    "As observed: this particular neighborhood fails to distinguish 1 quad vs 2 triangles.",
    "Stronger-than-expected: in this environment WL does distinguish them.",
)

print("TEST 5: CombinatorialComplex — rank-2 and rank-3 sensitivity\n")
r2, r3 = test_cc_rank2_and_rank3_behavior()
print_check(
    "Rank-2 coadjacency (faces via edges)",
    r2,
    "Correct: rank-2 HO-WL already detects the higher-rank differences.",
    "Unexpected: rank-2 HO-WL was not sensitive to the extra 3-cell.",
)
print_check(
    "Rank-3 coadjacency (3-cells via vertices)",
    r3,
    "Correct: rank-3 HO-WL detects the difference in 3-cells.",
    "Unexpected: rank-3 HO-WL was not sensitive to the extra 3-cell.",
)

print("TEST 6: Invariance under vertex relabeling (CombinatorialComplex, rank 2)\n")
inv_cc = test_higherorder_wl_invariance_cc()
print_check(
    "Rank-2 coadjacency invariance (isomorphic CCs)",
    inv_cc,
    "Correct: HO-WL respects combinatorial isomorphism in the CC setting.",
    "Unexpected: HO-WL does not respect isomorphism here.",
)

print("OPTIONAL: Neighborhood sanity check on CellComplex\n")
_sanity_check_neighborhood_cell_complex()

print("=============================================================\n")




TEST 1: Graph WL failure vs Higher-Order WL success

✓ Graph WL on 0-skeleton adjacency (CC1 vs CC2)
    Result        : True
    Interpretation: Correct: graph WL ONLY sees 0–1 skeleton → cannot distinguish CC1/CC2.

✓ Higher-Order WL on rank-2 coadjacency (faces via edges)
    Result        : True
    Interpretation: Correct: HO-WL on faces distinguishes single quad vs two triangles.

TEST 2: Classic 1-WL failure (C6 vs C3⊔C3) in CellComplex form

✓ Rank-0 adjacency WL signature (C6 vs C3⊔C3)
    Result        : True
    Interpretation: Correct: reproduces the known 1-WL failure example.

TEST 3: Invariance under vertex relabeling (SimplicialComplex, rank 1)

✓ Rank-1 adjacency WL signature (isomorphic SCs)
    Result        : True
    Interpretation: Correct: WL is invariant under combinatorial isomorphism.

TEST 4: Edge-level HO-WL limitation (CellComplex, rank 1 via rank 2)

✓ Rank-1 coadjacency via rank-2 (edges via faces)
    Result        : True
    Interpretation: As observe


# Higher-Order Weisfeiler–Lehman Hashing on Complexes

This notebook explores how the classical Weisfeiler–Lehman (WL) graph test can be generalized to **cell complexes**, **simplicial complexes**, and **combinatorial complexes** using **neighborhood matrices**, and how this *higher-order* WL can distinguish structures that are invisible to graph WL.

We will:

1. Recall the **1-dimensional WL test on graphs**.
2. Define **neighborhood matrices** on complexes (TopoNetX-style).
3. Define **higher-order WL on complexes** using the class `HigherOrderWeisfeilerLehmanHashing`.
4. Walk through several examples where graph WL fails but higher-order WL succeeds (or behaves differently).
5. Show a final example where **graph, SC, and CX WL all fail**, but **CC WL** succeeds.

---

## 1. Weisfeiler–Lehman (WL) on Graphs

Let a graph be

$$
G = (V, E).
$$

The **1-dimensional WL test** (also called *color refinement*) iteratively updates node labels based on neighbor labels.

### 1.1 Initialization

Each node (v \in V) starts with an initial label (\ell_0(v)).

This can be, for example, a constant label for all nodes, or something simple like the node degree.

Formally, we have

$$
\ell_0 : V \to \mathcal{A},
$$

where (\mathcal{A}) is a set of discrete labels (strings, integers, etc.).

### 1.2 WL Update Rule

At iteration (t = 0, 1, \dots), each node has a label (\ell_t(v)).

Define the **multiset of neighbor labels** at iteration (t) as

$$
M_t(v) = { \ell_t(u) ;:; u \in N(v) },
$$

where (N(v)) is the set of neighbors of (v).

The **WL update** at iteration (t+1) is

$$
\ell_{t+1}(v)
=============

h\big( \ell_t(v), M_t(v) \big),
$$

where (h) is an *injective* hashing / encoding function. In practice, we serialize (\ell_t(v)) and the (sorted) neighbor labels into a string and hash it.

After (T) iterations, the **WL signature** of a graph is the multiset of all node labels:

$$
\mathrm{sig}(G)
===============

\big{ \ell_T(v) ;:; v \in V \big}.
$$

Two graphs are **1-WL indistinguishable** if their WL signatures are the same.

In our code, when we call `HigherOrderWeisfeilerLehmanHashing().fit(G)` with a `networkx.Graph`:

* The class automatically builds the usual adjacency matrix of (G),
* Runs several WL iterations,
* And `get_domain_features()` returns a multiset-like collection of final labels encoding (\mathrm{sig}(G)).

---

## 2. Complexes and Neighborhood Matrices (TopoNetX)

Graphs only give us one level of structure (vertices + edges). In **TopoNetX**, we can work with richer objects:

* `SimplicialComplex`: simplices (vertices, edges, triangles, tetrahedra, …).
* `CellComplex`: cells of arbitrary shapes (edges, polygons, 3D cells, …) glued along faces.
* `CombinatorialComplex`: a very general incidence-based complex where cells can have arbitrary ranks and incidence patterns.
* Other classes: `PathComplex`, `ColoredHyperGraph`, etc.

Each complex is **graded by rank**:

* Rank (0): vertices (0-cells),
* Rank (1): edges (1-cells),
* Rank (2): faces (2-cells),
* Rank (3): volumes / 3-cells, etc.

We denote the set of rank-(r) cells by

$$
C_r = { c_1, \dots, c_{n_r} }.
$$

### 2.1 Rank and Via-Rank

We would like to define “neighbors” between rank-(r) cells in a way that respects the combinatorial structure of the complex.

TopoNetX does this via **neighborhood matrices** defined by:

* a **rank** (r): which cells we are labeling (vertices, edges, faces, …),
* a **via-rank** (k): which cells we use to connect them,
* a **neighborhood type**: `"adj"` or `"coadj"`.

Intuitively:

* **Adjacency** (`"adj"`): neighbors via higher-rank cells

  * “two rank-(r) cells are adjacent if they both lie in some rank-(k) *coface* (usually (k > r))”.
* **Coadjacency** (`"coadj"`): neighbors via lower-rank cells

  * “two rank-(r) cells are coadjacent if they share some rank-(k) *face* (usually (k < r))”.

TopoNetX exposes these via methods such as:

* `adjacency_matrix(rank, via_rank, index=True)`
* `coadjacency_matrix(rank, via_rank, index=True)`

Our helper function

```python
ind, A = neighborhood_from_complex(domain,
                                   neighborhood_type=...,
                                   neighborhood_dim={"rank": r, "via_rank": k})
```

wraps these calls and returns:

* `ind`: a list of identifiers of the rank-(r) cells,
* `A`: a sparse matrix encoding the adjacency or coadjacency between these cells.

Some common patterns:

* **0-skeleton (graph) adjacency** for simplicial / cell complexes:

  ```python
  neighborhood_dim = {"rank": 0, "via_rank": -1}
  ```

  This is the usual vertex–vertex adjacency via edges.

* **2-cell coadjacency via edges** (faces that share an edge):

  ```python
  neighborhood_dim = {"rank": 2, "via_rank": 1}
  ```

* **3-cell coadjacency via vertices** in a `CombinatorialComplex`:

  ```python
  neighborhood_dim = {"rank": 3, "via_rank": 0}
  ```

The key idea: for each ((r,k)) and neighborhood type, we get a **graph of (r)-cells**, and we can apply WL to that graph.

---

## 3. Higher-Order WL on Complexes

We now generalize WL from **nodes in a graph** to **cells of arbitrary rank** in a complex.

### 3.1 Induced Cell Graph

Fix:

* a complex (\mathcal{K}),
* a rank (r) (we will run WL on rank-(r) cells),
* a via-rank (k),
* a neighborhood type (`"adj"` or `"coadj"`).

From TopoNetX we obtain:

* an index set

  $$
  C_r = { c_1, \dots, c_{n_r} },
  $$

* a neighborhood matrix

  $$
  A^{(r,k)} \in {0,1}^{n_r \times n_r}.
  $$

We interpret this as a **graph**:

* **Nodes**: the rank-(r) cells (c_i \in C_r),
* **Edges**: ((c_i, c_j)) is an edge iff

  $$
  A^{(r,k)}_{ij} \neq 0.
  $$

We call this the **cell graph at rank (r)** induced by ((r,k)).

### 3.2 Higher-Order WL Update Rule

We now run WL on this cell graph.

At iteration (t), each cell (c \in C_r) has a label (\ell_t(c)). Its neighbor label multiset is

$$
M_t(c) = { \ell_t(c') ;:; c' \in N(c) },
$$

where (N(c)) is the set of neighbors in the cell graph.

The WL update rule is

$$
\ell_{t+1}(c) = h\big( \ell_t(c), M_t(c) \big).
$$

Just as in the graph case, (h) is an injective encoding of the current label and the multiset of neighbor labels.

### 3.3 Higher-Order WL Signature

After (T) iterations, the **higher-order WL signature** of (\mathcal{K}) at rank (r) (via rank (k)) is

$$
\mathrm{sig}^{(r,k)}(\mathcal{K})
=================================

\big{ \ell_T(c) ;:; c \in C_r \big}.
$$

Special cases:

* (r = 0, k = 1): standard 1-WL on the 0-skeleton graph.
* (r = 2, k = 1): WL on faces that share an edge.
* (r = 3, k = 0): WL on 3-cells that share vertices, etc.

In code, we implement this using the class:

```python
HigherOrderWeisfeilerLehmanHashing(
    wl_iterations=...,
    erase_base_features=...
).fit(
    domain=complex_or_graph,
    neighborhood_types="adj" or "coadj",
    neighborhood_dims={"rank": r, "via_rank": k},
)
```

**Graph case.** If `domain` is a `networkx.Graph`, the class ignores `neighborhood_types` and builds the standard adjacency matrix automatically.

**Multi-neighborhood case.** The class also supports multiple neighborhoods at once:

```python
HigherOrderWeisfeilerLehmanHashing().fit(
    domain=complex,
    neighborhood_types=["adj", "coadj"],
    neighborhood_dims=[{"rank": 2, "via_rank": 1},
                      {"rank": 2, "via_rank": 0}],
)
```

In that case, the update rule becomes

$$
\ell_{t+1}(c)
=============

h\Big(
\ell_t(c),
M_t^{(1)}(c),
M_t^{(2)}(c),
\dots
\Big),
$$

where each (M_t^{(i)}(c)) is the multiset of neighbor labels under the (i)-th neighborhood matrix.

---

## 4. How to Read the Printed Output of the Test Suite

In the code cell that runs the **Higher-Order WL test suite**, you will see output like:

```text
TEST 1: Graph WL failure vs Higher-Order WL success

✓ Graph WL on 0-skeleton adjacency (CC1 vs CC2)
    Result        : True
    Interpretation: Correct: graph WL ONLY sees 0–1 skeleton → cannot distinguish CC1/CC2.
```

Each check has:

* A **label**: what we are comparing (e.g., which rank, which neighborhood).
* A boolean `Result: True/False`:

  * For some checks, `True` means “signatures are equal” (WL **fails** to distinguish).
  * For others, `True` means “signatures are different” (WL **succeeds** to distinguish).
* An **Interpretation** that explains why that outcome is expected or interesting.

Always read the label + interpretation together:
they tell you whether “True” means **success** or **intentional failure** in that context.

---

## 5. Examples and What They Demonstrate

The test suite code builds a sequence of examples that illustrate:

1. **Graph WL failure vs higher-order WL success** (CombinatorialComplex, rank 2).
2. **Reproducing a classic 1-WL failure** (cycle (C_6) vs two triangles (C_3 \sqcup C_3)) using CellComplexes.
3. **Invariance under vertex relabeling** on a SimplicialComplex.
4. **A limitation of a particular edge-level neighborhood** on a CellComplex.
5. **A CombinatorialComplex example where both rank-2 and rank-3 HO-WL see higher-order structure.**
6. **Invariance under vertex relabeling** on a CombinatorialComplex.

On top of that, we also have a separate **“graph / SC / CX all fail, but CC succeeds”** example, which you can run in a separate cell.

Below is a conceptual map of these examples, independent of the exact function names in the code.

---

### 5.1 Test 1 — Graph WL Fails, Higher-Order WL Succeeds (CombinatorialComplex, Rank 2)

**Complex type:** `CombinatorialComplex`.

* **Vertices:** ({0,1,2,3}).
* **Edges:** a fixed K(_4)-like pattern
  ((0,1), (1,2), (2,3), (3,0), (0,2), (0,3)),
  identical in both complexes.

We build two complexes:

1. **CC1**:

   * Rank-0: vertices.
   * Rank-1: edges as above.
   * Rank-2: one 2-cell ([0,1,2,3]) (a single quadrilateral face).

2. **CC2**:

   * Rank-0: vertices.
   * Rank-1: same edges.
   * Rank-2: two 2-cells ([0,1,2]) and ([0,2,3]) (two triangles).

**Behavior:**

* **Graph WL on 0-skeleton adjacency** (rank 0 via rank 1):

  * Underlying vertex–edge graph is identical.
  * WL signatures match.
  * This is a **failure** of graph WL to see the difference in faces.

* **Higher-order WL on rank-2 coadjacency via rank 1** (faces sharing edges):

  * CC1 has 1 face; CC2 has 2 faces that share an edge.
  * The induced cell graphs at rank 2 differ.
  * Higher-order WL signatures differ → **success** at rank 2.

---

### 5.2 Test 2 — Classic 1-WL Failure: (C_6) vs (C_3 \sqcup C_3) as CellComplex

**Complex type:** `CellComplex` (1-dimensional).

We construct:

* `cx1`: a 6-cycle (C_6),
* `cx2`: disjoint union of two triangles (C_3 \sqcup C_3),

using only 0- and 1-cells.

**Behavior:**

* We run higher-order WL on **rank-0 adjacency** (`"adj"`, `{"rank": 0, "via_rank": -1}`).
* This is essentially the standard 1-WL on the underlying graphs.
* It is known that 1-WL cannot distinguish (C_6) from (C_3 \sqcup C_3).

The test confirms that the WL signatures are equal, reproducing this **classical failure**.

---

### 5.3 Test 3 — Invariance Under Vertex Relabeling (SimplicialComplex, Rank 1)

**Complex type:** `SimplicialComplex`.

We build two complexes:

* `sc1`: 2-simplices ([0,1,2]) and ([1,2,3]).
* `sc2`: the same structure with vertices relabeled
  ([10,11,12]) and ([11,12,13]).

These complexes are **isomorphic** as simplicial complexes.

**Behavior:**

* We run higher-order WL on **rank-1 adjacency** (edges via shared vertices).
* The induced edge-level graphs are isomorphic.
* WL signatures are equal → **invariance under relabeling**, as desired.

---

### 5.4 Test 4 — Edge-Level HO-WL Limitation (CellComplex, Rank 1 via Rank 2)

**Complex type:** `CellComplex`.

Common 1-skeleton (vertices and edges):

* Square with a diagonal:

  * Vertices: ({0,1,2,3}),
  * Edges: ((0,1), (1,2), (2,3), (3,0), (0,2)).

We build two complexes:

1. `cx_A`:

   * One quadrilateral face ([0,1,2,3]).

2. `cx_B`:

   * Two triangular faces ([0,1,2]) and ([0,2,3]).

**Behavior:**

* We run higher-order WL on **rank-1 coadjacency via rank 2**:

  * Nodes: edges.
  * Edges: pairs of edges that co-occur in a 2-cell.
* Empirically (with TopoNetX + our neighborhood definition), the WL signatures on edges are **equal**.

This is a **negative example**: in this configuration, **edge-level coadjacency** is not strong enough to distinguish “one quad vs two triangles”.

It illustrates that higher-order WL is not uniformly more powerful:
its expressiveness depends on **which rank** and **which neighborhood** you choose.

---

### 5.5 Test 5 — CombinatorialComplex: Rank-2 and Rank-3 Sensitivity

**Complex type:** `CombinatorialComplex`.

We build two complexes `cc_A`, `cc_B` that share a 2D structure but differ in rank 3.

**Shared structure:**

* Vertices: ({0,1,2,3}).
* Edges: two triangles sharing an edge:

  * ((0,1),(1,2),(2,0)) and ((1,2),(2,3)).
* Faces: ([0,1,2]) and ([1,2,3]) in both.

**Rank-3 structure:**

* `cc_A`: one 3-cell attached over ([0,1,2]).
* `cc_B`: two 3-cells, over ([0,1,2]) and ([1,2,3]).

**Behavior:**

1. **Rank-2 HO-WL** (faces via edges: `"coadj"`, `{"rank": 2, "via_rank": 1}`):

   * Faces are neighbors if they share an edge.
   * Because of how 3-cells affect co-incidence patterns, the rank-2 neighborhood structure changes.
   * WL signatures differ → **rank-2 HO-WL is already sensitive** to the rank-3 difference.

2. **Rank-3 HO-WL** (3-cells via vertices: `"coadj"`, `{"rank": 3, "via_rank": 0}`):

   * Nodes: 3-cells.
   * Obviously, 1 vs 2 cells gives a different graph.
   * WL signatures differ → **rank-3 HO-WL also detects** the difference.

This test shows that in some settings:

* Lower ranks (faces) can already encode higher-rank information through neighborhood matrices.
* Higher ranks (3-cells) make the difference even more explicit.

---

### 5.6 Test 6 — Invariance Under Vertex Relabeling (CombinatorialComplex, Rank 2)

**Complex type:** `CombinatorialComplex`.

We build two isomorphic complexes:

* `cc1`:

  * Vertices: ({0,1,2,3}),
  * Edges: a 4-cycle,
  * Faces: ([0,1,2]) and ([0,2,3]).

* `cc2`:

  * Same structure but vertices relabeled ({10,11,12,13}),
  * Faces: ([10,11,12]) and ([10,12,13]).

**Behavior:**

* We run higher-order WL on **rank-2 coadjacency via rank 1** (faces via edges).
* The induced face-level graphs are isomorphic.
* WL signatures match → HO-WL is **invariant under combinatorial isomorphism** at this rank.

---

## 6. “Graph / SC / CX Fail, CC Succeeds” Example

In a separate cell (not in the 6-test suite), we also build an example with the following structure:

* Base graph: two triangles sharing an edge.
* `SimplicialComplex` and `CellComplex` versions:

  * Same vertices, edges, and faces.
* `CombinatorialComplex` versions:

  * Same 0-, 1-, and 2-cells,
  * Different rank-3 cells:

    * One 3-cell vs two 3-cells.

Then we observe:

1. Graph WL on the base graph: **equal** (cannot see rank-2 or rank-3 structure).
2. Higher-order WL on SC rank-2: **equal** (faces are identical).
3. Higher-order WL on CX rank-2: **equal** (same 2D cells).
4. Higher-order WL on CC rank-3 (via vertices): **different** (3-cells differ).

This shows that:

> Even if **all lower-rank WL tests** (graph, SC, CX) fail,
> a **CombinatorialComplex** with higher rank and an appropriate neighborhood can still distinguish the structures.

---


Next we provide the code that explains the above examples in our package.

In [57]:
# ============================================================
# 5. COMBINATORIAL COMPLEX TEST (DETAILED):
# WL FAILS ON GRAPH / SC / CX BUT SUCCEEDS ON CC (HIGHER RANK)
# ============================================================
#
# We build one underlying combinatorial structure in four different ways:
#   (1) As a plain graph (just vertices + edges).
#   (2) As a SimplicialComplex (with 2-simplices).
#   (3) As a CellComplex (with 2-cells).
#   (4) As a CombinatorialComplex (with 0-,1-,2-, and 3-cells).
#
# Then we create TWO versions A and B:
#   - In (1)-(3), A and B are *identical* (no difference in structure).
#   - In (4), A and B differ ONLY in rank-3 cells:
#       cc_A has ONE rank-3 cell,
#       cc_B has TWO rank-3 cells.
#
# For each representation we run one or more WL tests and print:
#   - Whether the WL signatures are equal.
#   - A short interpretation of what that equality/inequality means.
#
# The point of this cell is to show:
#   • Plain graph WL cannot see the CC difference (as expected).
#   • WL on SC and CX (including rank-2) also cannot see the CC difference.
#   • Only higher-order WL on the CombinatorialComplex at rank 3 detects the
#     extra 3-cell in cc_B and separates cc_A from cc_B.
# ============================================================

import networkx as nx
import toponetx as tnx

# (Assumes you already defined:
#  - neighborhood_from_complex
#  - HigherOrderWeisfeilerLehmanHashing
#  in earlier cells)


# ---------------------------------------------------------------------
# Small helper for didactic printing
# ---------------------------------------------------------------------
def print_check(label: str, cond: bool, when_true: str, when_false: str) -> None:
    """
    Print a labeled boolean check with a textual interpretation.

    Parameters
    ----------
    label : str
        Short description of the test being performed.
    cond : bool
        Result of the equality check (e.g. sig_A == sig_B).
    when_true : str
        Explanation of what it means if cond is True.
    when_false : str
        Explanation of what it means if cond is False.
    """
    mark = "✓" if cond else "✗"
    print(f"{mark} {label}")
    print(f"    Result        : {cond}")
    print(f"    Interpretation: {when_true if cond else when_false}")
    print()  # blank line for readability


# ---------------------------------------------------------------------
# 1. Utility functions: WL on graphs and WL on complexes
# ---------------------------------------------------------------------
def graph_wl_signature_from_edges(edges, wl_iterations: int = 3):
    """
    Run classical 1-WL on a plain graph specified by an edge list,
    using HigherOrderWeisfeilerLehmanHashing in graph mode.

    Parameters
    ----------
    edges : list of [u, v]
        Undirected edges of the graph.
    wl_iterations : int
        Number of WL refinement iterations.

    Returns
    -------
    list of str
        Sorted multiset of graph-level features (WL labels).
    """
    G = nx.Graph()
    G.add_edges_from(edges)

    wl = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=wl_iterations,
        erase_base_features=False,
    ).fit(domain=G)

    # We use the flat bag-of-features across all iterations as the signature.
    return sorted(wl.get_domain_features())


def graph_wl_signature_on_rank0_complex(domain, wl_iterations: int = 3):
    """
    Run WL on the 0-skeleton of a TopoNetX complex.

    Conceptually:
        1. We extract the vertex adjacency matrix of the complex.
        2. We treat that as a graph.
        3. We run the standard WL procedure on that graph using
           HigherOrderWeisfeilerLehmanHashing in graph mode.

    For CombinatorialComplex / ColoredHyperGraph we must specify a
    via_rank to talk about adjacency between vertices via edges.
    For simplicial and cell complexes, via_rank is ignored.

    Parameters
    ----------
    domain : TopoNetX complex
        One of SimplicialComplex, CellComplex, CombinatorialComplex, etc.
    wl_iterations : int
        Number of WL refinement iterations.

    Returns
    -------
    list of str
        Sorted multiset of WL features on the 0-skeleton graph.
    """
    if isinstance(domain, (tnx.CombinatorialComplex, tnx.ColoredHyperGraph)):
        # vertices (rank 0) adjacent via edges (rank 1)
        neigh_dim = {"rank": 0, "via_rank": 1}
    else:
        # via_rank is ignored in SimplicialComplex / CellComplex
        neigh_dim = {"rank": 0, "via_rank": -1}

    ind, A = neighborhood_from_complex(
        domain,
        neighborhood_type="adj",
        neighborhood_dim=neigh_dim,
    )

    G = nx.from_scipy_sparse_array(A)

    wl = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=wl_iterations,
        erase_base_features=False,
    ).fit(domain=G)

    return sorted(wl.get_domain_features())


def higher_order_wl_signature_complex(
    domain,
    neighborhood_type: str,
    neighborhood_dim: dict,
    wl_iterations: int = 3,
):
    """
    Run higher-order WL on a TopoNetX complex using a single neighborhood.

    Conceptually:
        1. Choose a rank r and via-rank k using neighborhood_dim.
        2. Build a (co)adjacency matrix A^{(r,k)} for rank-r cells.
        3. Treat rank-r cells as nodes in a graph defined by A^{(r,k)}.
        4. Run WL on that cell-graph.

    This is exactly the "higher-order WL" described in the tutorial.

    Parameters
    ----------
    domain : TopoNetX complex
    neighborhood_type : {"adj", "coadj"}
        Type of neighborhood: adjacency or coadjacency.
    neighborhood_dim : dict
        Dictionary with keys "rank" and "via_rank" describing which
        cells we run WL on and via which rank we connect them.
    wl_iterations : int
        Number of WL refinement iterations.

    Returns
    -------
    list of str
        Sorted multiset of higher-order WL features over the chosen rank.
    """
    hol = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=wl_iterations,
        erase_base_features=False,
    ).fit(
        domain=domain,
        neighborhood_types=neighborhood_type,
        neighborhood_dims=neighborhood_dim,
    )
    return sorted(hol.get_domain_features())


# ---------------------------------------------------------------------
# 2. Base structure: two triangles sharing an edge (underlying graph)
# ---------------------------------------------------------------------
# We use the same base graph in all four representations:
#
#   vertices: 0, 1, 2, 3
#   edges:
#       triangle (0,1,2): (0,1), (1,2), (2,0)
#       triangle (1,2,3): (1,3), (2,3)
#
# This is the "two triangles sharing an edge" example described in the text.
edges_base = [
    [0, 1],
    [1, 2],
    [2, 0],   # triangle 0-1-2
    [1, 3],
    [2, 3],   # triangle 1-2-3
]


# ---------------------------------------------------------------------
# 3. SimplicialComplex: sc_A and sc_B are IDENTICAL
# ---------------------------------------------------------------------
# Both sc_A and sc_B contain exactly the two 2-simplices:
#   [0,1,2] and [1,2,3]
#
# Because they are identical:
#   - 0-skeleton WL should see them as the same.
#   - rank-2 higher-order WL (faces) should see them as the same.
sc_A = tnx.SimplicialComplex([[0, 1, 2], [1, 2, 3]])
sc_B = tnx.SimplicialComplex([[0, 1, 2], [1, 2, 3]])  # identical copy


# ---------------------------------------------------------------------
# 4. CellComplex: cx_A and cx_B are IDENTICAL
# ---------------------------------------------------------------------
# Both cell complexes have:
#   - vertices {0,1,2,3}
#   - edges given by edges_base
#   - two 2-cells filling the two triangles
#
# So again:
#   - 0-skeleton WL should see them as equal.
#   - rank-2 higher-order WL should see them as equal.
cx_A = tnx.CellComplex(edges_base, ranks=1)
cx_A.add_cell([0, 1, 2], rank=2)
cx_A.add_cell([1, 2, 3], rank=2)

cx_B = tnx.CellComplex(edges_base, ranks=1)
cx_B.add_cell([0, 1, 2], rank=2)
cx_B.add_cell([1, 2, 3], rank=2)


# ---------------------------------------------------------------------
# 5. CombinatorialComplex: cc_A vs cc_B differ ONLY in rank-3
# ---------------------------------------------------------------------
# cc_A and cc_B share:
#   - the same vertex set {0,1,2,3}
#   - the same edges edges_base
#   - the same rank-2 cells [0,1,2] and [1,2,3]
#
# They differ in rank-3:
#   - cc_A has exactly one rank-3 cell attached to [0,1,2]
#   - cc_B has two rank-3 cells:
#         one attached to [0,1,2]
#         one attached to [1,2,3]
#
# This is the extra information that only WL on rank-3 cells can see.
cc_A = tnx.CombinatorialComplex()
for v in [0, 1, 2, 3]:
    cc_A.add_cell([v], rank=0)
for u, v in edges_base:
    cc_A.add_cell([u, v], rank=1)
cc_A.add_cell([0, 1, 2], rank=2)
cc_A.add_cell([1, 2, 3], rank=2)
# rank-3: one higher cell over triangle (0,1,2)
cc_A.add_cell([0, 1, 2], rank=3)

cc_B = tnx.CombinatorialComplex()
for v in [0, 1, 2, 3]:
    cc_B.add_cell([v], rank=0)
for u, v in edges_base:
    cc_B.add_cell([u, v], rank=1)
cc_B.add_cell([0, 1, 2], rank=2)
cc_B.add_cell([1, 2, 3], rank=2)
# rank-3: one over (0,1,2) and one over (1,2,3)
cc_B.add_cell([0, 1, 2], rank=3)
cc_B.add_cell([1, 2, 3], rank=3)


# ---------------------------------------------------------------------
# 6. Run all WL tests and print detailed, didactic output
# ---------------------------------------------------------------------
print("\n================ CC TEST: GRAPH / SC / CX FAIL, CC SUCCEEDS ================\n")

print("Base structure:")
print("  • Two triangles (0,1,2) and (1,2,3) sharing the edge (1,2)")
print("  • Vertices: {0, 1, 2, 3}")
print("  • Edges   :", edges_base)
print()
print("We build FOUR representations of this structure (A and B each):")
print("  (1) Plain graph: only vertices + edges (no faces).")
print("  (2) SimplicialComplex  sc_A vs sc_B: two 2-simplices [0,1,2], [1,2,3].")
print("  (3) CellComplex        cx_A vs cx_B: two 2-cells filling the triangles.")
print("  (4) CombinatorialComplex cc_A vs cc_B: same 0–2 structure,")
print("      but cc_B has an extra rank-3 cell compared to cc_A.\n")

print("We now run a sequence of WL tests and compare the signatures for A vs B:")
print("  • If signatures are EQUAL: WL cannot distinguish A and B in that view.")
print("  • If signatures DIFFER   : WL successfully detects a structural difference.\n")

# 1) Plain graph WL — should be equal (same edge list)
sig_graph_A = graph_wl_signature_from_edges(edges_base)
sig_graph_B = graph_wl_signature_from_edges(edges_base)
print_check(
    "Graph WL on plain graph (edges only) A vs B",
    sig_graph_A == sig_graph_B,
    when_true="As expected: the underlying graph is identical, so 1-WL cannot distinguish them.",
    when_false="Unexpected: 1-WL claims the same edge set leads to different signatures."
)

# 2) SimplicialComplex: graph WL on 0-skeleton
sig_sc_A_graph = graph_wl_signature_on_rank0_complex(sc_A)
sig_sc_B_graph = graph_wl_signature_on_rank0_complex(sc_B)
print_check(
    "Graph WL on SimplicialComplex 0-skeleton A vs B",
    sig_sc_A_graph == sig_sc_B_graph,
    when_true="As expected: sc_A and sc_B have the same 0-skeleton graph.",
    when_false="Unexpected: WL on the vertex graph thinks sc_A and sc_B differ."
)

#    Higher-order WL on rank-2 coadjacency for SC — should be equal
sig_sc_A_2 = higher_order_wl_signature_complex(
    sc_A, "coadj", {"rank": 2, "via_rank": -1}
)
sig_sc_B_2 = higher_order_wl_signature_complex(
    sc_B, "coadj", {"rank": 2, "via_rank": -1}
)
print_check(
    "Higher-order WL on SC (rank=2 faces, coadjacency) A vs B",
    sig_sc_A_2 == sig_sc_B_2,
    when_true="As expected: sc_A and sc_B have exactly the same 2-simplices.",
    when_false="Unexpected: WL on 2-simplices claims a difference where there should be none."
)

# 3) CellComplex: graph WL on 0-skeleton
sig_cx_A_graph = graph_wl_signature_on_rank0_complex(cx_A)
sig_cx_B_graph = graph_wl_signature_on_rank0_complex(cx_B)
print_check(
    "Graph WL on CellComplex 0-skeleton A vs B",
    sig_cx_A_graph == sig_cx_B_graph,
    when_true="As expected: cx_A and cx_B share the same vertex-edge graph.",
    when_false="Unexpected: WL on the vertex graph thinks cx_A and cx_B differ."
)

#    Higher-order WL on rank-2 coadjacency for CX — should be equal
sig_cx_A_2 = higher_order_wl_signature_complex(
    cx_A, "coadj", {"rank": 2, "via_rank": -1}
)
sig_cx_B_2 = higher_order_wl_signature_complex(
    cx_B, "coadj", {"rank": 2, "via_rank": -1}
)
print_check(
    "Higher-order WL on CX (rank=2 faces, coadjacency) A vs B",
    sig_cx_A_2 == sig_cx_B_2,
    when_true="As expected: cx_A and cx_B have identical 2-cell structure.",
    when_false="Unexpected: WL on 2-cells claims a difference where there should be none."
)

# 4) CombinatorialComplex: graph WL on 0-skeleton
sig_cc_A_graph = graph_wl_signature_on_rank0_complex(cc_A)
sig_cc_B_graph = graph_wl_signature_on_rank0_complex(cc_B)
print_check(
    "Graph WL on CombinatorialComplex 0-skeleton A vs B",
    sig_cc_A_graph == sig_cc_B_graph,
    when_true="As expected: cc_A and cc_B share the same vertex-edge graph.",
    when_false="Unexpected: WL on the vertex graph thinks cc_A and cc_B differ."
)

#    Higher-order WL on rank-3 coadjacency via rank-0 — should DIFFER
sig_cc_A_3 = higher_order_wl_signature_complex(
    cc_A, "coadj", {"rank": 3, "via_rank": 0}
)
sig_cc_B_3 = higher_order_wl_signature_complex(
    cc_B, "coadj", {"rank": 3, "via_rank": 0}
)
print_check(
    "Higher-order WL on CC (rank=3 cells, coadj via vertices) A vs B",
    sig_cc_A_3 == sig_cc_B_3,
    when_true="This would mean WL on rank-3 cells failed to see the extra 3-cell in cc_B (not desired).",
    when_false="As desired: WL on rank-3 cells distinguishes cc_A (1 three-cell) from cc_B (2 three-cells)."
)

print("Summary:")
print("  • All WL tests that only see the 0-skeleton or the 2D structure (SC/CX) say A and B look the same.")
print("  • Only the higher-order WL on the CombinatorialComplex at rank 3 (with coadjacency via rank 0)")
print("    is sensitive enough to detect the extra 3-cell in cc_B.")
print("  • This illustrates how combinatorial complexes + flexible (r, k) neighborhoods give WL")
print("    strictly more expressive power than graphs, simplicial complexes, or cell complexes alone.\n")
print("===========================================================================\n")




Base structure:
  • Two triangles (0,1,2) and (1,2,3) sharing the edge (1,2)
  • Vertices: {0, 1, 2, 3}
  • Edges   : [[0, 1], [1, 2], [2, 0], [1, 3], [2, 3]]

We build FOUR representations of this structure (A and B each):
  (1) Plain graph: only vertices + edges (no faces).
  (2) SimplicialComplex  sc_A vs sc_B: two 2-simplices [0,1,2], [1,2,3].
  (3) CellComplex        cx_A vs cx_B: two 2-cells filling the triangles.
  (4) CombinatorialComplex cc_A vs cc_B: same 0–2 structure,
      but cc_B has an extra rank-3 cell compared to cc_A.

We now run a sequence of WL tests and compare the signatures for A vs B:
  • If signatures are EQUAL: WL cannot distinguish A and B in that view.
  • If signatures DIFFER   : WL successfully detects a structural difference.

✓ Graph WL on plain graph (edges only) A vs B
    Result        : True
    Interpretation: As expected: the underlying graph is identical, so 1-WL cannot distinguish them.

✓ Graph WL on SimplicialComplex 0-skeleton A vs B
    Re