# PH template

In [40]:
import numpy as np
import networkx as nx
from tqdm import tqdm
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score, confusion_matrix
import xgboost as xgb

from torch_geometric.datasets import TUDataset

import pyflagser


## Set Thresholds

In [41]:
#r.v. for filtration 
N_THRESHOLD = 12
USE_QTY = True

# these are the limits for the PH filtration
MIN_DIM = 0
MAX_DIM = 2

RANDOM_STATE = 42

## Betti Curvature Computation for Active Noted
here,
- active nodes: are in the form of an iterable node of indices
- full_graph is the nx graph of the original
- return: betti tuple consolidating (b0, b1, b2) i.e., betti-1, betti-2, betti-3

In [42]:
def compute_betti_curvature(active_nodes, full_graph):
    G = full_graph.subgraph(sorted(active_nodes)).copy()

    # convert induced subgraph of active nodes into adjacency matrix
    nodelist = sorted(G.nodes())

    # validatory checks
    if len(nodelist) == 0:
        return (0,0,0)

    adj_mat = nx.to_numpy_array(G, nodelist = nodelist)

    # refer to pyflagser
    my_flag = pyflagser.flagser_unweighted(
        adj_mat, min_dimension= MIN_DIM, max_dimension = MAX_DIM, 
        directed=False, coeff = 2, approximation=None
    )

    x = my_flag.get("betti", [])
    b0 = int(x[0]) if len(x) > 0 else 0
    b1 = int (x[1]) if len(x) > 1 else 0
    b2 = int(x[2]) if len(x) > 2 else 0
    return (b0, b1, b2)

### Graph from pyflagser object

In [43]:
def graph_from_pyg_data(data):
    """
    Construct a networkx Graph from a PyG data object.
    Assumes node features/data.x contain atomic numbers in a known column. 
    MUTAG commonly stores node labels in data.x or data.z.
    We'll attempt common places: data.x (first column), data.z, or data.node_attr if present.
    """
    G = nx.Graph()
    num_nodes = data.num_nodes

    # Add nodes
    for i in range(num_nodes):
        G.add_node(i)

    # Add edges
    edge_index = data.edge_index.numpy()
    for u, v in edge_index.T:
        G.add_edge(int(u), int(v))

    # node attributes / atomic number reading heuristics
    atomic_numbers = None
    if hasattr(data, 'x') and data.x is not None:
        x_np = data.x.numpy()
        # If single feature per node, use directly
        if x_np.ndim == 1 or x_np.shape[1] == 1:
            atomic_numbers = x_np.reshape(-1)
        else:
            # If one-hot encoded, use argmax to get label index
            atomic_numbers = np.argmax(x_np, axis=1)

    # fallback if atomic numbers in data.z
    if atomic_numbers is None and hasattr(data, 'z'):
        atomic_numbers = data.z.numpy().reshape(-1)

    if atomic_numbers is None:
        # default: set all to 0 (should not happen for MUTAG)
        atomic_numbers = np.zeros(num_nodes, dtype=float)

    # attach atomic numbers as node attributes
    for i, val in enumerate(atomic_numbers):
        G.nodes[i]['atomic_number'] = float(val)

    return G

### Thresholds for dataset

In [44]:
def build_thresholds_for_dataset(all_node_values, n_thresholds=N_THRESHOLD, quantile=USE_QTY):
    vals = np.array(all_node_values)
    if quantile:
        qs = np.linspace(0.0, 1.0, n_thresholds)
        thresholds = np.quantile(vals, qs)
    else:
        thresholds = np.linspace(vals.min(), vals.max(), n_thresholds)
    # make unique and sorted
    thresholds = np.unique(np.asarray(thresholds))
    return thresholds

In [45]:
def main():
    # Load MUTAG
    dataset = TUDataset(root='data/TUDataset', name='MUTAG')
    print(f"Loaded MUTAG with {len(dataset)} graphs.")

    graphs = []
    labels = []
    all_node_values = []

    # convert each PyG graph to networkx and record atomic numbers across dataset
    for data in dataset:
        G = graph_from_pyg_data(data)
        graphs.append(G)
        labels.append(int(data.y.item()))
        # collect node atomic numbers
        values = [G.nodes[n].get('atomic_number', 0.0) for n in G.nodes()]
        all_node_values.extend(values)

    thresholds = build_thresholds_for_dataset(all_node_values, n_thresholds=N_THRESHOLD, quantile=USE_QTY)
    print("Using thresholds:", thresholds)

    # For each graph, compute betti curves across thresholds
    features = []   # each element will be concatenated betti0..2 curves + summary stats
    for G in tqdm(graphs, desc="Graphs"):
        node_vals = {n: G.nodes[n].get('atomic_number', 0.0) for n in G.nodes()}
        betti0_curve = []
        betti1_curve = []
        betti2_curve = []
        for eps in thresholds:
            # sublevel: nodes with value <= eps
            active_nodes = [n for n, v in node_vals.items() if v <= eps]
            b0, b1, b2 = compute_betti_curvature(active_nodes, G)
            betti0_curve.append(b0)
            betti1_curve.append(b1)
            betti2_curve.append(b2)
        # vectorize: baseline = concatenated curves + summary stats
        feat = []
        feat.extend(betti0_curve)
        feat.extend(betti1_curve)
        feat.extend(betti2_curve)
        # add simple summary stats
        feat.append(np.mean(betti0_curve))
        feat.append(np.max(betti0_curve))
        feat.append(np.sum(betti0_curve))
        feat.append(np.mean(betti1_curve))
        feat.append(np.max(betti1_curve))
        feat.append(np.sum(betti1_curve))
        feat.append(np.mean(betti2_curve))
        feat.append(np.max(betti2_curve))
        feat.append(np.sum(betti2_curve))
        features.append(np.array(feat, dtype=float))

    X = np.vstack(features)
    y = np.array(labels)

    # train-test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.3, random_state=RANDOM_STATE)

    clf = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=RANDOM_STATE)
    clf.fit(X_train, y_train)

    y_pred = clf.predict(X_test)
    y_proba = clf.predict_proba(X_test)[:, 1] if clf.n_classes_ > 1 else clf.predict_proba(X_test)[:, 0]

    acc = accuracy_score(y_test, y_pred)
    bal_acc = balanced_accuracy_score(y_test, y_pred)
    try:
        roc = roc_auc_score(y_test, y_proba)
    except Exception:
        roc = None

    cm = confusion_matrix(y_test, y_pred)

    print("Results:")
    print("Accuracy:", acc)
    print("Balanced Accuracy:", bal_acc)
    print("ROC AUC:", roc)
    print("Confusion matrix:\n", cm)

if __name__ == "__main__":
    main()

Loaded MUTAG with 188 graphs.
Using thresholds: [0. 1. 2. 6.]


Graphs: 100%|████████████████████████████████████████████████████████████████████████| 188/188 [00:05<00:00, 36.69it/s]
Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)


Results:
Accuracy: 0.7894736842105263
Balanced Accuracy: 0.736842105263158
ROC AUC: 0.8753462603878116
Confusion matrix:
 [[11  8]
 [ 4 34]]
