In [1]:
# multipersistence_mutag.py
# Multipersistence snapshot (2-parameter) pipeline for MUTAG:
#  - parameter A: atomic number (node attribute)
#  - parameter B: node degree (computed on original graph)
#  - compute Betti (0,1,2) at each (eps_A, eps_B) grid point
#  - vectorize surfaces, train XGBoost, report metrics

import numpy as np
import networkx as nx
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score, confusion_matrix
import xgboost as xgb

from torch_geometric.datasets import TUDataset
import pyflagser

# -------------------------
# User-tunable settings
# -------------------------
N_THRESH_A = 6   # number of thresholds along parameter A (atomic number)
N_THRESH_B = 6   # number of thresholds along parameter B (degree)
USE_QUANTILES = True  # use quantile-based thresholds across dataset (recommended)
MIN_DIM = 0
MAX_DIM = 2
RANDOM_STATE = 42

# -------------------------
# Helper functions
# -------------------------
def graph_from_pyg_data(data):
    """
    Build networkx graph and attach atomic_number attribute if available.
    Fallbacks: data.x (single scalar or one-hot -> argmax), data.z
    """
    G = nx.Graph()
    n = data.num_nodes
    for i in range(n):
        G.add_node(i)
    edge_index = data.edge_index.numpy()
    for u, v in edge_index.T:
        G.add_edge(int(u), int(v))

    # try to extract node scalar labels / atomic numbers
    atomic_numbers = None
    if hasattr(data, 'x') and data.x is not None:
        x_np = data.x.numpy()
        if x_np.ndim == 1 or (x_np.ndim == 2 and x_np.shape[1] == 1):
            atomic_numbers = x_np.reshape(-1)
        elif x_np.ndim == 2:
            # assume one-hot -> argmax label
            atomic_numbers = np.argmax(x_np, axis=1)
    if atomic_numbers is None and hasattr(data, 'z'):
        atomic_numbers = data.z.numpy().reshape(-1)
    if atomic_numbers is None:
        atomic_numbers = np.zeros(n, dtype=float)

    for i, val in enumerate(atomic_numbers):
        G.nodes[i]['atomic_number'] = float(val)

    # also attach original degree (on full graph) as a separate node attribute
    for i in G.nodes():
        G.nodes[i]['orig_degree'] = float(G.degree(i))

    return G

def compute_betti_numbers_for_active(active_nodes, full_graph):
    """
    Given list of active node indices and the full graph (networkx),
    compute Betti numbers up to MAX_DIM using pyflagser.flagser_unweighted.
    """
    if len(active_nodes) == 0:
        return (0, 0, 0)
    H = full_graph.subgraph(sorted(active_nodes)).copy()
    nodelist = sorted(H.nodes())
    if len(nodelist) == 0:
        return (0, 0, 0)
    Adj = nx.to_numpy_array(H, nodelist=nodelist)
    my_flag = pyflagser.flagser_unweighted(
        Adj, min_dimension=MIN_DIM, max_dimension=MAX_DIM,
        directed=False, coeff=2, approximation=None
    )
    x = my_flag.get("betti", [])
    b0 = int(x[0]) if len(x) > 0 else 0
    b1 = int(x[1]) if len(x) > 1 else 0
    b2 = int(x[2]) if len(x) > 2 else 0
    return (b0, b1, b2)

def build_thresholds(values, n_thresholds, use_quantiles=True):
    """
    Build sorted unique threshold array for given values across dataset.
    """
    vals = np.asarray(values)
    if vals.size == 0:
        return np.array([0.0])
    if use_quantiles:
        qs = np.linspace(0.0, 1.0, n_thresholds)
        thr = np.quantile(vals, qs)
    else:
        thr = np.linspace(vals.min(), vals.max(), n_thresholds)
    return np.unique(thr)

def vectorize_betti_surfaces(betti0, betti1, betti2):
    """
    Vectorize 2D Betti surfaces into 1D feature vector.
    Strategy:
      - flatten each surface (row-major)
      - append summary stats (mean, max, sum, variance) per dimension
    """
    f = []
    for surf in (betti0, betti1, betti2):
        flat = surf.flatten()
        f.extend(flat.tolist())
        # basic summaries
        f.append(np.mean(flat))
        f.append(np.max(flat))
        f.append(np.sum(flat))
        f.append(np.var(flat))
    return np.array(f, dtype=float)

# -------------------------
# Main pipeline
# -------------------------
def main():
    # Load MUTAG
    dataset = TUDataset(root='data/TUDataset', name='MUTAG')
    print(f"Loaded MUTAG with {len(dataset)} graphs.")

    graphs = []
    labels = []
    all_atomic_vals = []
    all_degree_vals = []

    # convert and collect global distributions
    for data in dataset:
        G = graph_from_pyg_data(data)
        graphs.append(G)
        labels.append(int(data.y.item()))
        all_atomic_vals.extend([G.nodes[n]['atomic_number'] for n in G.nodes()])
        all_degree_vals.extend([G.nodes[n]['orig_degree'] for n in G.nodes()])

    # Build thresholds for both parameters
    thr_A = build_thresholds(all_atomic_vals, N_THRESH_A, use_quantiles=USE_QUANTILES)
    thr_B = build_thresholds(all_degree_vals, N_THRESH_B, use_quantiles=USE_QUANTILES)
    print("Thresholds A (atomic number):", thr_A)
    print("Thresholds B (degree):", thr_B)

    features = []
    for G in tqdm(graphs, desc="Computing multipersistence surfaces"):
        # For each graph build 2D Betti surfaces sized (len(thr_A), len(thr_B))
        nA = len(thr_A)
        nB = len(thr_B)
        betti0_surf = np.zeros((nA, nB), dtype=int)
        betti1_surf = np.zeros((nA, nB), dtype=int)
        betti2_surf = np.zeros((nA, nB), dtype=int)

        # prefetch node attribute dicts
        node_atomic = {n: G.nodes[n].get('atomic_number', 0.0) for n in G.nodes()}
        node_degree = {n: G.nodes[n].get('orig_degree', float(G.degree(n))) for n in G.nodes()}

        for iA, a_eps in enumerate(thr_A):
            for iB, b_eps in enumerate(thr_B):
                # sublevel-sublevel: nodes with atomic <= a_eps AND degree <= b_eps
                active_nodes = [n for n in G.nodes() if (node_atomic[n] <= a_eps and node_degree[n] <= b_eps)]
                b0, b1, b2 = compute_betti_numbers_for_active(active_nodes, G)
                betti0_surf[iA, iB] = b0
                betti1_surf[iA, iB] = b1
                betti2_surf[iA, iB] = b2

        # vectorize surfaces into fixed-size feature vector
        feat = vectorize_betti_surfaces(betti0_surf, betti1_surf, betti2_surf)
        features.append(feat)

    X = np.vstack(features)
    y = np.array(labels)

    # Train/test split and XGBoost classifier
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.3, random_state=RANDOM_STATE)

    clf = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=RANDOM_STATE)
    clf.fit(X_train, y_train)

    y_pred = clf.predict(X_test)
    # handle binary or multi-class probability extraction safely
    try:
        if clf.n_classes_ > 1:
            y_proba = clf.predict_proba(X_test)[:, 1]
        else:
            y_proba = clf.predict_proba(X_test)[:, 0]
    except Exception:
        y_proba = None

    acc = accuracy_score(y_test, y_pred)
    bal_acc = balanced_accuracy_score(y_test, y_pred)
    roc = roc_auc_score(y_test, y_proba) if (y_proba is not None and len(np.unique(y_test)) == 2) else None
    cm = confusion_matrix(y_test, y_pred)

    print("=== RESULTS ===")
    print("Accuracy:", acc)
    print("Balanced Accuracy:", bal_acc)
    print("ROC AUC:", roc)
    print("Confusion matrix:\n", cm)

    # return useful objects for inspection/visualization
    return {
        "model": clf,
        "thr_A": thr_A,
        "thr_B": thr_B,
        "X": X,
        "y": y
    }

if __name__ == "__main__":
    res = main()


Loaded MUTAG with 188 graphs.
Thresholds A (atomic number): [0. 1. 6.]
Thresholds B (degree): [1. 2. 3. 4.]


Computing multipersistence surfaces: 100%|███████████████████████████████████████████| 188/188 [00:24<00:00,  7.65it/s]
Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)


=== RESULTS ===
Accuracy: 0.8771929824561403
Balanced Accuracy: 0.881578947368421
ROC AUC: 0.9030470914127423
Confusion matrix:
 [[17  2]
 [ 5 33]]
