In [None]:
!pip install torch torch_geometric networkx numpy scipy scikit-learn xgboost tqdm GraphRicciCurvature


"""
mutag_hks_forman_mlp.py

Pipeline:
 - load MUTAG from torch_geometric.TUDataset
 - convert graphs to networkx with atomic_number node attribute
 - compute HKS (multiple time scales) via Laplacian eigendecomposition
 - compute Forman-Ricci curvature via GraphRicciCurvature.FormanRicci
 - aggregate features per graph, train MLPClassifier, print metrics
"""

import numpy as np
import networkx as nx
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score, confusion_matrix
from sklearn.neural_network import MLPClassifier
from torch_geometric.datasets import TUDataset
from scipy.linalg import eigh  # for symmetric eigenproblem
from GraphRicciCurvature.FormanRicci import FormanRicci  # package: GraphRicciCurvature

# ---------------- SETTINGS (one visible place to change) ----------------
RANDOM_STATE = 42

# HKS time scales (edit here)
HKS_TIMES = np.logspace(-2, 2, 6)   # e.g., [1e-2, 1e-1, 1e0, 1e1, 1e2]; tune as needed

# Laplacian type: 'combinatorial' or 'normalized'
LAPLACIAN_TYPE = 'combinatorial'

# Number of eigenpairs to use. For small graphs set None to compute all.
NUM_EIGENPAIRS = None  # set e.g., 20 to truncate

# MLP params
MLP_HIDDEN = (100,)      # you can change to (200,100) etc.
MLP_MAX_ITER = 500
TEST_SIZE = 0.3
# ------------------------------------------------------------------------

def graph_from_pyg_data(data):
    """Convert PyG Data to a networkx Graph with node attribute 'atomic_number'."""
    G = nx.Graph()
    n = data.num_nodes
    G.add_nodes_from(range(n))
    # edges
    edge_index = data.edge_index.numpy()
    for u, v in edge_index.T:
        G.add_edge(int(u), int(v))
    # atomic numbers heuristics (same as before)
    atomic_numbers = None
    if hasattr(data, 'x') and data.x is not None:
        x_np = data.x.numpy()
        if x_np.ndim == 1 or x_np.shape[1] == 1:
            atomic_numbers = x_np.reshape(-1)
        else:
            atomic_numbers = np.argmax(x_np, axis=1)
    if atomic_numbers is None and hasattr(data, 'z'):
        atomic_numbers = data.z.numpy().reshape(-1)
    if atomic_numbers is None:
        atomic_numbers = np.zeros(n, dtype=float)
    for i, val in enumerate(atomic_numbers):
        G.nodes[i]['atomic_number'] = float(val)
    return G

def build_laplacian_matrix(G, lap_type='combinatorial'):
    """Return Laplacian matrix as numpy array (n x n)."""
    A = nx.to_numpy_array(G, nodelist=sorted(G.nodes()))
    degs = A.sum(axis=1)
    if lap_type == 'combinatorial':
        D = np.diag(degs)
        L = D - A
    elif lap_type == 'normalized':
        # symmetric normalized Laplacian L_sym = I - D^{-1/2} A D^{-1/2}
        with np.errstate(divide='ignore'):
            d_root_inv = 1.0 / np.sqrt(degs)
        d_root_inv[np.isinf(d_root_inv)] = 0.0
        Dri = np.diag(d_root_inv)
        L = np.eye(A.shape[0]) - Dri @ A @ Dri
    else:
        raise ValueError("lap_type must be 'combinatorial' or 'normalized'")
    return L

def compute_hks_for_graph(G, times=HKS_TIMES, lap_type=LAPLACIAN_TYPE, num_eigs=NUM_EIGENPAIRS):
    """
    Compute HKS for each node and each time in times.
    Returns: hks_matrix of shape (n_nodes, len(times))
    """
    nodelist = sorted(G.nodes())
    n = len(nodelist)
    if n == 0:
        return np.zeros((0, len(times)))
    L = build_laplacian_matrix(G, lap_type=lap_type)
    # eigh returns eigenvalues in ascending order for symmetric matrices
    if num_eigs is None or num_eigs >= n:
        # compute all eigenpairs
        lam, phi = eigh(L)
    else:
        # compute smallest num_eigs eigenpairs via eigh for the full matrix then slice
        lam_full, phi_full = eigh(L)
        lam = lam_full[:num_eigs]
        phi = phi_full[:, :num_eigs]
    # ensure numerical non-negativity
    lam = np.clip(lam, 0.0, None)
    # compute h_t(x) = sum_i exp(-lam_i * t) * phi_i(x)^2
    hks = np.zeros((n, len(times)))
    # phi shape: (n, m), lam shape: (m,)
    for ti, t in enumerate(times):
        # compute weights w_i = exp(-lam_i * t)
        w = np.exp(-lam * t)
        # contribution per eigenfunction: w_i * phi[:, i]^2
        # if phi has shape (n, m) and w shape (m,), compute (phi**2) @ w
        hks[:, ti] = (phi ** 2) @ w
    return hks

def compute_forman_curvature_stats(G):
    """
    Compute Forman-Ricci curvature on edges and aggregate statistics.
    Uses GraphRicciCurvature.FormanRicci
    Returns dict of aggregated stats
    """
    # Make a shallow copy to avoid destroying original attributes in some contexts
    G_copy = G.copy()
    frc = FormanRicci(G_copy, verbose="ERROR")
    # compute_ricci_curvature will add 'formanCurvature' attributes to edges and nodes
    frc.compute_ricci_curvature()
    # collect edge curvatures
    edge_vals = []
    for u, v, d in G_copy.edges(data=True):
        val = d.get('formanCurvature', None)
        if val is not None:
            edge_vals.append(float(val))
    edge_vals = np.array(edge_vals) if len(edge_vals) > 0 else np.array([0.0])
    # collect node curvatures (if present) otherwise compute average of incident edges
    node_vals = []
    for node in G_copy.nodes():
        attr = G_copy.nodes[node].get('formanCurvature', None)
        if attr is None:
            # average incident edges if any
            incident = [G_copy[u][v].get('formanCurvature', 0.0) for u, v in G_copy.edges(node)]
            node_vals.append(np.mean(incident) if len(incident) > 0 else 0.0)
        else:
            node_vals.append(float(attr))
    node_vals = np.array(node_vals) if len(node_vals) > 0 else np.array([0.0])
    stats = {
        'edge_mean': float(np.mean(edge_vals)),
        'edge_std': float(np.std(edge_vals)),
        'edge_min': float(np.min(edge_vals)),
        'edge_max': float(np.max(edge_vals)),
        'node_mean': float(np.mean(node_vals)),
        'node_std': float(np.std(node_vals)),
        'node_min': float(np.min(node_vals)),
        'node_max': float(np.max(node_vals)),
        'edge_pos_frac': float((edge_vals > 0).mean()),
        'edge_neg_frac': float((edge_vals < 0).mean()),
    }
    return stats

def vectorize_graph(G):
    """
    Compute HKS + Forman features, return 1D numpy array feature vector.
    """
    # HKS
    hks = compute_hks_for_graph(G, times=HKS_TIMES, lap_type=LAPLACIAN_TYPE, num_eigs=NUM_EIGENPAIRS)
    if hks.size == 0:
        # fallback for empty graph
        hks_stats = np.zeros(3 * len(HKS_TIMES))  # mean,std,max per time
    else:
        # for each time, compute mean,std,max across nodes
        means = np.mean(hks, axis=0)
        stds = np.std(hks, axis=0)
        maxs = np.max(hks, axis=0)
        hks_stats = np.concatenate([means, stds, maxs])  # length 3*len(T)
    # Forman curvature stats
    fstats = compute_forman_curvature_stats(G)
    forman_vec = np.array([
        fstats['edge_mean'], fstats['edge_std'], fstats['edge_min'], fstats['edge_max'],
        fstats['node_mean'], fstats['node_std'], fstats['node_min'], fstats['node_max'],
        fstats['edge_pos_frac'], fstats['edge_neg_frac']
    ], dtype=float)
    # optional: add baseline graph features
    n_nodes = G.number_of_nodes()
    n_edges = G.number_of_edges()
    avg_deg = 2.0 * n_edges / n_nodes if n_nodes > 0 else 0.0
    # atomic-number basic stats
    atomic_vals = np.array([G.nodes[n].get('atomic_number', 0.0) for n in sorted(G.nodes())]) if n_nodes > 0 else np.array([0.0])
    atomic_stats = np.array([
        atomic_vals.mean(), atomic_vals.std(), atomic_vals.min(), atomic_vals.max()
    ], dtype=float)
    # concatenate all
    feat = np.concatenate([hks_stats, forman_vec, np.array([n_nodes, n_edges, avg_deg]), atomic_stats])
    return feat

def main():
    dataset = TUDataset(root='data/TUDataset', name='MUTAG')
    print("Loaded MUTAG with", len(dataset), "graphs.")
    graphs = []
    labels = []
    for data in dataset:
        G = graph_from_pyg_data(data)
        graphs.append(G)
        labels.append(int(data.y.item()))
    # vectorize
    feats = []
    for G in tqdm(graphs, desc="Vectorizing graphs"):
        feats.append(vectorize_graph(G))
    X = np.vstack(feats)
    y = np.array(labels)
    print("Feature shape:", X.shape)
    # train/test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
    # MLP
    mlp = MLPClassifier(hidden_layer_sizes=MLP_HIDDEN, max_iter=MLP_MAX_ITER, random_state=RANDOM_STATE)
    mlp.fit(X_train, y_train)
    y_pred = mlp.predict(X_test)
    # for ROC AUC, need probabilities and binary classification
    y_proba = mlp.predict_proba(X_test)[:, 1] if mlp.classes_.shape[0] > 1 else mlp.predict_proba(X_test)[:, 0]
    acc = accuracy_score(y_test, y_pred)
    bal = balanced_accuracy_score(y_test, y_pred)
    try:
        roc = roc_auc_score(y_test, y_proba)
    except Exception:
        roc = None
    cm = confusion_matrix(y_test, y_pred)
    print("Results:")
    print("Accuracy:", acc)
    print("Balanced Accuracy:", bal)
    print("ROC AUC:", roc)
    print("Confusion matrix:\n", cm)

if __name__ == "__main__":
    main()
