In [None]:
import hashlib
import networkx as nx
from tqdm.auto import tqdm
from typing import List, Dict, Optional


class WeisfeilerLehmanHashing(object):
    """
    Weisfeiler-Lehman feature extractor class.

    Args:
        graph (NetworkX graph): NetworkX graph for which we do WL hashing.
        wl_iterations (int): Number of WL iterations.
        use_node_attribute (Optional[str]): Optional attribute name to be used.
        erase_base_features (bool): Delete the base features after hashing.
    """

    def __init__(
        self,
        graph: nx.classes.graph.Graph,
        wl_iterations: int,
        use_node_attribute: Optional[str],
        erase_base_features: bool,
    ):
        """
        Initialization method which also executes feature extraction.
        """
        self.wl_iterations = wl_iterations
        self.graph = graph
        self.use_node_attribute = use_node_attribute
        self.erase_base_features = erase_base_features
        self._set_features()
        self._do_recursions()

    def _set_features(self):
        """
        Creating the features.
        """
        if self.use_node_attribute is not None:
            # We retrieve the features of the nodes with the attribute name
            # `use_node_attribute` and assign them into a dictionary with structure:
            # {node_a_name: feature_of_node_a}
            # Nodes without this feature will not appear in the dictionary.
            features = nx.get_node_attributes(self.graph, self.use_node_attribute)

            # We check whether all nodes have the requested feature
            if len(features) != self.graph.number_of_nodes():
                missing_nodes = []
                # We find up to five missing nodes so to make
                # a more informative error message.
                for node in tqdm(
                    self.graph.nodes,
                    total=self.graph.number_of_nodes(),
                    leave=False,
                    dynamic_ncols=True,
                    desc="Searching for missing nodes"
                ):
                    if node not in features:
                        missing_nodes.append(node)
                    if len(missing_nodes) > 5:
                        break
                raise ValueError(
                    (
                        "We expected for ALL graph nodes to have a node "
                        "attribute name `{}` to be used as part of "
                        "the requested embedding algorithm, but only {} "
                        "out of {} nodes has the correct attribute. "
                        "Consider checking for typos and missing values, "
                        "and use some imputation technique as necessary. "
                        "Some of the nodes without the requested attribute "
                        "are: {}"
                    ).format(
                        self.use_node_attribute,
                        len(features),
                        self.graph.number_of_nodes(),
                        missing_nodes
                    )
                )
            # If so, we assign the feature set.
            self.features = features
        else:
            # Default: use node degrees as initial labels
            self.features = {
                node: self.graph.degree(node) for node in self.graph.nodes()
            }

        # extracted_features[node] = [f^0(node), f^1(node), ..., f^T(node)]
        self.extracted_features = {k: [str(v)]
                                   for k, v in self.features.items()}

    def _erase_base_features(self):
        """
        Erasing the base features (the initial labels).
        """
        for k, v in self.extracted_features.items():
            # remove the first element (base feature f^0)
            del self.extracted_features[k][0]

    def _do_a_recursion(self) -> Dict[int, str]:
        """
        The method does a single WL recursion.

        Return types:
            * **new_features** *(dict of strings)* - The hash table with extracted WL features.
        """
        new_features = {}
        for node in self.graph.nodes():
            nebs = self.graph.neighbors(node)
            degs = [self.features[neb] for neb in nebs]
            features = [str(self.features[node])] + \
                sorted([str(deg) for deg in degs])
            features = "_".join(features)
            hash_object = hashlib.md5(features.encode())
            hashing = hash_object.hexdigest()
            new_features[node] = hashing

        # Append this iteration's feature to the history
        self.extracted_features = {
            k: self.extracted_features[k] + [v] for k, v in new_features.items()
        }
        return new_features

    def _do_recursions(self):
        """
        The method does a series of WL recursions.
        """
        for _ in range(self.wl_iterations):
            self.features = self._do_a_recursion()
        if self.erase_base_features:
            self._erase_base_features()

    def get_node_features(self) -> Dict[int, List[str]]:
        """
        Return the node level features.

        Returns:
            dict: {node: [feature_it0, feature_it1, ..., feature_itT]}
        """
        return self.extracted_features

    def get_graph_features(self) -> List[str]:
        """
        Return the graph level features as a bag (multiset)
        of all node features across all iterations.

        Returns:
            list of str: concatenation over nodes and iterations.
        """
        return [
            feature
            for node, features in self.extracted_features.items()
            for feature in features
        ]


# ============================================================
# Example usage
# ============================================================
if __name__ == "__main__":
    # -----------------------------
    # Example 1: plain graph, degree as base feature
    # -----------------------------
    G = nx.karate_club_graph()  # classic benchmark graph

    wl = WeisfeilerLehmanHashing(
        graph=G,
        wl_iterations=2,         # number of WL iterations
        use_node_attribute=None, # start from degrees
        erase_base_features=False
    )

    node_feats = wl.get_node_features()
    graph_feats = wl.get_graph_features()

    print("Example 1: Karate graph (degree-based WL)")
    # features for a single node (e.g., node 0)
    print("Node 0 features:", node_feats[0])
    print("Total number of graph-level features:", len(graph_feats))
    print("First 10 graph features:", graph_feats[:10])

    # -----------------------------
    # Example 2: using a node attribute
    # -----------------------------
    H = nx.Graph()
    H.add_nodes_from([0, 1, 2, 3])
    H.add_edges_from([(0, 1), (1, 2), (2, 3)])

    # Add a node attribute "label"
    nx.set_node_attributes(H, {
        0: "A",
        1: "B",
        2: "A",
        3: "C",
    }, name="label")

    wl_attr = WeisfeilerLehmanHashing(
        graph=H,
        wl_iterations=3,
        use_node_attribute="label",  # use this attribute instead of degree
        erase_base_features=True     # optionally drop the original labels
    )

    node_feats_attr = wl_attr.get_node_features()
    graph_feats_attr = wl_attr.get_graph_features()

    print("\nExample 2: Small graph with 'label' node attribute")
    for node, feats in node_feats_attr.items():
        print(f"Node {node} features:", feats)

    print("Graph-level feature multiset (first few):", graph_feats_attr[:10])


Example 1: Karate graph (degree-based WL)
Node 0 features: ['16', 'ef9153b15158f0e1f50641e7bf8c1b28', '50402170d75b12360445faec47820429']
Total number of graph-level features: 102
First 10 graph features: ['16', 'ef9153b15158f0e1f50641e7bf8c1b28', '50402170d75b12360445faec47820429', '9', 'da1b21c9ce096bf157fe29ba7c59563b', '3d0d0163a1a838e3b3f26bf1090220ca', '10', 'f3c087dc4f4756e76a65388e0f7bf976', '4bdae6d3c2f6dc1f170cd3804b8e2d42', '6']

Example 2: Small graph with 'label' node attribute
Node 0 features: ['350f8fcc3cd566cd76e9906ea8e5936b', 'b0f06103498e6e63e39c3de8de18fbbe', 'c31c3ea6b78a54ef15a9855ce4d0527a']
Node 1 features: ['03cd6f755301fc7ddc543563e50e2183', 'cfbf572289373f2e2c464e052c709b19', '1140848dbe25c31020720fcdd74e00ae']
Node 2 features: ['3510e0bfa31ada47a9b76fc9ba9dff17', 'e079c71875a4f69ce9b55abb5bc794a4', '34be3d3dd583406233aa5b19bb74542b']
Node 3 features: ['a4d1dc80e304ab3d90468c2f37cba861', '073e851d7bbb87ccca45ab16c7b612e2', 'a5e07b02f5915e2528c69d13be7539f0']


In [None]:
import networkx as nx

# --- Graph A: Cycle C6 ---
G1 = nx.cycle_graph(6)

# --- Graph B: Two triangles ---
G2 = nx.disjoint_union(nx.cycle_graph(3), nx.cycle_graph(3))

wl1 = WeisfeilerLehmanHashing(G1, wl_iterations=5, use_node_attribute=None, erase_base_features=False)
wl2 = WeisfeilerLehmanHashing(G2, wl_iterations=5, use_node_attribute=None, erase_base_features=False)

print("Graph features match?", sorted(wl1.get_graph_features()) == sorted(wl2.get_graph_features()))


Graph features match? True


In [None]:
!pip install toponetx

Collecting toponetx
  Downloading TopoNetX-0.2.0-py3-none-any.whl.metadata (13 kB)
Collecting trimesh (from toponetx)
  Downloading trimesh-4.10.0-py3-none-any.whl.metadata (13 kB)
Downloading TopoNetX-0.2.0-py3-none-any.whl (114 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.4/114.4 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading trimesh-4.10.0-py3-none-any.whl (736 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m736.6/736.6 kB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh, toponetx
Successfully installed toponetx-0.2.0 trimesh-4.10.0


In [None]:
"""Functions for computing neighborhoods of a complex."""

from typing import Literal

import toponetx as tnx
from scipy.sparse import csr_matrix


def neighborhood_from_complex(
    domain: tnx.Complex,
    neighborhood_type: Literal["adj", "coadj"] = "adj",
    neighborhood_dim=None,
) -> tuple[list, csr_matrix]:
    """Compute the neighborhood of a complex.

    This function returns the indices and matrix for the neighborhood specified by
    `neighborhood_type` and `neighborhood_dim` for the input complex `domain`.

    Parameters
    ----------
    domain : toponetx.classes.Complex
        The complex to compute the neighborhood for.
    neighborhood_type : {"adj", "coadj"}, default="adj"
        The type of neighborhood to compute. "adj" for adjacency matrix, "coadj" for coadjacency matrix.
    neighborhood_dim : dict
        The integer parameters needed to specify the neighborhood of the cells to generate the embedding.
        In TopoNetX  (co)adjacency neighborhood matrices are specified via one or two parameters.
        - For Cell/Simplicial/Path complexes (co)adjacency matrix is specified by a single parameter, this is precisely
        neighborhood_dim["rank"].
        - For Combinatorial/ColoredHyperGraph the (co)adjacency matrix is specified by a single parameter, this is precisely
        neighborhood_dim["rank"] and neighborhood_dim["via_rank"].

    Notes
    -----
        here neighborhood_dim={"rank": 1, "via_rank": -1} specifies the dimension for
        which the cell embeddings are going to be computed.
        "rank": 1 means that the embeddings will be computed for the first dimension.
        The integer "via_rank": -1 is ignored when the input is cell/simplicial complex
        and  must be specified when the input complex is a combinatorial complex or
        colored hypergraph.

    Returns
    -------
    ind : list
        A list of the indices for the nodes in the neighborhood.
    A : scipy.sparse.csr_matrix
        The matrix representing the neighborhood.

    Raises
    ------
    TypeError
        If `domain` is not a SimplicialComplex, CellComplex, PathComplex ColoredHyperGraph or CombinatorialComplex.
    TypeError
        If `neighborhood_type` is invalid.
    """
    if neighborhood_dim is None:
        neighborhood_dim = {"rank": 0, "via_rank": -1}

    if neighborhood_type not in ["adj", "coadj"]:
        raise TypeError(
            f"Input neighborhood_type must be `adj` or `coadj`, got {neighborhood_type}."
        )

    if isinstance(domain, tnx.SimplicialComplex | tnx.CellComplex | tnx.PathComplex):
        if neighborhood_type == "adj":
            ind, A = domain.adjacency_matrix(neighborhood_dim["rank"], index=True)
        else:
            ind, A = domain.coadjacency_matrix(neighborhood_dim["rank"], index=True)
    elif isinstance(domain, tnx.CombinatorialComplex | tnx.ColoredHyperGraph):
        if neighborhood_type == "adj":
            ind, A = domain.adjacency_matrix(
                neighborhood_dim["rank"], neighborhood_dim["via_rank"], index=True
            )
        else:
            ind, A = domain.coadjacency_matrix(
                neighborhood_dim["rank"], neighborhood_dim["via_rank"], index=True
            )
    else:
        raise TypeError(
            "Input Complex can only be a SimplicialComplex, CellComplex, PathComplex ColoredHyperGraph or CombinatorialComplex."
        )

    return ind, A

In [None]:
"""Higher-order Weisfeiler–Lehman hashing on complexes."""

from __future__ import annotations

import hashlib
from typing import Any, Dict, Hashable, List, Literal, Optional, Tuple

import networkx as nx
import toponetx as tnx
from scipy.sparse import csr_matrix



class HigherOrderWeisfeilerLehmanHashing:
    r"""
    Higher-order Weisfeiler–Lehman (WL) feature extractor on TopoNetX complexes.

    This class generalizes the classical Weisfeiler–Lehman hashing from graphs
    to arbitrary topological domains supported by TopoNetX
    (CellComplex, SimplicialComplex, CombinatorialComplex, PathComplex, ColoredHyperGraph).

    Instead of a graph adjacency, it uses an arbitrary (co)adjacency neighborhood
    matrix computed from the complex (e.g. adjacency of 0-cells, coadjacency of 1-cells,
    adjacency via another rank, etc.).

    Mathematically, we consider:
        - A complex :math:`\mathcal{K}`.
        - A neighborhood matrix :math:`A` over a chosen family of cells (e.g. rank-0, rank-1, ...).
        - WL is then run on the graph :math:`G = (V, E)` with
          :math:`V = \{0, \dots, n-1\}` indexing those cells, and
          edges induced by the non-zero pattern of :math:`A`.

    Parameters
    ----------
    wl_iterations : int
        Number of WL iterations (depth of refinement).
    erase_base_features : bool, default=False
        If True, drop the base features (iteration 0) from the final feature lists.

    Notes
    -----
    - The base features can be:
        * explicitly provided via ``cell_features`` in :meth:`fit`, or
        * default to the degree in the neighborhood graph induced by the
          chosen (co)adjacency matrix.
    - After fitting, you can obtain:
        * cell-level features via :meth:`get_cell_features`
        * complex-level (bag-of-features) representation via :meth:`get_domain_features`.
    """

    # Neighborhood matrix and index list
    A: csr_matrix
    ind: List[Hashable]

    def __init__(self, wl_iterations: int = 2, erase_base_features: bool = False) -> None:
        self.wl_iterations = wl_iterations
        self.erase_base_features = erase_base_features

        # Will be populated by fit()
        self.domain: Optional[tnx.Complex] = None
        self.neighborhood_type: Optional[str] = None
        self.neighborhood_dim: Optional[Dict[str, int]] = None

        self.graph_: Optional[nx.Graph] = None
        self._index_to_cell: Dict[int, Hashable] = {}
        self._cell_to_index: Dict[Hashable, int] = {}

        # WL internal state
        self.features: Dict[int, Any] = {}
        self.extracted_features: Dict[int, List[str]] = {}

    # -------------------------------------------------------------------------
    # Core public API
    # -------------------------------------------------------------------------
    def fit(
        self,
        domain: tnx.Complex,
        neighborhood_type: Literal["adj", "coadj"] = "adj",
        neighborhood_dim: Optional[Dict[str, int]] = None,
        cell_features: Optional[Dict[Hashable, Any]] = None,
    ) -> "HigherOrderWeisfeilerLehmanHashing":
        r"""
        Fit the higher-order WL hashing on a complex.

        Parameters
        ----------
        domain : toponetx.classes.Complex
            A complex object. The complex can be one of:
            - CellComplex
            - CombinatorialComplex
            - PathComplex
            - SimplicialComplex
            - ColoredHyperGraph
        neighborhood_type : {"adj", "coadj"}, default="adj"
            The type of neighborhood to compute:
            - "adj"   → adjacency matrix
            - "coadj" → coadjacency matrix
        neighborhood_dim : dict, optional
            Integer parameters specifying the neighborhood of the cells to generate features.
            Follows the same convention as :func:`neighborhood_from_complex`:

            - For Cell/Simplicial/Path complexes, the (co)adjacency is specified by:
                  neighborhood_dim["rank"]
            - For Combinatorial/ColoredHyperGraph, it is specified by:
                  neighborhood_dim["rank"], neighborhood_dim["via_rank"]
        cell_features : dict, optional
            Optional base features for the cells. If provided, this must be a dictionary
            mapping each cell identifier in the chosen domain rank to a scalar or
            hashable attribute.

            The keys must match the entries of the index list ``ind`` returned by
            :func:`neighborhood_from_complex`. If None, the base features default to
            the degrees in the induced neighborhood graph.

        Returns
        -------
        HigherOrderWeisfeilerLehmanHashing
            The fitted instance (for chaining).
        """
        self.domain = domain
        self.neighborhood_type = neighborhood_type
        self.neighborhood_dim = neighborhood_dim

        # 1. Build neighborhood matrix and index list
        self.ind, self.A = neighborhood_from_complex(
            domain, neighborhood_type=neighborhood_type, neighborhood_dim=neighborhood_dim
        )

        # 2. Build an internal graph on indices 0..n-1
        #    use the sparsity pattern of A as adjacency
        g = nx.from_scipy_sparse_array(self.A)
        # Optionally add self-loops, mirroring Cell2Vec behavior
        g.add_edges_from((idx, idx) for idx in range(g.number_of_nodes()))

        self.graph_ = g
        self._index_to_cell = dict(enumerate(self.ind))
        self._cell_to_index = {cell: idx for idx, cell in self._index_to_cell.items()}

        # 3. Set base features and run WL recursions
        self._set_features(cell_features=cell_features)
        self._do_recursions()

        return self

    def get_index_features(self) -> Dict[int, List[str]]:
        """
        Get WL feature sequences indexed by the internal integer index.

        Returns
        -------
        dict
            Mapping from internal index (0..n-1) to list of WL features
            across iterations, i.e.:
                {idx: [f^0(idx), f^1(idx), ..., f^T(idx)]}
        """
        return self.extracted_features

    def get_cell_features(self) -> Dict[Hashable, List[str]]:
        """
        Get WL feature sequences for each cell in the chosen rank.

        Returns
        -------
        dict
            Mapping from cell identifier to list of WL features across iterations.
            The cell identifiers coincide with the entries of ``self.ind`` returned
            by :func:`neighborhood_from_complex`.
        """
        return {
            self._index_to_cell[idx]: feats
            for idx, feats in self.extracted_features.items()
        }

    def get_domain_features(self) -> List[str]:
        """
        Get a bag (multiset) of all WL features across all cells and iterations.

        Returns
        -------
        list of str
            Concatenation of all WL features, i.e. a global representation for the complex.
        """
        return [
            feature
            for idx, features in self.extracted_features.items()
            for feature in features
        ]

    # -------------------------------------------------------------------------
    # Internal helpers: WL logic
    # -------------------------------------------------------------------------
    def _set_features(self, cell_features: Optional[Dict[Hashable, Any]] = None) -> None:
        """
        Initialize base features for the WL recursion.
        """
        assert self.graph_ is not None, "Graph has not been constructed. Call fit() first."

        if cell_features is not None:
            # Map cell_features (keyed by cells) to internal index-based features
            features: Dict[int, Any] = {}
            missing_cells = []
            for idx, cell in self._index_to_cell.items():
                if cell not in cell_features:
                    missing_cells.append(cell)
                else:
                    features[idx] = cell_features[cell]

            if missing_cells:
                # Show a small subset of missing cells for debugging
                preview = missing_cells[:5]
                raise ValueError(
                    "Provided cell_features is missing values for some cells. "
                    f"Example missing cells: {preview}"
                )

            self.features = features
        else:
            # Default base features: degree in the neighborhood graph
            self.features = {
                idx: self.graph_.degree(idx)  # type: ignore[arg-type]
                for idx in self.graph_.nodes
            }

        # Initialize extracted_features with the base labels as strings
        self.extracted_features = {idx: [str(v)] for idx, v in self.features.items()}

    def _erase_base_features(self) -> None:
        """Erase the base features (iteration 0) from the feature lists."""
        for idx in list(self.extracted_features.keys()):
            if self.extracted_features[idx]:
                del self.extracted_features[idx][0]

    def _do_a_recursion(self) -> Dict[int, str]:
        """
        Perform a single WL refinement step.

        For each index i:
            new_feat(i) = hash( feat(i), multiset{ feat(j) : j in N(i) } )
        """
        assert self.graph_ is not None, "Graph has not been constructed. Call fit() first."

        new_features: Dict[int, str] = {}

        for idx in self.graph_.nodes:
            neighbors = self.graph_.neighbors(idx)
            neigh_vals = [self.features[nb] for nb in neighbors]

            # Concatenate current feature with sorted neighbor features
            parts = [str(self.features[idx])] + sorted(str(v) for v in neigh_vals)
            concat = "_".join(parts)

            hash_obj = hashlib.md5(concat.encode())
            hashing = hash_obj.hexdigest()
            new_features[idx] = hashing

        # Append this iteration's feature to the history
        self.extracted_features = {
            idx: self.extracted_features[idx] + [feat]
            for idx, feat in new_features.items()
        }

        return new_features

    def _do_recursions(self) -> None:
        """Run all WL iterations."""
        for _ in range(self.wl_iterations):
            self.features = self._do_a_recursion()

        if self.erase_base_features:
            self._erase_base_features()


In [None]:
"""Tests for HigherOrderWeisfeilerLehmanHashing."""

import numpy as np
import networkx as nx
import toponetx as tnx




class TestHigherOrderWL:
    """
    (1) A test where graph WL on 0-skeleton fails
        but higher-order WL on the complex succeeds.

    (2) A test where higher-order WL fails on a cell complex
        for a given neighborhood (C6 vs 2*C3).

    (3) A robustness test: invariance under relabeling.
    """

    # -------------------------------------------------------------------------
    # (1) Graph WL on 0-skeleton fails, higher-order WL on cells succeeds
    # -------------------------------------------------------------------------
    def test_graph_wl_fails_but_higherorder_wl_succeeds(self):
        """
        We build two CombinatorialComplex instances with the same 0–1 skeleton
        (same vertex-edge adjacency) but different 2-cells.

        - Graph WL is run on the 0-skeleton adjacency via rank=1
          → it sees the same graph in both cases, so it *cannot* distinguish.

        - HigherOrderWL is run on the 2-cells via rank=1
          → one complex has a 2-cell, the other does not, so it *can* distinguish.
        """
        # ------------------------------------------------------------------
        # Combinatorial complex 1: 4-cycle with a 2-cell filling the square
        # ------------------------------------------------------------------
        cc1 = tnx.CombinatorialComplex()

        # rank-0 cells (vertices)
        for v in range(4):
            cc1.add_cell([v], rank=0)

        # rank-1 cells (edges of a 4-cycle)
        edges = [(0, 1), (1, 2), (2, 3), (3, 0)]
        for e in edges:
            cc1.add_cell(list(e), rank=1)

        # rank-2 cell filling the square (boundary via those edges)
        # the exact boundary data is handled internally by TopoNetX
        cc1.add_cell([0, 1, 2, 3], rank=2)

        # ------------------------------------------------------------------
        # Combinatorial complex 2: same 0–1 skeleton, NO 2-cell
        # ------------------------------------------------------------------
        cc2 = tnx.CombinatorialComplex()

        for v in range(4):
            cc2.add_cell([v], rank=0)
        for e in edges:
            cc2.add_cell(list(e), rank=1)
        # no rank-2 cell added here

        # ------------------------------------------------------------------
        # 1.a Graph WL on 0-skeleton (rank 0 via rank 1) → should FAIL
        #     to distinguish cc1 and cc2, because they have the same 1-skeleton.
        # ------------------------------------------------------------------
        ind1_0, A1_0 = neighborhood_from_complex(
            cc1, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": 1}
        )
        ind2_0, A2_0 = neighborhood_from_complex(
            cc2, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": 1}
        )

        G1 = nx.from_scipy_sparse_array(A1_0)
        G2 = nx.from_scipy_sparse_array(A2_0)

        wl1 = WeisfeilerLehmanHashing(
            graph=G1,
            wl_iterations=2,
            use_node_attribute=None,
            erase_base_features=False,
        )
        wl2 = WeisfeilerLehmanHashing(
            graph=G2,
            wl_iterations=2,
            use_node_attribute=None,
            erase_base_features=False,
        )

        # Graph-level WL features (bags) should match
        # → graph WL does NOT see the extra 2-cell
        assert sorted(wl1.get_graph_features()) == sorted(
            wl2.get_graph_features()
        ), "Graph WL on 0-skeleton should not distinguish cc1 and cc2."

        # ------------------------------------------------------------------
        # 1.b Higher-order WL on 2-cells (rank 2 via rank 1) → should SUCCEED
        #     The first complex has one rank-2 cell, the second has none.
        # ------------------------------------------------------------------
        hol1 = HigherOrderWeisfeilerLehmanHashing(
            wl_iterations=2, erase_base_features=False
        ).fit(
            cc1,
            neighborhood_type="adj",
            neighborhood_dim={"rank": 2, "via_rank": 1},
        )

        # If there are no rank-2 cells, neighborhood_from_complex should
        # give an empty index list / adjacency matrix; the class should
        # handle this gracefully (possibly with no features).
        hol2 = HigherOrderWeisfeilerLehmanHashing(
            wl_iterations=2, erase_base_features=False
        ).fit(
            cc2,
            neighborhood_type="adj",
            neighborhood_dim={"rank": 2, "via_rank": 1},
        )

        feats1 = hol1.get_domain_features()
        feats2 = hol2.get_domain_features()

        # There MUST be a difference: cc1 has at least one 2-cell, cc2 has none.
        # So either length differs or content differs.
        assert feats1 != feats2, "Higher-order WL on 2-cells should distinguish cc1 and cc2."

    # -------------------------------------------------------------------------
    # (2) Failure example on a cell complex for a given neighborhood
    # -------------------------------------------------------------------------
    def test_higherorder_wl_failure_on_cell_complex_rank0(self):
        """
        Classic 1-WL failure lifted to a CellComplex:

        - Construct two 1D CellComplexes whose 0-skeleton graphs are:
            * C6 (cycle on 6 vertices)
            * disjoint union C3 ⊔ C3 (two triangles)

          These are a standard pair that 1-WL cannot distinguish.

        - Use HigherOrderWL on rank-0 adjacency ("adj" on 0-cells).
          This is exactly the same information as the graph WL test,
          so HigherOrderWL should also FAIL to distinguish them.
        """
        # C6 as a 1-dimensional cell complex (edges only)
        cx1 = tnx.CellComplex(
            [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 0]], ranks=1
        )

        # C3 ⊔ C3: two disjoint triangles
        cx2 = tnx.CellComplex(
            [[0, 1], [1, 2], [2, 0], [3, 4], [4, 5], [5, 3]], ranks=1
        )

        # 0-skeleton adjacency (rank=0)
        hol_cx1 = HigherOrderWeisfeilerLehmanHashing(
            wl_iterations=3, erase_base_features=False
        ).fit(
            cx1,
            neighborhood_type="adj",
            neighborhood_dim={"rank": 0, "via_rank": -1},
        )

        hol_cx2 = HigherOrderWeisfeilerLehmanHashing(
            wl_iterations=3, erase_base_features=False
        ).fit(
            cx2,
            neighborhood_type="adj",
            neighborhood_dim={"rank": 0, "via_rank": -1},
        )

        feats1 = sorted(hol_cx1.get_domain_features())
        feats2 = sorted(hol_cx2.get_domain_features())

        # As in the graph-setting, 1-WL on vertex adjacency cannot distinguish
        # C6 from C3 ⊔ C3, so HigherOrderWL on rank-0 should also fail.
        assert feats1 == feats2, (
            "Higher-order WL on rank-0 adjacency should fail to "
            "distinguish C6 from C3 ⊔ C3, mirroring the 1-WL failure."
        )

    # -------------------------------------------------------------------------
    # (3) Robustness: invariance under relabeling (isomorphic complexes)
    # -------------------------------------------------------------------------
    def test_higherorder_wl_invariance_under_relabeling(self):
        """
        Robustness test: HigherOrderWL should be invariant under relabeling
        of cells / vertices when the complexes are isomorphic.

        We build two isomorphic SimplicialComplexes with different vertex labels
        and check that the bag of domain features coincides.
        """
        # Simplicial complex 1: a small 2D shape
        sc1 = tnx.SimplicialComplex([[0, 1, 2], [1, 2, 3]])

        # Simplicial complex 2: same shape, vertices relabeled
        # mapping: 0→10, 1→11, 2→12, 3→13
        sc2 = tnx.SimplicialComplex([[10, 11, 12], [11, 12, 13]])

        # Run HigherOrderWL on rank-1 adjacency for both
        hol_sc1 = HigherOrderWeisfeilerLehmanHashing(
            wl_iterations=2, erase_base_features=False
        ).fit(
            sc1,
            neighborhood_type="adj",
            neighborhood_dim={"rank": 1, "via_rank": -1},
        )

        hol_sc2 = HigherOrderWeisfeilerLehmanHashing(
            wl_iterations=2, erase_base_features=False
        ).fit(
            sc2,
            neighborhood_type="adj",
            neighborhood_dim={"rank": 1, "via_rank": -1},
        )

        feats1 = sorted(hol_sc1.get_domain_features())
        feats2 = sorted(hol_sc2.get_domain_features())

        # Since the complexes are isomorphic (just relabeled), WL features as a
        # multiset should coincide → robustness / invariance.
        assert feats1 == feats2, (
            "Higher-order WL should be invariant under vertex relabeling "
            "for isomorphic complexes."
        )



# Higher-Order Weisfeiler–Lehman on Complexes

This tutorail explains the idea behind **higher-order Weisfeiler–Lehman (WL)** on **TopoNetX complexes**, and walks through the six tests implemented in the code.

The goal is to understand:

- How we **lift WL from graphs to complexes** (simplicial, cell, combinatorial),
- What it means to choose a **rank** and a **neighborhood matrix**,
- Why some tests are **success examples** (HO-WL sees more than graph WL),
- And why some tests are **deliberate failure examples** (limitations).


---

## 1. Recap: Classical WL on Graphs

Let a graph be

$$
G = (V, E), \quad V = \{1, \dots, n\}.
$$

Classical 1-WL (color refinement) iteratively refines node labels:

- At iteration \(t\), each node \(v \in V\) has a label \(\ell_t(v)\) (e.g. initial label is degree).
- We update:

  $$
  \ell_{t+1}(v)
    = h\Big(
        \ell_t(v),
        \{ \! \{ \ell_t(u) \mid u \in N(v) \} \! \}
      \Big),
  $$

  where \(N(v)\) is the set of neighbors of \(v\), \(\{\!\{ \cdot \}\!\}\) is a **multiset**, and \(h\) is a hash/injection that maps a pair (current label, multiset of neighbor labels) to a **new label**.

- After \(T\) iterations, the **graph signature** is typically the multiset of all node labels:

  $$
  \text{Sig}(G) = \{\!\{ \ell_T(v) \mid v \in V \}\!\}.
  $$

If two graphs have **different signatures**, WL distinguishes them.  
If they have **identical signatures**, WL **fails** to distinguish them (this is a known limitation, e.g. \(C_6\) vs \(C_3 \sqcup C_3\)).


---

## 2. Complexes and Ranks in TopoNetX

In TopoNetX we work with several kinds of complexes:

- **SimplicialComplex**: built from simplices: vertices, edges, triangles, tetrahedra, etc.
- **CellComplex**: built from cells glued along boundaries (more general than simplices).
- **CombinatorialComplex**: very general; cells are just finite sets with an incidence structure.

Each complex is **graded by rank**:

- Rank \(0\): vertices (0-cells),
- Rank \(1\): edges (1-cells),
- Rank \(2\): faces (2-cells),
- Rank \(3\): volumes / 3-cells, etc.

We denote the set of \(r\)-cells by

$$
C_r = \{ c_1, \dots, c_{n_r} \}.
$$


---

## 3. Neighborhood Matrices \(A^{(r, k)}\)

TopoNetX provides **neighborhood matrices** that encode how cells of one rank are related via cells of another rank. Conceptually:

- **Adjacency** (type `"adj"`):  
  rank \(r\) via rank \(k\) with \(r < k\).
- **Coadjacency** (type `"coadj"`):  
  rank \(r\) via rank \(k\) with \(r > k\).

We write this as a binary matrix

$$
A^{(r,k)} \in \{0,1\}^{n_r \times n_r},
$$

where:

- Rows/columns index \(r\)-cells \(c_i \in C_r\),
- \(A^{(r,k)}_{ij} = 1\) if \(c_i\) and \(c_j\) are “neighbors via rank \(k\)”  
  (e.g. two faces share an edge, two edges share a vertex, etc.).

Examples:

- **Rank-0 adjacency via rank-1 (graph adjacency)**:  
  vertices are neighbors if they share an edge.
- **Rank-2 coadjacency via rank-1**:  
  faces are neighbors if they share an edge.
- **Rank-1 coadjacency via rank-2**:  
  edges are neighbors if they co-occur in a 2-cell (face).


---

## 4. Higher-Order WL on Complexes

Given a complex and a choice of rank \(r\) and neighborhood type, we:

1. Choose rank \(r\) (e.g. faces, edges, vertices).
2. Compute the neighborhood matrix

   $$
   A^{(r,k)} \in \{0,1\}^{n_r \times n_r}.
   $$

3. View this as a graph

   $$
   G_r = (C_r, E_r),
   $$

   where \(C_r\) is the set of \(r\)-cells, and there is an edge between \(c_i\) and \(c_j\) iff \(A^{(r,k)}_{ij} = 1\).

4. Run **graph WL** on this cell-graph \(G_r\):

   - Nodes: \(c \in C_r\),
   - Neighbors: cells \(c' \in C_r\) with \(A^{(r,k)}_{cc'} = 1\),
   - Update rule:

     $$
     \ell_{t+1}(c)
       = h\Big(
           \ell_t(c),
           \{ \! \{ \ell_t(c') \mid c' \in N(c) \} \! \}
         \Big).
     $$

5. Aggregate all labels at the final iteration into a **bag-of-features**:

   $$
   \text{Sig}_r(\mathcal{K}) = \{\!\{ \ell_T(c) \mid c \in C_r \}\!\}.
   $$

This is precisely what the class `HigherOrderWeisfeilerLehmanHashing` does:

- `rank` and `via_rank` are chosen through `neighborhood_dim={"rank": r, "via_rank": k}`.
- `neighborhood_type="adj"` or `"coadj"` determines which matrix is used.
- `get_domain_features()` returns a multiset of hashes encoding \(\text{Sig}_r(\mathcal{K})\).

We compare **two complexes** by comparing their signatures at the same rank and neighborhood:

- If signatures are **different**, HO-WL distinguishes them at that level.
- If signatures are **equal**, HO-WL (with that choice of rank/neighborhood) **fails** to distinguish them.


---

## 5. How to Read the Printed Output

The test code prints lines like:

```text
✓ Graph WL on 0-skeleton adjacency (CC1 vs CC2)
    Result        : True
    Interpretation: As desired: graph WL only sees the shared 0–1 skeleton ...
````

Each line has:

* A **label** (what we are comparing),
* A boolean `Result: True/False`,
* An **Interpretation** that explains whether that result is expected and what it means.

Important:

* In some checks, `True` means **signatures are equal** (a *failure* to distinguish, by design).
* In other checks, `True` means **signatures are different** (a *success* to distinguish).

The label and the interpretation text tell you which one we are testing in each case.

---

## 6. Summary of the Six Tests

Below is a conceptual description of each test, matching the code and the printouts.

### Test 1 — Graph WL fails, Higher-Order WL succeeds (CombinatorialComplex, rank 2)

**Complex type:** `CombinatorialComplex`.

* Vertices: ({0,1,2,3}).
* Edges: a fixed K4-like set ((0,1),(1,2),(2,3),(3,0),(0,2),(0,3)) (**same** in both complexes).

We build two complexes:

1. **CC1**:

   * One single 2-cell ([0,1,2,3]) (a quadrilateral face).
2. **CC2**:

   * Two 2-cells ([0,1,2]) and ([0,2,3]) (two triangles).

We do:

1. **Graph WL on 0-skeleton adjacency (rank 0 via rank 1)**:

   * Only sees the vertex–edge graph.
   * The graph is identical for CC1 and CC2, so the WL signatures match.
   * `Result: True` here means **WL cannot distinguish** CC1 vs CC2 at the graph level — this is a **failure** of graph WL (as we want for the example).

2. **Higher-Order WL on rank-2 coadjacency via rank-1 (faces via edges)**:

   * Nodes are 2-cells (faces).
   * Coadjacency connects faces that share an edge.
   * CC1 has 1 face, CC2 has 2 faces → the 2-cell graphs differ.
   * `Result: True` here means **WL signatures differ**, so HO-WL **distinguishes** CC1 vs CC2 — a **success** for higher-order WL.

### Test 2 — Classic 1-WL Failure: (C_6) vs (C_3 \sqcup C_3) as CellComplex

**Complex type:** `CellComplex` (1-dimensional).

* `cx1`: encodes the 6-cycle graph (C_6).
* `cx2`: encodes two disjoint triangles (C_3 \sqcup C_3).

We run:

* **Higher-Order WL on rank-0 adjacency** (`neighborhood_type="adj"`, `{"rank": 0, "via_rank": -1}`):

  * This is exactly WL on the graph where nodes are vertices and edges connect adjacent vertices.
  * It is known that 1-WL **fails** to distinguish (C_6) from (C_3 \sqcup C_3).
  * `Result: True` here means **signatures are equal**, so WL fails — a **deliberate reproduction** of the classical failure example.

### Test 3 — Invariance Under Vertex Relabeling (SimplicialComplex, rank 1)

**Complex type:** `SimplicialComplex`.

* `sc1`: has 2-simplices ([0,1,2]) and ([1,2,3]).
* `sc2`: is the same complex but with vertex labels shifted:
  ([10,11,12]) and ([11,12,13]).

These two complexes are **isomorphic** as simplicial complexes.

We run:

* **Higher-Order WL on rank-1 adjacency**
  (`neighborhood_type="adj"`, `{"rank": 1, "via_rank": -1}`):

  * Nodes are edges (1-simplices).
  * Two edges are adjacent if they share a vertex.
  * Isomorphic complexes should yield **identical WL signatures** at this rank.

`Result: True` means the signatures match, so WL is **invariant under relabeling**, as desired.

### Test 4 — Edge-Level HO-WL Limitation (CellComplex, rank 1 via rank 2)

**Complex type:** `CellComplex`.

Common 1-skeleton (vertices and edges):

* Vertices: ({0,1,2,3}),
* Edges: square with a diagonal: ((0,1),(1,2),(2,3),(3,0),(0,2)).

We build two complexes:

1. `cx_A`: one 2-cell ([0,1,2,3]) (a single quadrilateral face).
2. `cx_B`: two 2-cells ([0,1,2]) and ([0,2,3]) (two triangles).

We run:

* **Higher-Order WL on rank-1 coadjacency via rank-2**
  (`neighborhood_type="coadj"`, `{"rank": 1, "via_rank": 2}`):

  * Nodes are edges.
  * Two edges are neighbors if they co-occur in a face.
  * Intuitively, you might expect this to see a difference between “one quad” and “two triangles”.

Empirically (with the current neighborhood definition):

* `Result: True` means **the WL signatures on edges are equal** for `cx_A` and `cx_B`,
* So **this particular edge-level neighborhood fails** to distinguish the two decompositions.

This is a **negative example**: it shows that **choice of rank and neighborhood matters**; higher-order WL is not automatically more powerful in every configuration.

### Test 5 — CC Example Where Both Rank-2 and Rank-3 HO-WL See Higher-Order Differences

**Complex type:** `CombinatorialComplex`.

We build two complexes `cc_A`, `cc_B`:

* Shared structure:

  * Vertices: ({0,1,2,3}).
  * Edges form two triangles sharing an edge:
    ((0,1),(1,2),(2,0)) and ((1,2),(2,3)).
  * Faces: ([0,1,2]) and ([1,2,3]) in both complexes.

**Difference in rank-3 structure**:

* `cc_A`: one rank-3 cell attached “above” ([0,1,2]).
* `cc_B`: two rank-3 cells, one above ([0,1,2]) and one above ([1,2,3]).

We run:

1. **Rank-2 HO-WL (faces via edges)**
   (`"coadj"`, `{"rank": 2, "via_rank": 1}`):

   * Nodes are faces (2-cells).
   * The coadjacency graph of faces already changes because of how rank-3 cells influence the co-incidence patterns.
   * `Result: True` means WL signatures differ, so rank-2 HO-WL **detects** the difference.

2. **Rank-3 HO-WL (3-cells via vertices)**
   (`"coadj"`, `{"rank": 3, "via_rank": 0}`):

   * Nodes are 3-cells.
   * Adjacency depends on shared vertices.
   * Obviously, one vs two 3-cells gives a different structure.
   * `Result: True` means WL signatures differ, so rank-3 HO-WL also **detects** the difference.

This test illustrates that:

* In this setting, both **rank-2** and **rank-3** HO-WL are sensitive to the higher-order differences.
* It is not only the “top rank” that can see them; sometimes lower ranks already encode this information via neighborhood matrices.

### Test 6 — Invariance Under Vertex Relabeling (CombinatorialComplex, rank 2)

**Complex type:** `CombinatorialComplex`.

We build two isomorphic complexes:

* `cc1`: vertices ({0,1,2,3}), edges forming a 4-cycle, faces ([0,1,2]) and ([0,2,3]).
* `cc2`: same structure but vertices relabeled ({10,11,12,13}), faces ([10,11,12]) and ([10,12,13]).

We run:

* **Higher-Order WL on rank-2 coadjacency via rank-1**
  (`"coadj"`, `{"rank": 2, "via_rank": 1}`):

  * Nodes are faces.
  * Faces are neighbors if they share an edge.
  * Because the complexes are isomorphic, we expect **identical WL signatures**.

`Result: True` means HO-WL is **invariant under vertex relabeling** at this rank in the combinatorial complex setting, as desired.

---

The code is provided below:


In [None]:
# ============================================================
# HIGHER-ORDER WL TEST SUITE
# ============================================================
#
# This cell implements SIX tests and *explains them in the console output*:
#
#  (1) Graph WL fails, Higher-Order WL on 2-cells succeeds (CombinatorialComplex).
#  (2) Reproducing a classic WL failure (C6 vs C3 ⊔ C3) in a CellComplex.
#  (3) Invariance under vertex relabeling on a SimplicialComplex.
#  (4) Edge-level Higher-Order WL on CellComplex: an empirical FAILURE case
#      (same 1-skeleton, different faces, but with this neighborhood WL
#       does not distinguish them).
#  (5) CombinatorialComplex where BOTH rank-2 and rank-3 HO-WL detect
#      different higher-order structure.
#  (6) Invariance under vertex relabeling on a CombinatorialComplex.
#
# In addition to docstrings, the main block prints:
#   - A short description of the complexes built in each test.
#   - Which ranks / via-ranks / neighborhood types are used.
#   - What “Result: True/False” *means* in context.
# ============================================================

import networkx as nx
import toponetx as tnx




# ------------------------------------------------------------------------------
# Helper for detailed, didactic printing
# ------------------------------------------------------------------------------
def print_check(name: str, condition: bool, when_true: str, when_false: str) -> None:
    """
    Pretty-print the result of a boolean test, with an interpretation.

    Parameters
    ----------
    name : str
        Short name / label for this test line.
    condition : bool
        Whether the equality / property holds.
    when_true : str
        Explanation if condition is True.
    when_false : str
        Explanation if condition is False.
    """
    mark = "✓" if condition else "✗"
    print(f"{mark} {name}")
    print(f"    Result        : {condition}")
    print(f"    Interpretation: {when_true if condition else when_false}")
    print()


# ------------------------------------------------------------------------------
# (1) Graph WL fails but Higher-Order WL succeeds (CC, rank-2)
# ------------------------------------------------------------------------------
def test_graph_wl_fails_but_higherorder_wl_succeeds():
    """
    Two CombinatorialComplex objects with identical 0–1 skeleton
    but different 2-cell structure.

    Setup
    -----
    Vertices: {0, 1, 2, 3}
    Edges   : [(0,1), (1,2), (2,3), (3,0), (0,2), (0,3)]
              This gives a K4-ish graph on {0,1,2,3}.

    CC1:
        - Same vertices and edges.
        - One rank-2 cell [0,1,2,3] (a single quadrilateral face).

    CC2:
        - Same vertices and edges.
        - Two rank-2 cells [0,1,2] and [0,2,3] (two triangles).

    Tests
    -----
    1) Graph WL on 0-skeleton adjacency (rank=0 via rank=1):
         - Uses vertex adjacency only.
         - Both CC1 and CC2 induce the SAME graph.
         - EXPECT: WL signatures equal → graph WL fails to see the difference.

    2) Higher-Order WL on rank-2 coadjacency (faces via edges):
         - Nodes = 2-cells (faces).
         - Edges between faces if they share a 1-cell (edge).
         - CC1 has 1 face, CC2 has 2 faces → different coadjacency pattern.
         - EXPECT: WL signatures differ → HO-WL succeeds.
    """

    # Shared 0–1 skeleton
    verts = [0, 1, 2, 3]
    edges = [
        (0, 1),
        (1, 2),
        (2, 3),
        (3, 0),
        (0, 2),
        (0, 3),
    ]

    # CC1: one quadrilateral 2-cell
    cc1 = tnx.CombinatorialComplex()
    for v in verts:
        cc1.add_cell([v], rank=0)
    for e in edges:
        cc1.add_cell(list(e), rank=1)
    cc1.add_cell([0, 1, 2, 3], rank=2)

    # CC2: two triangular 2-cells
    cc2 = tnx.CombinatorialComplex()
    for v in verts:
        cc2.add_cell([v], rank=0)
    for e in edges:
        cc2.add_cell(list(e), rank=1)
    cc2.add_cell([0, 1, 2], rank=2)
    cc2.add_cell([0, 2, 3], rank=2)

    # --- 1.a Graph WL on 0-skeleton adjacency (rank=0 via rank=1) ---
    ind1_0, A1_0 = neighborhood_from_complex(
        cc1, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": 1}
    )
    ind2_0, A2_0 = neighborhood_from_complex(
        cc2, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": 1}
    )

    G1 = nx.from_scipy_sparse_array(A1_0)
    G2 = nx.from_scipy_sparse_array(A2_0)

    wl1 = WeisfeilerLehmanHashing(
        G1, wl_iterations=3, use_node_attribute=None, erase_base_features=False
    )
    wl2 = WeisfeilerLehmanHashing(
        G2, wl_iterations=3, use_node_attribute=None, erase_base_features=False
    )

    graph_wl_equal = sorted(wl1.get_graph_features()) == sorted(
        wl2.get_graph_features()
    )

    # --- 1.b Higher-order WL on 2-cells (coadjacency via edges, rank=2 via rank=1) ---
    hol1 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cc1,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 2, "via_rank": 1},
    )

    hol2 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cc2,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 2, "via_rank": 1},
    )

    feats1 = hol1.get_domain_features()
    feats2 = hol2.get_domain_features()

    higherorder_distinguish = feats1 != feats2

    return graph_wl_equal, higherorder_distinguish


# ------------------------------------------------------------------------------
# (2) Higher-Order WL failure on CellComplex (rank-0 adjacency)
# ------------------------------------------------------------------------------
def test_higherorder_wl_failure_on_cell_complex():
    """
    Reproduce a known 1-WL failure (C6 vs C3 ⊔ C3) using CellComplex.

    Setup
    -----
    - cx1: cycle C6 on vertices {0,...,5}.
    - cx2: disjoint union of two triangles C3 ⊔ C3 on vertices {0,...,5}.

    These two graphs are **not isomorphic**, but 1-WL is known to
    fail to distinguish them.

    We embed them as 1D CellComplex instances and run Higher-Order WL
    on rank-0 adjacency (nodes via edges). This is essentially the same
    as 1-WL on the underlying graph.

    EXPECT:
        WL signatures equal → WL fails to distinguish them (as it should).
    """

    # C6 as a 1D cell complex
    cx1 = tnx.CellComplex(
        [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 0]],
        ranks=1,
    )

    # C3 ⊔ C3 as a 1D cell complex
    cx2 = tnx.CellComplex(
        [[0, 1], [1, 2], [2, 0], [3, 4], [4, 5], [5, 3]],
        ranks=1,
    )

    hol1 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cx1,
        neighborhood_type="adj",
        neighborhood_dim={"rank": 0, "via_rank": -1},
    )

    hol2 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cx2,
        neighborhood_type="adj",
        neighborhood_dim={"rank": 0, "via_rank": -1},
    )

    feats1 = sorted(hol1.get_domain_features())
    feats2 = sorted(hol2.get_domain_features())

    return feats1 == feats2


# ------------------------------------------------------------------------------
# (3) Robustness: invariance under vertex relabeling (SC, rank-1)
# ------------------------------------------------------------------------------
def test_higherorder_wl_invariance_sc():
    """
    Invariance under global vertex relabeling on a SimplicialComplex.

    Setup
    -----
    - sc1: simplicial complex with 2-simplices [0,1,2] and [1,2,3].
    - sc2: same complex, but vertex labels shifted:
            0→10, 1→11, 2→12, 3→13.

    Graph-theoretically and combinatorially, sc1 and sc2 are isomorphic.

    Test
    ----
    - We run Higher-Order WL on rank-1 adjacency for both:
        neighborhood_type = "adj"
        neighborhood_dim  = {"rank": 1, "via_rank": -1}
      (edges adjacent if they share a vertex).

    EXPECT:
        WL signatures equal → invariance under relabeling.
    """

    sc1 = tnx.SimplicialComplex([[0, 1, 2], [1, 2, 3]])
    sc2 = tnx.SimplicialComplex([[10, 11, 12], [11, 12, 13]])

    hol1 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=2, erase_base_features=False
    ).fit(
        sc1,
        neighborhood_type="adj",
        neighborhood_dim={"rank": 1, "via_rank": -1},
    )

    hol2 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=2, erase_base_features=False
    ).fit(
        sc2,
        neighborhood_type="adj",
        neighborhood_dim={"rank": 1, "via_rank": -1},
    )

    feats1 = sorted(hol1.get_domain_features())
    feats2 = sorted(hol2.get_domain_features())

    return feats1 == feats2


# ------------------------------------------------------------------------------
# (4) Edge-level HO-WL on CellComplex: empirical failure case
# ------------------------------------------------------------------------------
def test_edge_level_higherorder_wl_behavior():
    """
    Edge-level Higher-Order WL on a CellComplex, where empirically
    it does NOT distinguish two different 2D decompositions.

    Setup
    -----
    Vertices: {0,1,2,3}
    Edges   : square with diagonal
              (0,1), (1,2), (2,3), (3,0), (0,2)

    cx_A:
        - One 2-cell [0,1,2,3] (single quadrilateral face).

    cx_B:
        - Two 2-cells [0,1,2] and [0,2,3] (two triangular faces).

    Important:
        The 1-skeleton (vertices+edges) is identical.

    Test
    ----
    We run Higher-Order WL on rank-1 coadjacency via rank-2:
        neighborhood_type = "coadj"
        neighborhood_dim  = {"rank": 1, "via_rank": 2}
      i.e., edges are "neighbors" if they are co-faces of some 2-cell.

    Empirical observation:
        - WL signatures for cx_A and cx_B come out EQUAL.
        - That means that, for this particular neighborhood notion,
          HO-WL on edges does NOT distinguish the single-quad vs two-triangle
          decomposition.

    This is a NEGATIVE example: higher-order WL’s power depends on
    how we choose neighborhoods and ranks.
    """

    # Common 1-skeleton: square with diagonal
    edges = [
        [0, 1],
        [1, 2],
        [2, 3],
        [3, 0],
        [0, 2],
    ]

    # cx_A: one quadrilateral face
    cx_A = tnx.CellComplex(edges, ranks=1)
    cx_A.add_cell([0, 1, 2, 3], rank=2)

    # cx_B: two triangular faces
    cx_B = tnx.CellComplex(edges, ranks=1)
    cx_B.add_cell([0, 1, 2], rank=2)
    cx_B.add_cell([0, 2, 3], rank=2)

    # Higher-Order WL on edges: rank=1, coadj via rank=2
    hol_A = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cx_A,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 1, "via_rank": 2},
    )

    hol_B = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cx_B,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 1, "via_rank": 2},
    )

    feats_A = sorted(hol_A.get_domain_features())
    feats_B = sorted(hol_B.get_domain_features())

    # Return True if they are equal → "WL fails here"
    return feats_A == feats_B


# ------------------------------------------------------------------------------
# (5) CC example: BOTH rank-2 and rank-3 HO-WL distinguish
# ------------------------------------------------------------------------------
def test_cc_rank2_and_rank3_behavior():
    """
    CombinatorialComplex example where both rank-2 and rank-3
    HO-WL detect differences in higher-order structure.

    Setup
    -----
    Vertices: {0,1,2,3}
    Edges   : two triangles sharing an edge
              (0,1), (1,2), (2,0), (1,3), (2,3)
    Faces   : [0,1,2] and [1,2,3] (same in both CCs as a set)

    cc_A:
        - Rank-3: one cell attached "above" [0,1,2].

    cc_B:
        - Rank-3: two cells, one above [0,1,2], one above [1,2,3].

    Behavior (empirically)
    ----------------------
    Using TopoNetX neighborhood operators:
        - Rank-2 coadjacency via rank-1 (faces via edges) produces
          DIFFERENT WL signatures between cc_A and cc_B.
        - Rank-3 coadjacency via rank-0 (3-cells via vertices) also
          produces DIFFERENT signatures.

    So in this setting, both rank-2 and rank-3 HO-WL are sensitive to
    the extra higher-rank structure.
    """

    # Shared 0–1–2 skeleton
    verts = [0, 1, 2, 3]
    edges = [
        [0, 1],
        [1, 2],
        [2, 0],
        [1, 3],
        [2, 3],
    ]
    faces = [
        [0, 1, 2],
        [1, 2, 3],
    ]

    # cc_A
    cc_A = tnx.CombinatorialComplex()
    for v in verts:
        cc_A.add_cell([v], rank=0)
    for e in edges:
        cc_A.add_cell(e, rank=1)
    for f in faces:
        cc_A.add_cell(f, rank=2)
    # rank-3: one cell above [0,1,2]
    cc_A.add_cell([0, 1, 2], rank=3)

    # cc_B
    cc_B = tnx.CombinatorialComplex()
    for v in verts:
        cc_B.add_cell([v], rank=0)
    for e in edges:
        cc_B.add_cell(e, rank=1)
    for f in faces:
        cc_B.add_cell(f, rank=2)
    # rank-3: two cells, above [0,1,2] and [1,2,3]
    cc_B.add_cell([0, 1, 2], rank=3)
    cc_B.add_cell([1, 2, 3], rank=3)

    # --- 1) Rank-2 HO-WL (faces via edges) ---
    hol_A_2 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cc_A,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 2, "via_rank": 1},
    )
    hol_B_2 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cc_B,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 2, "via_rank": 1},
    )

    feats_A_2 = sorted(hol_A_2.get_domain_features())
    feats_B_2 = sorted(hol_B_2.get_domain_features())
    rank2_diff = feats_A_2 != feats_B_2  # empirically True

    # --- 2) Rank-3 HO-WL (3-cells via vertices) ---
    hol_A_3 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cc_A,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 3, "via_rank": 0},
    )
    hol_B_3 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=3, erase_base_features=False
    ).fit(
        cc_B,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 3, "via_rank": 0},
    )

    feats_A_3 = sorted(hol_A_3.get_domain_features())
    feats_B_3 = sorted(hol_B_3.get_domain_features())
    rank3_diff = feats_A_3 != feats_B_3  # also True

    return rank2_diff, rank3_diff


# ------------------------------------------------------------------------------
# (6) Invariance under vertex relabeling (CC, rank-2)
# ------------------------------------------------------------------------------
def test_higherorder_wl_invariance_cc():
    """
    Invariance under global vertex relabeling for a CombinatorialComplex.

    Setup
    -----
    We build two combinatorial complexes cc1 and cc2 which are isomorphic
    via a vertex permutation.

    - cc1 has:
        vertices: {0,1,2,3}
        edges   : [0,1], [1,2], [2,3], [3,0]
        faces   : [0,1,2], [0,2,3]

    - cc2 has the same structure but vertices relabeled by +10:
        vertices: {10,11,12,13}
        edges   : [10,11], [11,12], [12,13], [13,10]
        faces   : [10,11,12], [10,12,13]

    Test
    ----
    We run Higher-Order WL on rank-2 coadjacency via rank-1:
        neighborhood_type = "coadj"
        neighborhood_dim  = {"rank": 2, "via_rank": 1}
      (faces coadjacent if they share an edge).

    Note:
        For CombinatorialComplex, adjacency_matrix requires rank < via_rank,
        and coadjacency_matrix requires rank > via_rank. Here rank=2, via=1,
        so we **must** use "coadj".

    EXPECT:
        WL signatures equal → HO-WL respects combinatorial isomorphism.
    """

    # cc1
    cc1 = tnx.CombinatorialComplex()
    for v in [0, 1, 2, 3]:
        cc1.add_cell([v], rank=0)
    for e in [[0, 1], [1, 2], [2, 3], [3, 0]]:
        cc1.add_cell(e, rank=1)
    for f in [[0, 1, 2], [0, 2, 3]]:
        cc1.add_cell(f, rank=2)

    # cc2: relabeled by +10
    cc2 = tnx.CombinatorialComplex()
    for v in [10, 11, 12, 13]:
        cc2.add_cell([v], rank=0)
    for e in [[10, 11], [11, 12], [12, 13], [13, 10]]:
        cc2.add_cell(e, rank=1)
    for f in [[10, 11, 12], [10, 12, 13]]:
        cc2.add_cell(f, rank=2)

    hol1 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=2, erase_base_features=False
    ).fit(
        cc1,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 2, "via_rank": 1},
    )

    hol2 = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=2, erase_base_features=False
    ).fit(
        cc2,
        neighborhood_type="coadj",
        neighborhood_dim={"rank": 2, "via_rank": 1},
    )

    feats1 = sorted(hol1.get_domain_features())
    feats2 = sorted(hol2.get_domain_features())

    return feats1 == feats2


# ------------------------------------------------------------------------------
# Run all tests with detailed explanations
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    print("\n================ HIGHER-ORDER WL TESTS (SUITE) ================\n")
    print("Each test builds one or more TopoNetX complexes (SC, CX, CC),")
    print("selects a neighborhood matrix A^(r,k), and compares WL signatures.")
    print("Legend: 'Result: True' means the equality / property being tested")
    print("        holds; the Interpretation line explains why that is good")
    print("        or bad for our understanding of higher-order WL.\n")

    # ------------------------------------------------------------------
    # TEST 1
    # ------------------------------------------------------------------
    print("TEST 1: Graph WL failure vs Higher-Order WL success (CC, rank-2)\n")
    print("  Complex types:")
    print("    - CC1, CC2: CombinatorialComplex on vertices {0,1,2,3}.")
    print("    - Both share the same 0–1 skeleton, but differ in 2-cells:")
    print("        CC1: one 2-cell [0,1,2,3].")
    print("        CC2: two 2-cells [0,1,2] and [0,2,3].")
    print("  What we compare:")
    print("    - Graph WL on 0-skeleton adjacency (r=0 via k=1).")
    print("    - Higher-Order WL on 2-cells via edges (coadj, r=2 > k=1).\n")

    g_equal, ho_diff = test_graph_wl_fails_but_higherorder_wl_succeeds()
    print_check(
        "Graph WL on 0-skeleton adjacency (CC1 vs CC2)",
        g_equal,
        when_true="As desired: graph WL only sees the shared 0–1 skeleton and cannot distinguish CC1 vs CC2.",
        when_false="Unexpected: graph WL claims CC1 and CC2 differ, but their 0–1 skeleton is identical."
    )
    print_check(
        "Higher-Order WL on rank-2 coadjacency (faces via edges) (CC1 vs CC2)",
        ho_diff,
        when_true="As desired: WL on 2-cells distinguishes the single-face vs two-face structure.",
        when_false="Unexpected: WL on 2-cells fails to see the difference in rank-2 structure."
    )

    # ------------------------------------------------------------------
    # TEST 2
    # ------------------------------------------------------------------
    print("TEST 2: Reproducing known 1-WL failure (C6 vs C3 ⊔ C3) as CellComplex\n")
    print("  Complex types:")
    print("    - cx1: CellComplex encoding a 6-cycle C6.")
    print("    - cx2: CellComplex encoding disjoint union C3 ⊔ C3.")
    print("  What we compare:")
    print("    - Higher-Order WL on rank-0 adjacency (nodes via edges).")
    print("    - This replicates standard 1-WL on the underlying graphs.\n")

    fail_ok = test_higherorder_wl_failure_on_cell_complex()
    print_check(
        "Higher-Order WL on rank-0 adjacency (C6 vs C3⊔C3)",
        fail_ok,
        when_true="As expected: this matches the known 1-WL failure; WL signatures are identical.",
        when_false="Unexpected: WL separates C6 from C3⊔C3, contradicting the standard failure example."
    )

    # ------------------------------------------------------------------
    # TEST 3
    # ------------------------------------------------------------------
    print("TEST 3: Invariance under vertex relabeling (SimplicialComplex, rank-1)\n")
    print("  Complex types:")
    print("    - sc1: SC with 2-simplices [0,1,2], [1,2,3].")
    print("    - sc2: same SC but vertices relabeled {0,1,2,3} → {10,11,12,13}.")
    print("  What we compare:")
    print("    - Higher-Order WL on rank-1 adjacency (edges via shared vertices).")
    print("    - We test invariance under combinatorial isomorphism.\n")

    inv_sc_ok = test_higherorder_wl_invariance_sc()
    print_check(
        "Higher-Order WL on rank-1 adjacency (isomorphic SCs, relabeled vertices)",
        inv_sc_ok,
        when_true="As expected: WL is invariant under graph/complex isomorphism (vertex relabeling).",
        when_false="Unexpected: WL breaks invariance under simple relabeling, indicating a bug."
    )

    # ------------------------------------------------------------------
    # TEST 4
    # ------------------------------------------------------------------
    print("TEST 4: Edge-level HO-WL on CellComplex (empirical failure case)\n")
    print("  Complex types:")
    print("    - cx_A: CellComplex with 1 quad face [0,1,2,3].")
    print("    - cx_B: CellComplex with 2 triangular faces [0,1,2], [0,2,3].")
    print("    - Both share the same 1-skeleton: square with a diagonal.")
    print("  What we compare:")
    print("    - Higher-Order WL on rank-1 coadjacency via rank-2 (edges via faces).")
    print("    - Empirically, WL signatures are equal → this neighborhood")
    print("      does NOT distinguish the two 2D decompositions.\n")

    edge_equal = test_edge_level_higherorder_wl_behavior()
    print_check(
        "Higher-Order WL on rank-1 coadjacency via rank-2 (edges via faces)",
        edge_equal,
        when_true="Empirically: HO-WL on edges does NOT distinguish one quad vs two triangles with this neighborhood (a limitation).",
        when_false="Empirically: in your environment HO-WL on edges DOES distinguish the decompositions (stronger than expected)."
    )

    # ------------------------------------------------------------------
    # TEST 5
    # ------------------------------------------------------------------
    print("TEST 5: CC example where BOTH rank-2 and rank-3 HO-WL detect higher-order differences\n")
    print("  Complex types:")
    print("    - cc_A, cc_B: CombinatorialComplexes built from two triangles")
    print("      sharing an edge on vertices {0,1,2,3}.")
    print("    - Faces [0,1,2] and [1,2,3] exist in both.")
    print("    - Rank-3 structure differs: cc_A has 1 three-cell, cc_B has 2.")
    print("  What we compare:")
    print("    - Rank-2 HO-WL: faces via edges (coadj, r=2 via k=1).")
    print("    - Rank-3 HO-WL: 3-cells via vertices (coadj, r=3 via k=0).\n")

    rank2_diff, rank3_diff = test_cc_rank2_and_rank3_behavior()
    print_check(
        "Higher-Order WL on rank-2 coadjacency (faces via edges) (CC_A vs CC_B)",
        rank2_diff,
        when_true="As observed: rank-2 HO-WL is already sensitive to the extra rank-3 cell structure.",
        when_false="Unexpected: rank-2 HO-WL fails to detect the higher-rank differences."
    )
    print_check(
        "Higher-Order WL on rank-3 coadjacency via vertices (3-cells via vertices) (CC_A vs CC_B)",
        rank3_diff,
        when_true="As expected: HO-WL on rank-3 cells also detects the extra 3-cell in CC_B.",
        when_false="Unexpected: rank-3 HO-WL fails to detect the higher-rank differences."
    )

    # ------------------------------------------------------------------
    # TEST 6
    # ------------------------------------------------------------------
    print("TEST 6: Invariance under vertex relabeling (CombinatorialComplex, rank-2)\n")
    print("  Complex types:")
    print("    - cc1: CC on vertices {0,1,2,3}, edges forming a 4-cycle,")
    print("      faces [0,1,2] and [0,2,3].")
    print("    - cc2: same CC but vertices relabeled {0,1,2,3} → {10,11,12,13}.")
    print("  What we compare:")
    print("    - Higher-Order WL on rank-2 coadjacency via rank-1 (faces via edges).")
    print("    - We test invariance under combinatorial isomorphism in the CC setting.\n")

    inv_cc_ok = test_higherorder_wl_invariance_cc()
    print_check(
        "Higher-Order WL on rank-2 coadjacency via edges (isomorphic CCs, relabeled vertices)",
        inv_cc_ok,
        when_true="As expected: HO-WL respects combinatorial isomorphism in the CC setting.",
        when_false="Unexpected: HO-WL does not respect isomorphism here, indicating an implementation issue."
    )

    print("===============================================================\n")




Each test builds one or more TopoNetX complexes (SC, CX, CC),
selects a neighborhood matrix A^(r,k), and compares WL signatures.
Legend: 'Result: True' means the equality / property being tested
        holds; the Interpretation line explains why that is good
        or bad for our understanding of higher-order WL.

TEST 1: Graph WL failure vs Higher-Order WL success (CC, rank-2)

  Complex types:
    - CC1, CC2: CombinatorialComplex on vertices {0,1,2,3}.
    - Both share the same 0–1 skeleton, but differ in 2-cells:
        CC1: one 2-cell [0,1,2,3].
        CC2: two 2-cells [0,1,2] and [0,2,3].
  What we compare:
    - Graph WL on 0-skeleton adjacency (r=0 via k=1).
    - Higher-Order WL on 2-cells via edges (coadj, r=2 > k=1).

✓ Graph WL on 0-skeleton adjacency (CC1 vs CC2)
    Result        : True
    Interpretation: As desired: graph WL only sees the shared 0–1 skeleton and cannot distinguish CC1 vs CC2.

✓ Higher-Order WL on rank-2 coadjacency (faces via edges) (CC1 vs CC2)
 


# Higher-Order Weisfeiler–Lehman Hashing on Complexes

This notebook explores how the classical Weisfeiler–Lehman (WL) graph test can be generalized to **cell complexes**, **simplicial complexes**, and **combinatorial complexes** using **neighborhood matrices**, and how this “higher-order” WL can distinguish structures that are invisible to graph WL.

We will:

1. Recall the **1-dimensional WL test on graphs**.
2. Define **neighborhood matrices** on complexes (TopoNetX-style).
3. Define **Higher-Order WL on complexes**.
4. Walk through six examples where graph WL fails but higher-order WL succeeds.
5. Show a final example where **graph, SC, and CX WL all fail**, but **CC WL** succeeds.

---

## 1. Weisfeiler–Lehman (WL) on Graphs

Let a graph be

$$
G = (V, E).
$$

The **1-dimensional WL test** (also called color refinement) iteratively updates node labels based on neighbor labels.

### 1.1 Initialization

Each node \(v \in V\) starts with an initial label
\(\ell_0(v)\).  
This can be, for example, all the same label, or the degree of the node.

We can think of
\(\ell_0 : V \to \mathcal{A}\), where \(\mathcal{A}\) is a set of discrete labels (strings, integers, …).

### 1.2 WL Update Rule

At iteration \(t = 0, 1, \dots\), we have labels \(\ell_t(v)\) for each node.

Define the **multiset of neighbor labels** at iteration \(t\) as

$$
M_t(v) = \{ \ell_t(u) \,:\, u \in N(v) \},
$$

where \(N(v)\) is the set of neighbors of \(v\).

Then the **WL update** at iteration \(t+1\) is

$$
\ell_{t+1}(v)
=
h\big( \ell_t(v), M_t(v) \big),
$$

where \(h\) is an *injective* hashing / encoding function. In practice, we serialize \(\ell_t(v)\) and the sorted list of neighbor labels into a string and hash it.

After \(T\) iterations, the **WL signature** of a graph is the multiset of all node labels:

$$
\mathrm{sig}(G)
=
\{ \ell_T(v) : v \in V \}.
$$

Two graphs are **1-WL indistinguishable** if their WL signatures are the same.

In our code, when we call

```python
graph_wl_signature_on_rank0(domain)
````

we construct the **0-skeleton graph** of `domain` (vertices and edges), and then run this standard WL procedure.

---

## 2. Complexes and Neighborhood Matrices (TopoNetX)

Graphs only give us one level of structure (vertices + edges). In **TopoNetX**, we can work with richer objects:

* **SimplicialComplex:** sets of simplices (vertices, edges, triangles, tetrahedra, …).
* **CellComplex:** cells of arbitrary shapes (edges, polygons, 3D cells, …), glued along faces.
* **CombinatorialComplex:** a very general incidence-based complex with cells of arbitrary ranks and arbitrary incidence patterns.
* Other classes: **PathComplex**, **ColoredHyperGraph**, etc.

For each complex, TopoNetX can build various **neighborhood matrices** that relate cells of rank (r) via cells of rank (k):

* adjacency,
* coadjacency,
* etc.

### 2.1 Rank and Via-Rank

Let $\mathcal{K}$ be a complex and let $C_r$ denote the set of all **rank-(r)** cells:

$$
C_r = { c_1, \dots, c_{n_r} }.
$$

Examples:

* In a simplicial complex:

  * rank 0: vertices,
  * rank 1: edges,
  * rank 2: triangles,
  * rank 3: tetrahedra, etc.
* In a cell complex:

  * rank 0: vertices,
  * rank 1: edges,
  * rank 2: polygonal faces, etc.
* In a combinatorial complex:

  * rank (r) is an abstract level; cells can be anything, as long as incidences are defined.

We can define a **neighborhood** between rank-(r) cells using some **via-rank** (k):

* adjacency: “two rank-(r) cells are adjacent if they share a rank-(k) coface”,
* coadjacency: “two rank-(r) cells are coadjacent if they share a rank-(k) face”.

TopoNetX provides functions such as:

* `adjacency_matrix(rank, via_rank, index=True)`
* `coadjacency_matrix(rank, via_rank, index=True)`

Our helper `neighborhood_from_complex` wraps these methods and returns:

* `ind`: a list of ids of the chosen rank-(r) cells,
* `A`: a sparse matrix encoding the adjacency or coadjacency between these cells.

Some common patterns:

* **0-skeleton (graph) adjacency**:

  * for simplicial / cell complexes:

    ```python
    {"rank": 0, "via_rank": -1}
    ```

* **2-cell coadjacency via edges** (faces sharing an edge):

  ```python
  {"rank": 2, "via_rank": 1}
  ```

* **rank-3 coadjacency via vertices** in a `CombinatorialComplex`:

  ```python
  {"rank": 3, "via_rank": 0}
  ```

This is the key tool: given `(rank, via_rank)` and `neighborhood_type` (`"adj"` or `"coadj"`), we can build a **graph of cells**, and then apply WL to that graph.

---

## 3. Higher-Order WL on Complexes

We now generalize WL from **nodes in a graph** to **cells of arbitrary rank** in a complex.

### 3.1 Induced Cell Graph

Fix:

* a complex $\mathcal{K}$,
* a rank (r) (we will do WL on rank-(r) cells),
* a via-rank (k) (we use cells of rank (k) to define adjacency),
* a neighborhood type (adjacency or coadjacency).

From TopoNetX we obtain a matrix

$$
A^{(r,k)} \in {0,1}^{n_r \times n_r},
$$

and an index set

$$
C_r = { c_1, \dots, c_{n_r} }.
$$

We treat $(C_r, A^{(r,k)})$ as a **graph**:

* **nodes**: the cells $c_i$,
* **edges**: there is an edge between $c_i$ and $c_j$ iff

  $$
  A^{(r,k)}_{ij} \neq 0.
  $$

We call this the **cell graph at rank (r)** induced by ((r, k)).

### 3.2 Higher-Order WL Update Rule

We now run WL on this cell graph.

* Labels are now on cells: (\ell_t : C_r \to \mathcal{A}).

* For each cell (c \in C_r), define its neighbor multiset:

  $$
  M_t(c) = \{ \ell_t(c') : c' \in N(c) \},
  $$

  where (N(c)) is the set of neighbors in the cell graph.

* WL update at iteration (t+1) is

  $$
  \ell_{t+1}(c) = h\big( \ell_t(c), M_t(c) \big).
  $$

Here again, (h) is any injective encoding of the pair “current label + multiset of neighbor labels”.

### 3.3 Higher-Order WL Signature

After (T) iterations, the **higher-order WL signature** of $\mathcal{K}$ at rank (r) is the multiset of all labels:

$$
\mathrm{sig}^{(r,k)}(\mathcal{K})=\{ \ell_T(c) : c \in C_r \}.
$$

Special cases:

* (r = 0), (k = 1): standard 1-WL on the 0-skeleton graph.
* (r = 2), (k = 1): WL on faces that share edges.
* (r = 3), (k = 0): WL on 3-cells that share vertices, etc.

In code, we do this with:

```python
HigherOrderWeisfeilerLehmanHashing(
    wl_iterations=..., erase_base_features=...
).fit(
    domain=complex,
    neighborhood_type="adj" or "coadj",
    neighborhood_dim={"rank": r, "via_rank": k},
)
```

---

## 4. Six Examples Where Graph WL Fails but Higher-Order WL Succeeds

In **Cell 1** of this notebook, we construct six pairs of complexes. For each pair, we compute:

1. **Graph WL on the 0-skeleton** (classical 1-WL on the vertex-edge graph).
2. **Higher-order WL at rank 2 or on an appropriate rank using coadjacency**.

In every example, we deliberately ensure:

* The 0-skeleton graph is the same for the two complexes → **graph WL signatures match**.
* The higher-rank structure (faces, hyperedges, etc.) is different → **higher-order WL can tell them apart**.

Below we describe each pair **explicitly**.

---

### 4.1 Example 1 — Filled vs Hollow Square (CellComplex)

**Vertices:** ({0, 1, 2, 3}).

**1-skeleton edges:**

* ((0,1)),
* ((1,2)),
* ((2,3)),
* ((3,0)).

#### Complex A: `square_filled`

* Rank-0 cells: the 4 vertices.
* Rank-1 cells: the 4 edges above.
* Rank-2 cells: **one** face filling the square ([0,1,2,3]).

#### Complex B: `square_hollow`

* Rank-0 cells: the same 4 vertices.
* Rank-1 cells: the same 4 edges.
* Rank-2 cells: **none**.

#### Behavior

* 0-skeleton graph: same 4-cycle in both → graph WL says “equal”.
* Rank-2 coadjacency:

  * In `square_filled`, there is one 2-cell.
  * In `square_hollow`, there are zero 2-cells.

  So higher-order WL on rank 2 (coadjacency) distinguishes them.

---

### 4.2 Example 2 — Same 1-Skeleton, Different 2-Cell Attachments

**Vertices:** ({0, 1, 2, 3}).

**1-skeleton edges:**

* Square edges: ((0,1), (1,2), (2,3), (3,0)),
* Diagonal: ((0,2)).

Same in both complexes.

#### Complex A: `cx_A`

* Rank-0: vertices 0,1,2,3.
* Rank-1: edges above.
* Rank-2: **one** quadrilateral 2-cell ([0,1,2,3]).

#### Complex B: `cx_B`

* Rank-0: same.
* Rank-1: same.
* Rank-2: **two** triangular 2-cells:

  * ([0,1,2]),
  * ([0,2,3]).

#### Behavior

* 0-skeleton graph: identical → graph WL equal.
* Rank-2 coadjacency graph:

  * Complex A: one 2-cell, no neighbors.
  * Complex B: two 2-cells that share an edge.

  The induced cell graphs differ, and higher-order WL on rank 2 (coadj) detects it.

---

### 4.3 Example 3 — Same Graph, Different Number of Filled Cycles

We build a graph with **two squares** that share some edges.

**Vertices:** ({0,1,2,3,4,5}).

**Edges:**

* Square 1: (0-1-2-3-0),
* Square 2: (1-4-5-2).

Same 1-skeleton in both complexes.

#### Complex A: `cx3_A`

* Rank-0: vertices.
* Rank-1: edges listed above.
* Rank-2: **one** face:

  * square ([0,1,2,3]) is filled,
  * square ([1,4,5,2]) is *not* filled.

#### Complex B: `cx3_B`

* Rank-0: same vertices.
* Rank-1: same edges.
* Rank-2: **two** faces:

  * square ([0,1,2,3]),
  * square ([1,4,5,2]).

#### Behavior

* 0-skeleton graph: the same → graph WL equal.
* Rank-2 coadjacency:

  * Complex A: one 2-cell.
  * Complex B: two 2-cells.

  WL on rank-2 coadjacency sees the different number of faces.

---

### 4.4 Example 4 — Same Graph (C_6), Extra Rank-2 Cell vs None (CombinatorialComplex)

**Vertices:** ({0,1,2,3,4,5}).

**Edges (cycle (C_6)):**

* ((0,1), (1,2), (2,3), (3,4), (4,5), (5,0)).

#### CombinatorialComplex A: `cc4_A`

* Rank-0: vertices.
* Rank-1: the 6 edges of the cycle.
* Rank-2: **one** hyperedge (rank-2 cell) on vertices ([0,2,4]).

#### CombinatorialComplex B: `cc4_B`

* Rank-0: same vertices.
* Rank-1: same edges.
* Rank-2: **no** hyperedge.

#### Behavior

* 0-skeleton graph: same cycle (C_6) → graph WL equal.
* Rank-2 coadjacency via rank-0:

  * `cc4_A` has one rank-2 cell,
  * `cc4_B` has none.

  Higher-order WL on rank-2 (coadj, via rank 0) distinguishes them.

This shows how combinatorial complexes can add hyperedges on top of a fixed graph.

---

### 4.5 Example 5 — Clique Complex vs Bare Graph

Underlying **graph:**

* Vertices: ({0,1,2,3}).
* Edges:

  * triangle (0-1-2): ((0,1), (1,2), (2,0)),
  * triangle (1-2-3): ((1,3), (2,3)).

#### SimplicialComplex: `sc5_A`

* Rank-0: vertices 0,1,2,3.
* Rank-1: edges as above.
* Rank-2: the 2-simplices:

  * ([0,1,2]),
  * ([1,2,3]).

This is essentially the **clique complex** of the graph.

#### CellComplex: `cc5_B`

* Rank-0: same vertices.
* Rank-1: same edges.
* Rank-2: **none**.

This is just the bare graph.

#### Behavior

* 0-skeleton graph: identical → graph WL equal.
* Rank-2 coadjacency:

  * `sc5_A` has 2 faces,
  * `cc5_B` has none.

  Higher-order WL on rank 2 (coadj) distinguishes them.

---

### 4.6 Example 6 — Subdivided vs Non-Subdivided Face

Underlying **graph:**

* Vertices: ({0,1,2,3}).
* Edges:

  * square: ((0,1), (1,2), (2,3), (3,0)),
  * diagonal: ((0,2)).

Same in both complexes.

#### Complex A: `cx6_A`

* Rank-0: vertices.
* Rank-1: edges.
* Rank-2: one quadrilateral 2-cell ([0,1,2,3]).

#### Complex B: `cx6_B`

* Rank-0: vertices.
* Rank-1: edges.
* Rank-2: two triangles:

  * ([0,1,2]),
  * ([0,2,3]).

#### Behavior

* 0-skeleton graph: same → graph WL equal.
* Rank-2 coadjacency:

  * Complex A: one 2-cell.
  * Complex B: two 2-cells that share an edge.

  Higher-order WL on rank 2 detects the different subdivisions.

---

## 5. Combinatorial Complex Example:

### Graph / SC / CX WL All Fail — CC WL Succeeds

In **Cell 2** of the notebook, we build a more subtle example to highlight why **CombinatorialComplex** + higher-order WL is strictly more expressive.

### 5.1 Base Graph: Two Triangles Sharing an Edge

Underlying **graph:**

* Vertices: ({0,1,2,3}).
* Edges:

  * triangle ((0,1,2)): ((0,1), (1,2), (2,0)),
  * triangle ((1,2,3)): ((1,3), (2,3)).

This is a “two triangles sharing the edge ((1,2))” graph.

### 5.2 SimplicialComplex and CellComplex Versions

We build:

* `sc_A`, `sc_B`: **identical** simplicial complexes

  * Rank-0: vertices 0,1,2,3.
  * Rank-1: edges of the graph.
  * Rank-2: simplices ([0,1,2]), ([1,2,3]).
* `cx_A`, `cx_B`: **identical** cell complexes

  * Rank-0: vertices.
  * Rank-1: edges.
  * Rank-2: faces with the same incidence as in `sc_A`, `sc_B`.

So:

* Graph WL on `sc_A` vs `sc_B`: same 0-skeleton → equal signatures.
* Higher-order WL on rank-2 coadjacency in SC: same faces → equal.
* Same for CX: equal in both 0-skeleton WL and rank-2 WL.

### 5.3 CombinatorialComplex Versions

Now we build two **CombinatorialComplexes** `cc_A` and `cc_B` that share the same 0-, 1-, and 2-cells, but differ in **rank 3**.

#### CC A: `cc_A`

* Rank-0: vertices 0,1,2,3.
* Rank-1: the edges of the base graph.
* Rank-2: cells ([0,1,2]), ([1,2,3]).
* Rank-3: **one** cell attached to ([0,1,2]).

#### CC B: `cc_B`

* Rank-0: same.
* Rank-1: same.
* Rank-2: same.
* Rank-3: **two** cells:

  * one attached to ([0,1,2]),
  * one attached to ([1,2,3]).

So the **graph structure and 2-dimensional structure are identical**, but the **3-dimensional structure differs**.

### 5.4 WL Behavior

We compute:

1. Graph WL on base graph (edges only): **equal**.
2. Graph WL on 0-skeleton of SC: **equal**.
3. Higher-order WL on SC rank-2 (coadj): **equal**.
4. Graph WL on 0-skeleton of CX: **equal**.
5. Higher-order WL on CX rank-2 (coadj): **equal**.
6. Graph WL on 0-skeleton of CC: **equal**.
7. **Higher-order WL on CC rank-3 (coadj via rank 0):**

   * `cc_A` has 1 rank-3 cell,
   * `cc_B` has 2 rank-3 cells.

   The induced cell graphs at rank 3 differ, and higher-order WL on ((r=3, k=0)) yields **different signatures**.

Conclusion:

> Graph, SimplicialComplex, CellComplex WL — all fail.
> Only CombinatorialComplex WL on $(r=3, k=0)$ detects the difference.

---




Next we provide the code that explains the above examples in our package.

In [None]:
# ============================================================
# 5. COMBINATORIAL COMPLEX TEST (DETAILED):
# WL FAILS ON GRAPH / SC / CX BUT SUCCEEDS ON CC (HIGHER RANK)
# ============================================================
#
# This cell implements the example described in the tutorial:
#
# We build one underlying combinatorial structure in four different ways:
#   (1) As a plain graph (just vertices + edges).
#   (2) As a SimplicialComplex (with 2-simplices).
#   (3) As a CellComplex (with 2-cells).
#   (4) As a CombinatorialComplex (with 0-,1-,2-, and 3-cells).
#
# Then we create TWO versions A and B:
#   - In (1)-(3), A and B are *identical* (no difference in structure).
#   - In (4), A and B differ ONLY in rank-3 cells:
#       cc_A has ONE rank-3 cell,
#       cc_B has TWO rank-3 cells.
#
# For each representation we run one or more WL tests and print:
#   - Whether the WL signatures are equal.
#   - A short interpretation of what that equality/inequality means.
#
# The point of this cell is to show:
#   • Plain graph WL cannot see the CC difference (as expected).
#   • WL on SC and CX (including rank-2) also cannot see the CC difference.
#   • Only higher-order WL on the CombinatorialComplex at rank 3 detects the
#     extra 3-cell in cc_B and separates cc_A from cc_B.
# ============================================================

import networkx as nx
import toponetx as tnx



# ---------------------------------------------------------------------
# Small helper for didactic printing
# ---------------------------------------------------------------------
def print_check(label: str, cond: bool, when_true: str, when_false: str) -> None:
    """
    Print a labeled boolean check with a textual interpretation.

    Parameters
    ----------
    label : str
        Short description of the test being performed.
    cond : bool
        Result of the equality check (e.g. sig_A == sig_B).
    when_true : str
        Explanation of what it means if cond is True.
    when_false : str
        Explanation of what it means if cond is False.
    """
    mark = "✓" if cond else "✗"
    print(f"{mark} {label}")
    print(f"    Result        : {cond}")
    print(f"    Interpretation: {when_true if cond else when_false}")
    print()  # blank line for readability


# ---------------------------------------------------------------------
# 1. Utility functions: WL on graphs and WL on complexes
# ---------------------------------------------------------------------
def graph_wl_signature_from_edges(edges, wl_iterations: int = 3):
    """
    Run classical 1-WL on a plain graph specified by an edge list.

    Here we ignore complexes entirely and just treat edges as an
    undirected NetworkX graph.

    Parameters
    ----------
    edges : list of [u, v]
        Undirected edges of the graph.
    wl_iterations : int
        Number of WL refinement iterations.

    Returns
    -------
    list of str
        Sorted multiset of graph-level features (WL labels).
    """
    G = nx.Graph()
    G.add_edges_from(edges)
    wl = WeisfeilerLehmanHashing(
        graph=G,
        wl_iterations=wl_iterations,
        use_node_attribute=None,
        erase_base_features=False,
    )
    return sorted(wl.get_graph_features())


def graph_wl_signature_on_rank0_complex(domain, wl_iterations: int = 3):
    """
    Run WL on the 0-skeleton of a TopoNetX complex.

    Conceptually:
        1. We extract the vertex adjacency matrix of the complex.
        2. We treat that as a graph.
        3. We run the standard Weisfeiler–Lehman hashing on that graph.

    For CombinatorialComplex / ColoredHyperGraph we must specify a
    via_rank to talk about adjacency between vertices via edges.
    For simplicial and cell complexes, via_rank is ignored.

    Parameters
    ----------
    domain : TopoNetX complex
        One of SimplicialComplex, CellComplex, CombinatorialComplex, etc.
    wl_iterations : int
        Number of WL refinement iterations.

    Returns
    -------
    list of str
        Sorted multiset of WL features on the 0-skeleton graph.
    """
    if isinstance(domain, (tnx.CombinatorialComplex, tnx.ColoredHyperGraph)):
        # vertices (rank 0) adjacent via edges (rank 1)
        neigh_dim = {"rank": 0, "via_rank": 1}
    else:
        # via_rank is ignored in SimplicialComplex / CellComplex
        neigh_dim = {"rank": 0, "via_rank": -1}

    ind, A = neighborhood_from_complex(
        domain,
        neighborhood_type="adj",
        neighborhood_dim=neigh_dim,
    )
    G = nx.from_scipy_sparse_array(A)
    wl = WeisfeilerLehmanHashing(
        graph=G,
        wl_iterations=wl_iterations,
        use_node_attribute=None,
        erase_base_features=False,
    )
    return sorted(wl.get_graph_features())


def higher_order_wl_signature_complex(
    domain,
    neighborhood_type: str,
    neighborhood_dim: dict,
    wl_iterations: int = 3,
):
    """
    Run higher-order WL on a TopoNetX complex.

    Conceptually:
        1. Choose a rank r and via-rank k using neighborhood_dim.
        2. Build a (co)adjacency matrix A^{(r,k)} for rank-r cells.
        3. Treat rank-r cells as nodes in a graph defined by A^{(r,k)}.
        4. Run WL on that cell-graph.

    This is exactly the "higher-order WL" described in the tutorial.

    Parameters
    ----------
    domain : TopoNetX complex
    neighborhood_type : {"adj", "coadj"}
        Type of neighborhood: adjacency or coadjacency.
    neighborhood_dim : dict
        Dictionary with keys "rank" and "via_rank" describing which
        cells we run WL on and via which rank we connect them.
    wl_iterations : int
        Number of WL refinement iterations.

    Returns
    -------
    list of str
        Sorted multiset of higher-order WL features over the chosen rank.
    """
    hol = HigherOrderWeisfeilerLehmanHashing(
        wl_iterations=wl_iterations,
        erase_base_features=False,
    ).fit(
        domain=domain,
        neighborhood_type=neighborhood_type,
        neighborhood_dim=neighborhood_dim,
    )
    return sorted(hol.get_domain_features())


# ---------------------------------------------------------------------
# 2. Base structure: two triangles sharing an edge (underlying graph)
# ---------------------------------------------------------------------
# We use the same base graph in all four representations:
#
#   vertices: 0, 1, 2, 3
#   edges:
#       triangle (0,1,2): (0,1), (1,2), (2,0)
#       triangle (1,2,3): (1,3), (2,3)
#
# This is the "two triangles sharing an edge" example described in the text.
edges_base = [
    [0, 1],
    [1, 2],
    [2, 0],   # triangle 0-1-2
    [1, 3],
    [2, 3],   # triangle 1-2-3
]


# ---------------------------------------------------------------------
# 3. SimplicialComplex: sc_A and sc_B are IDENTICAL
# ---------------------------------------------------------------------
# Both sc_A and sc_B contain exactly the two 2-simplices:
#   [0,1,2] and [1,2,3]
#
# Because they are identical:
#   - 0-skeleton WL should see them as the same.
#   - rank-2 higher-order WL (faces) should see them as the same.
sc_A = tnx.SimplicialComplex([[0, 1, 2], [1, 2, 3]])
sc_B = tnx.SimplicialComplex([[0, 1, 2], [1, 2, 3]])  # identical copy


# ---------------------------------------------------------------------
# 4. CellComplex: cx_A and cx_B are IDENTICAL
# ---------------------------------------------------------------------
# Both cell complexes have:
#   - vertices {0,1,2,3}
#   - edges given by edges_base
#   - two 2-cells filling the two triangles
#
# So again:
#   - 0-skeleton WL should see them as equal.
#   - rank-2 higher-order WL should see them as equal.
cx_A = tnx.CellComplex(edges_base, ranks=1)
cx_A.add_cell([0, 1, 2], rank=2)
cx_A.add_cell([1, 2, 3], rank=2)

cx_B = tnx.CellComplex(edges_base, ranks=1)
cx_B.add_cell([0, 1, 2], rank=2)
cx_B.add_cell([1, 2, 3], rank=2)


# ---------------------------------------------------------------------
# 5. CombinatorialComplex: cc_A vs cc_B differ ONLY in rank-3
# ---------------------------------------------------------------------
# cc_A and cc_B share:
#   - the same vertex set {0,1,2,3}
#   - the same edges edges_base
#   - the same rank-2 cells [0,1,2] and [1,2,3]
#
# They differ in rank-3:
#   - cc_A has exactly one rank-3 cell attached to [0,1,2]
#   - cc_B has two rank-3 cells:
#         one attached to [0,1,2]
#         one attached to [1,2,3]
#
# This is the extra information that only WL on rank-3 cells can see.
cc_A = tnx.CombinatorialComplex()
for v in [0, 1, 2, 3]:
    cc_A.add_cell([v], rank=0)
for u, v in edges_base:
    cc_A.add_cell([u, v], rank=1)
cc_A.add_cell([0, 1, 2], rank=2)
cc_A.add_cell([1, 2, 3], rank=2)
# rank-3: one higher cell over triangle (0,1,2)
cc_A.add_cell([0, 1, 2], rank=3)

cc_B = tnx.CombinatorialComplex()
for v in [0, 1, 2, 3]:
    cc_B.add_cell([v], rank=0)
for u, v in edges_base:
    cc_B.add_cell([u, v], rank=1)
cc_B.add_cell([0, 1, 2], rank=2)
cc_B.add_cell([1, 2, 3], rank=2)
# rank-3: one over (0,1,2) and one over (1,2,3)
cc_B.add_cell([0, 1, 2], rank=3)
cc_B.add_cell([1, 2, 3], rank=3)


# ---------------------------------------------------------------------
# 6. Run all WL tests and print detailed, didactic output
# ---------------------------------------------------------------------
print("\n================ CC TEST: GRAPH / SC / CX FAIL, CC SUCCEEDS ================\n")

print("Base structure:")
print("  • Two triangles (0,1,2) and (1,2,3) sharing the edge (1,2)")
print("  • Vertices: {0, 1, 2, 3}")
print("  • Edges   :", edges_base)
print()
print("We build FOUR representations of this structure (A and B each):")
print("  (1) Plain graph: only vertices + edges (no faces).")
print("  (2) SimplicialComplex  sc_A vs sc_B: two 2-simplices [0,1,2], [1,2,3].")
print("  (3) CellComplex        cx_A vs cx_B: two 2-cells filling the triangles.")
print("  (4) CombinatorialComplex cc_A vs cc_B: same 0–2 structure,")
print("      but cc_B has an extra rank-3 cell compared to cc_A.\n")

print("We now run a sequence of WL tests and compare the signatures for A vs B:")
print("  • If signatures are EQUAL: WL cannot distinguish A and B in that view.")
print("  • If signatures DIFFER   : WL successfully detects a structural difference.\n")

# 1) Plain graph WL — should be equal (same edge list)
sig_graph_A = graph_wl_signature_from_edges(edges_base)
sig_graph_B = graph_wl_signature_from_edges(edges_base)
print_check(
    "Graph WL on plain graph (edges only) A vs B",
    sig_graph_A == sig_graph_B,
    when_true="As expected: the underlying graph is identical, so 1-WL cannot distinguish them.",
    when_false="Unexpected: 1-WL claims the same edge set leads to different signatures."
)

# 2) SimplicialComplex: graph WL on 0-skeleton
sig_sc_A_graph = graph_wl_signature_on_rank0_complex(sc_A)
sig_sc_B_graph = graph_wl_signature_on_rank0_complex(sc_B)
print_check(
    "Graph WL on SimplicialComplex 0-skeleton A vs B",
    sig_sc_A_graph == sig_sc_B_graph,
    when_true="As expected: sc_A and sc_B have the same 0-skeleton graph.",
    when_false="Unexpected: WL on the vertex graph thinks sc_A and sc_B differ."
)

#    Higher-order WL on rank-2 coadjacency for SC — should be equal
sig_sc_A_2 = higher_order_wl_signature_complex(
    sc_A, "coadj", {"rank": 2, "via_rank": -1}
)
sig_sc_B_2 = higher_order_wl_signature_complex(
    sc_B, "coadj", {"rank": 2, "via_rank": -1}
)
print_check(
    "Higher-order WL on SC (rank=2 faces, coadjacency) A vs B",
    sig_sc_A_2 == sig_sc_B_2,
    when_true="As expected: sc_A and sc_B have exactly the same 2-simplices.",
    when_false="Unexpected: WL on 2-simplices claims a difference where there should be none."
)

# 3) CellComplex: graph WL on 0-skeleton
sig_cx_A_graph = graph_wl_signature_on_rank0_complex(cx_A)
sig_cx_B_graph = graph_wl_signature_on_rank0_complex(cx_B)
print_check(
    "Graph WL on CellComplex 0-skeleton A vs B",
    sig_cx_A_graph == sig_cx_B_graph,
    when_true="As expected: cx_A and cx_B share the same vertex-edge graph.",
    when_false="Unexpected: WL on the vertex graph thinks cx_A and cx_B differ."
)

#    Higher-order WL on rank-2 coadjacency for CX — should be equal
sig_cx_A_2 = higher_order_wl_signature_complex(
    cx_A, "coadj", {"rank": 2, "via_rank": -1}
)
sig_cx_B_2 = higher_order_wl_signature_complex(
    cx_B, "coadj", {"rank": 2, "via_rank": -1}
)
print_check(
    "Higher-order WL on CX (rank=2 faces, coadjacency) A vs B",
    sig_cx_A_2 == sig_cx_B_2,
    when_true="As expected: cx_A and cx_B have identical 2-cell structure.",
    when_false="Unexpected: WL on 2-cells claims a difference where there should be none."
)

# 4) CombinatorialComplex: graph WL on 0-skeleton
sig_cc_A_graph = graph_wl_signature_on_rank0_complex(cc_A)
sig_cc_B_graph = graph_wl_signature_on_rank0_complex(cc_B)
print_check(
    "Graph WL on CombinatorialComplex 0-skeleton A vs B",
    sig_cc_A_graph == sig_cc_B_graph,
    when_true="As expected: cc_A and cc_B share the same vertex-edge graph.",
    when_false="Unexpected: WL on the vertex graph thinks cc_A and cc_B differ."
)

#    Higher-order WL on rank-3 coadjacency via rank-0 — should DIFFER
sig_cc_A_3 = higher_order_wl_signature_complex(
    cc_A, "coadj", {"rank": 3, "via_rank": 0}
)
sig_cc_B_3 = higher_order_wl_signature_complex(
    cc_B, "coadj", {"rank": 3, "via_rank": 0}
)
print_check(
    "Higher-order WL on CC (rank=3 cells, coadj via vertices) A vs B",
    sig_cc_A_3 == sig_cc_B_3,
    when_true="This would mean WL on rank-3 cells failed to see the extra 3-cell in cc_B (not desired).",
    when_false="As desired: WL on rank-3 cells distinguishes cc_A (1 three-cell) from cc_B (2 three-cells)."
)

print("Summary:")
print("  • All WL tests that only see the 0-skeleton or the 2D structure (SC/CX) say A and B look the same.")
print("  • Only the higher-order WL on the CombinatorialComplex at rank 3 (with coadjacency via rank 0)")
print("    is sensitive enough to detect the extra 3-cell in cc_B.")
print("  • This illustrates how combinatorial complexes + flexible (r, k) neighborhoods give WL")
print("    strictly more expressive power than graphs, simplicial complexes, or cell complexes alone.\n")
print("===========================================================================\n")




Base structure:
  • Two triangles (0,1,2) and (1,2,3) sharing the edge (1,2)
  • Vertices: {0, 1, 2, 3}
  • Edges   : [[0, 1], [1, 2], [2, 0], [1, 3], [2, 3]]

We build FOUR representations of this structure (A and B each):
  (1) Plain graph: only vertices + edges (no faces).
  (2) SimplicialComplex  sc_A vs sc_B: two 2-simplices [0,1,2], [1,2,3].
  (3) CellComplex        cx_A vs cx_B: two 2-cells filling the triangles.
  (4) CombinatorialComplex cc_A vs cc_B: same 0–2 structure,
      but cc_B has an extra rank-3 cell compared to cc_A.

We now run a sequence of WL tests and compare the signatures for A vs B:
  • If signatures are EQUAL: WL cannot distinguish A and B in that view.
  • If signatures DIFFER   : WL successfully detects a structural difference.

✓ Graph WL on plain graph (edges only) A vs B
    Result        : True
    Interpretation: As expected: the underlying graph is identical, so 1-WL cannot distinguish them.

✓ Graph WL on SimplicialComplex 0-skeleton A vs B
    Re

In [None]:
"""Functions for computing neighborhoods of a complex."""

from typing import Literal

import toponetx as tnx
from scipy.sparse import csr_matrix, vstack, hstack


def neighborhood_from_complex(
    domain: tnx.Complex,
    neighborhood_type: Literal["adj", "coadj", "boundary", "coboundary"] = "adj",
    neighborhood_dim=None,
) -> tuple[list, csr_matrix]:
    """Compute a neighborhood matrix for a TopoNetX complex.

    This function returns the indices and matrix for the neighborhood specified by
    ``neighborhood_type`` and ``neighborhood_dim`` for the input complex ``domain``.

    Supported neighborhood types
    ----------------------------
    1. ``"adj"`` (adjacency on a single rank)
       - Nodes are cells of a fixed rank r.
       - Two r-cells are adjacent if they share suitable (co)faces, as defined
         by TopoNetX's `adjacency_matrix` implementation.

    2. ``"coadj"`` (coadjacency on a single rank)
       - Similar to `"adj"`, but using the complex's `coadjacency_matrix`.

    3. ``"boundary"`` / ``"coboundary"`` (Hasse graph from incidence)
       - Here we use the complex's `incidence_matrix` to build an undirected
         **Hasse graph** between two consecutive ranks:
           • lower rank: r-1
           • upper rank: r
       - Nodes are the union of all (r-1)-cells and all r-cells.
       - Edges connect an (r-1)-cell to an r-cell whenever the incidence is nonzero.

       For the purposes of an undirected neighborhood, `"boundary"` and
       `"coboundary"` return the **same** Hasse adjacency; they are conceptually
       different (downward vs upward), but the graph is the same.

    Parameters
    ----------
    domain : toponetx.classes.Complex
        The complex to compute the neighborhood for.
        Must be one of:
        - SimplicialComplex
        - CellComplex
        - PathComplex
        - CombinatorialComplex
        - ColoredHyperGraph
    neighborhood_type : {"adj", "coadj", "boundary", "coboundary"}, default="adj"
        The type of neighborhood to compute.
    neighborhood_dim : dict, optional
        Integer parameters specifying which rank(s) to use.

        For "adj"/"coadj":
        ------------------
        - For Simplicial/Cell/Path:
              neighborhood_dim["rank"]
          selects the rank r whose cells will be the nodes.

        - For Combinatorial/ColoredHyperGraph:
              neighborhood_dim["rank"], neighborhood_dim["via_rank"]
          specify both the rank r of the nodes and the intermediate rank
          via which adjacency/coadjacency is computed.

        For "boundary"/"coboundary":
        ----------------------------
        We use the **incidence** between rank r and rank r-1 via
        `domain.incidence_matrix`:

        - For Simplicial/Cell/Path:
              domain.incidence_matrix(rank=r, index=True)
          is assumed to return:
              ind_low, ind_high, B
          where
              * ind_low  : labels of (r-1)-cells,
              * ind_high : labels of r-cells,
              * B        : incidence matrix (shape n_low × n_high).

        - For Combinatorial/ColoredHyperGraph:
              domain.incidence_matrix(rank=r, via_rank=r-1, index=True)
          is assumed to return:
              ind_low, ind_high, B
          with the same semantics.

        If ``neighborhood_dim`` is None, we default to:
            neighborhood_dim = {"rank": 0, "via_rank": -1}

    Returns
    -------
    ind : list
        A list of the indices for the nodes in the neighborhood graph.
        - For "adj"/"coadj": indices of rank-r cells.
        - For "boundary"/"coboundary": indices of rank-(r-1) cells followed
          by indices of rank-r cells (Hasse graph nodes).
    A : scipy.sparse.csr_matrix
        The matrix representing the neighborhood.
        - For "adj"/"coadj": square (n_r × n_r) adjacency/coadjacency matrix.
        - For "boundary"/"coboundary":
          square ((n_{r-1} + n_r) × (n_{r-1} + n_r)) adjacency matrix of the
          bipartite Hasse graph induced by the incidence matrix.

    Raises
    ------
    TypeError
        If `domain` is not a supported complex type.
    TypeError
        If `neighborhood_type` is invalid.
    """
    # Default neighborhood dimensions
    if neighborhood_dim is None:
        neighborhood_dim = {"rank": 0, "via_rank": -1}

    if neighborhood_type not in ["adj", "coadj", "boundary", "coboundary"]:
        raise TypeError(
            "Input neighborhood_type must be one of "
            "'adj', 'coadj', 'boundary', or 'coboundary', "
            f"got {neighborhood_type}."
        )

    # ------------------------------------------------------------
    # Case 1: adjacency / coadjacency on a fixed rank
    # ------------------------------------------------------------
    if neighborhood_type in ["adj", "coadj"]:
        if isinstance(domain, tnx.SimplicialComplex | tnx.CellComplex | tnx.PathComplex):
            if neighborhood_type == "adj":
                ind, A = domain.adjacency_matrix(neighborhood_dim["rank"], index=True)
            else:
                ind, A = domain.coadjacency_matrix(neighborhood_dim["rank"], index=True)

        elif isinstance(domain, tnx.CombinatorialComplex | tnx.ColoredHyperGraph):
            if neighborhood_type == "adj":
                ind, A = domain.adjacency_matrix(
                    neighborhood_dim["rank"],
                    neighborhood_dim["via_rank"],
                    index=True,
                )
            else:
                ind, A = domain.coadjacency_matrix(
                    neighborhood_dim["rank"],
                    neighborhood_dim["via_rank"],
                    index=True,
                )
        else:
            raise TypeError(
                "Input Complex can only be a SimplicialComplex, CellComplex, "
                "PathComplex, ColoredHyperGraph or CombinatorialComplex."
            )

        return ind, A.asformat("csr")

    # ------------------------------------------------------------
    # Case 2: boundary / coboundary → Hasse graph from incidence_matrix
    # ------------------------------------------------------------
    if not hasattr(domain, "incidence_matrix"):
        raise TypeError(
            "The given complex does not provide an 'incidence_matrix' method, "
            "so 'boundary'/'coboundary' neighborhoods are not supported."
        )

    r = neighborhood_dim["rank"]

    # Two cases: (SC / Cell / Path) vs (CC / ColoredHyperGraph)
    if isinstance(domain, tnx.SimplicialComplex | tnx.CellComplex | tnx.PathComplex):
        # Expected signature:
        #   ind_low, ind_high, B = domain.incidence_matrix(rank=r, index=True)
        ind_low, ind_high, B = domain.incidence_matrix(r, index=True)  # type: ignore[arg-type]
    elif isinstance(domain, tnx.CombinatorialComplex | tnx.ColoredHyperGraph):
        # Expected signature:
        #   ind_low, ind_high, B = domain.incidence_matrix(rank=r, via_rank=r-1, index=True)
        via = neighborhood_dim.get("via_rank", r - 1)
        ind_low, ind_high, B = domain.incidence_matrix(  # type: ignore[arg-type]
            rank=r, via_rank=via, index=True
        )
    else:
        raise TypeError(
            "Input Complex can only be a SimplicialComplex, CellComplex, "
            "PathComplex, ColoredHyperGraph or CombinatorialComplex."
        )

    # Make sure B is CSR and unsigned
    B = abs(B).asformat("csr")
    n_low, n_high = B.shape

    # Build bipartite Hasse adjacency:
    #   [ 0   B ]
    #   [ B^T 0 ]
    zero_low = csr_matrix((n_low, n_low))
    zero_high = csr_matrix((n_high, n_high))

    upper = hstack([zero_low, B])
    lower = hstack([B.transpose(), zero_high])
    A_hasse = vstack([upper, lower]).asformat("csr")

    # Node indices = (r-1)-cells followed by r-cells
    ind = list(ind_low) + list(ind_high)

    return ind, A_hasse


In [None]:
import toponetx as tnx
from scipy.sparse import csr_matrix

# If neighborhood_from_complex is in the same notebook, no import is needed.
# Otherwise:
# from your_module import neighborhood_from_complex

def check(msg, cond):
    print(("✓" if cond else "✗"), msg)

print("\n================ NEIGHBORHOOD_FROM_COMPLEX TESTS ================\n")

# -------------------------------------------------------------------
# 1. SimplicialComplex: adjacency and coadjacency
# -------------------------------------------------------------------
print("Test 1: SimplicialComplex adjacency / coadjacency on rank 1")

sc = tnx.SimplicialComplex([[0, 1, 2], [1, 2, 3]])

ind_adj, A_adj = neighborhood_from_complex(
    sc,
    neighborhood_type="adj",
    neighborhood_dim={"rank": 1, "via_rank": -1},
)
check("A_adj is CSR", isinstance(A_adj, csr_matrix))
check("A_adj is square", A_adj.shape[0] == A_adj.shape[1])
check("len(ind_adj) == A_adj.shape[0]", len(ind_adj) == A_adj.shape[0])
print("  rank-1 cells indices:", ind_adj)
print("  A_adj shape:", A_adj.shape, "\n")

ind_coadj, A_coadj = neighborhood_from_complex(
    sc,
    neighborhood_type="coadj",
    neighborhood_dim={"rank": 2, "via_rank": -1},
)
check("A_coadj is CSR", isinstance(A_coadj, csr_matrix))
check("A_coadj is square", A_coadj.shape[0] == A_coadj.shape[1])
print("  rank-2 cells indices:", ind_coadj)
print("  A_coadj shape:", A_coadj.shape, "\n")

# -------------------------------------------------------------------
# 2. SimplicialComplex: boundary / coboundary via incidence_matrix
# -------------------------------------------------------------------
print("Test 2: SimplicialComplex boundary / coboundary (Hasse graph from incidence)")

# Single triangle so incidence_matrix(rank=1) should connect vertices (0-cells) and edges (1-cells)
sc2 = tnx.SimplicialComplex([[0, 1, 2]])

ind_b, A_b = neighborhood_from_complex(
    sc2,
    neighborhood_type="boundary",
    neighborhood_dim={"rank": 1},  # use incidence between rank-1 and rank-0
)
check("A_b is CSR", isinstance(A_b, csr_matrix))
check("A_b is square", A_b.shape[0] == A_b.shape[1])
check("A_b has nonzero entries", A_b.nnz > 0)
print("  boundary-Hasse indices (rank 0 + rank 1):", ind_b)
print("  A_b shape:", A_b.shape, "\n")

ind_cb, A_cb = neighborhood_from_complex(
    sc2,
    neighborhood_type="coboundary",
    neighborhood_dim={"rank": 1},
)
check("A_cb is CSR", isinstance(A_cb, csr_matrix))
check("A_cb has same shape as A_b", A_cb.shape == A_b.shape)
check("A_cb equals A_b (undirected Hasse)", (A_cb != A_b).nnz == 0)
print("  coboundary-Hasse indices:", ind_cb)
print("  A_cb shape:", A_cb.shape, "\n")

# -------------------------------------------------------------------
# 3. CombinatorialComplex: adj / coadj + boundary (if incidence defined)
# -------------------------------------------------------------------
print("Test 3: CombinatorialComplex adj / coadj / boundary")

cc = tnx.CombinatorialComplex()
# vertices
for v in [0, 1, 2, 3]:
    cc.add_cell([v], rank=0)
# edges
for e in [(0, 1), (1, 2), (2, 3), (3, 0)]:
    cc.add_cell(list(e), rank=1)
# 2-cells
cc.add_cell([0, 1, 2], rank=2)
cc.add_cell([0, 2, 3], rank=2)

# rank-1 adjacency via rank-2
ind_cc_adj, A_cc_adj = neighborhood_from_complex(
    cc,
    neighborhood_type="adj",
    neighborhood_dim={"rank": 1, "via_rank": 2},
)
check("CC rank-1 adjacency is CSR", isinstance(A_cc_adj, csr_matrix))
check("CC rank-1 adjacency is square", A_cc_adj.shape[0] == A_cc_adj.shape[1])
print("  CC rank-1 nodes:", ind_cc_adj)
print("  A_cc_adj shape:", A_cc_adj.shape, "\n")

# rank-1 coadjacency via rank-0
ind_cc_coadj, A_cc_coadj = neighborhood_from_complex(
    cc,
    neighborhood_type="coadj",
    neighborhood_dim={"rank": 1, "via_rank": 0},
)
check("CC rank-1 coadjacency is CSR", isinstance(A_cc_coadj, csr_matrix))
check("CC rank-1 coadjacency is square", A_cc_coadj.shape[0] == A_cc_coadj.shape[1])
print("  CC rank-1 nodes (coadj):", ind_cc_coadj)
print("  A_cc_coadj shape:", A_cc_coadj.shape, "\n")

# boundary between rank 1 and 2 (if incidence_matrix is implemented)
try:
    ind_cc_b, A_cc_b = neighborhood_from_complex(
        cc,
        neighborhood_type="boundary",
        neighborhood_dim={"rank": 2},  # incidence between rank 2 and rank 1
    )
    check("CC boundary-Hasse is CSR", isinstance(A_cc_b, csr_matrix))
    check("CC boundary-Hasse is square", A_cc_b.shape[0] == A_cc_b.shape[1])
    print("  CC boundary-Hasse nodes (rank 1 + rank 2):", ind_cc_b)
    print("  A_cc_b shape:", A_cc_b.shape, "\n")
except TypeError as e:
    print("  (Skipped boundary test for CC: incidence_matrix not implemented for this complex)")
    print("  Error:", e, "\n")

print("=============================================================\n")




Test 1: SimplicialComplex adjacency / coadjacency on rank 1
✓ A_adj is CSR
✓ A_adj is square
✓ len(ind_adj) == A_adj.shape[0]
  rank-1 cells indices: {(0, 1): 0, (0, 2): 1, (1, 2): 2, (1, 3): 3, (2, 3): 4}
  A_adj shape: (5, 5) 

✓ A_coadj is CSR
✓ A_coadj is square
  rank-2 cells indices: {(0, 1, 2): 0, (1, 2, 3): 1}
  A_coadj shape: (2, 2) 

Test 2: SimplicialComplex boundary / coboundary (Hasse graph from incidence)
✓ A_b is CSR
✓ A_b is square
✓ A_b has nonzero entries
  boundary-Hasse indices (rank 0 + rank 1): [(0,), (1,), (2,), (0, 1), (0, 2), (1, 2)]
  A_b shape: (6, 6) 

✓ A_cb is CSR
✓ A_cb has same shape as A_b
✓ A_cb equals A_b (undirected Hasse)
  coboundary-Hasse indices: [(0,), (1,), (2,), (0, 1), (0, 2), (1, 2)]
  A_cb shape: (6, 6) 

Test 3: CombinatorialComplex adj / coadj / boundary
✓ CC rank-1 adjacency is CSR
✓ CC rank-1 adjacency is square
  CC rank-1 nodes: OrderedDict({frozenset({0, 1}): 0, frozenset({1, 2}): 1, frozenset({2, 3}): 2, frozenset({0, 3}): 3})
  A

In [None]:
"""Testing the neighborhood module."""

import pytest
import toponetx as tnx

import topoembedx as tex


class TestNeighborhood:
    """Test the neighborhood module of TopoEmbedX."""

    def test_neighborhood_from_complex_raise_error(self):
        """Testing if right assertion is raised for incorrect type."""
        with pytest.raises(TypeError) as e:
            tex.neighborhood.neighborhood_from_complex(1)

        assert (
            str(e.value)
            == """Input Complex can only be a SimplicialComplex, CellComplex, PathComplex ColoredHyperGraph or CombinatorialComplex."""
        )

    def test_neighborhood_from_complex_matrix_dimension_cell_complex(self):
        """Testing the matrix dimensions for the adjacency and coadjacency matrices."""
        # Testing for the case of Cell Complex
        cc1 = tnx.classes.CellComplex(
            [[0, 1, 2, 3], [1, 2, 3, 4], [1, 3, 4, 5, 6, 7, 8]]
        )

        cc2 = tnx.classes.CellComplex([[0, 1, 2], [1, 2, 3]])

        ind, A = tex.neighborhood.neighborhood_from_complex(cc1)
        assert A.todense().shape == (9, 9)
        assert len(ind) == 9

        ind, A = tex.neighborhood.neighborhood_from_complex(cc2)
        assert A.todense().shape == (4, 4)
        assert len(ind) == 4

        ind, A = tex.neighborhood.neighborhood_from_complex(
            cc1, neighborhood_type="coadj"
        )
        assert A.todense().shape == (9, 9)
        assert len(ind) == 9

        ind, A = tex.neighborhood.neighborhood_from_complex(
            cc2, neighborhood_type="coadj"
        )
        assert A.todense().shape == (4, 4)
        assert len(ind) == 4