
# ThreeGraphX — Transductive Explanation Demo (QM9, SchNet/DimeNet++)

This tutorial shows an end-to-end **transductive** explanation workflow:

1. Load QM9 and a backbone model (SchNet or DimeNet++).
2. Visualize the molecule graph with node order and element type.
3. Run the transductive explainer to optimize a node mask for this molecule.
4. Inspect & visualize the explanation (mask heatmap, explanatory subgraph, clusters).
5. (Optional) Evaluate top-k fidelity against the full-model prediction.

> **Note:** This notebook assumes you're running from the repo root with the package at `src/threegraphx/`. If not, adapt the `PYTHONPATH` cell accordingly.


In [None]:

# If you're running on Google Colab, uncomment the next lines to install dependencies.
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
# !pip install networkx matplotlib

import os, sys, os.path as osp, numpy as np, torch
print("Torch:", torch.__version__, "| CUDA:", torch.cuda.is_available())
device = "cuda" if torch.cuda.is_available() else "cpu"
device


In [None]:

# Make the local package importable (expects repo layout with src/threegraphx/...)
import os, sys, os.path as osp
HERE = os.getcwd()
SRC = osp.join(HERE, "src")
if SRC not in sys.path:
    sys.path.insert(0, SRC)

from torch_geometric.datasets import QM9
from torch_geometric.nn import knn_graph
from torch_geometric.explain import Explainer, ModelConfig
from torch_geometric.nn.models.schnet import SchNet
from torch_geometric.nn.models.dimenet import DimeNet, DimeNetPlusPlus

# Import our project modules
from threegraphx.hooks.base import MaskPoint
from threegraphx.hooks.schnet import SchNetHooks
from threegraphx.hooks.dimenet import DimeNetHooks
from threegraphx.transductive import GraphXTransductive
from threegraphx.viz import visualize_graph, visualize_mask, visualize_clusters, visualize_explanatory_subgraph


In [None]:

# ---- Configuration ----
BACKBONE = "schnet"      # choices: "schnet", "dimenet"
USE_DPP = True           # if BACKBONE == "dimenet": True -> DimeNet++, False -> DimeNet
TARGET_ATTR = 0          # QM9 property index (0..11). For dimenet path, we remap y accordingly (see below).
EPOCHS = 30              # inner steps for transductive optimization
LR = 1e-2
MASK_POINT = "embed"     # "embed" or "pre_agg" (where to multiply node masks)
KNN_K = 2                # build a 2-NN graph over coordinates for clustering
SEED = 42

torch.manual_seed(SEED); np.random.seed(SEED)


In [None]:

def build_schnet_and_data(device: str, data_root: str, target_attr: int):
    ds = QM9(data_root)
    model, datasets = SchNet.from_qm9_pretrained(data_root, ds, target_attr)
    return model.to(device), datasets  # (train, val, test)

def split_qm9(dataset, train=2048, val=500, test=1024, seed=42):
    from torch.utils.data import random_split
    N = len(dataset)
    assert train + val + test <= N, "Requested split larger than dataset"
    g = torch.Generator().manual_seed(seed)
    return random_split(dataset, [train, val, test], generator=g)

def build_dimenet_and_data(device: str, data_root: str, target_attr: int, use_pp: bool, seed=42):
    dataset = QM9(data_root)
    # Common DimeNet setup: select 12 targets in specific order:
    idx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 12, 13, 14, 15, 11])
    dataset.data.y = dataset.data.y[:, idx]
    if use_pp:
        model = DimeNetPlusPlus(out_channels=1).to(device)
    else:
        model = DimeNet(out_channels=1).to(device)
    train_ds, val_ds, test_ds = split_qm9(dataset, 2048, 500, 1024, seed)
    return model, (train_ds, val_ds, test_ds)


In [None]:

DATA_ROOT = os.path.join(HERE, "data", "QM9")
os.makedirs(DATA_ROOT, exist_ok=True)

if BACKBONE == "schnet":
    hooks = SchNetHooks()
    model, (train_ds, val_ds, test_ds) = build_schnet_and_data(device, DATA_ROOT, TARGET_ATTR)
else:
    hooks = DimeNetHooks()
    model, (train_ds, val_ds, test_ds) = build_dimenet_and_data(device, DATA_ROOT, TARGET_ATTR, USE_DPP)

print(f"Backbone: {BACKBONE} ({'DimeNet++' if (BACKBONE=='dimenet' and USE_DPP) else ''})")
print("Train/Val/Test sizes:", len(train_ds), len(val_ds), len(test_ds))


In [None]:

# Pick a single molecule from the test set
data = test_ds[0].to(device)
print("z shape:", tuple(data.z.shape), "| pos shape:", tuple(data.pos.shape), "| y shape:", tuple(data.y.shape))

# Build a simple 2-NN graph on positions for clustering and visualization
edges = knn_graph(data.pos, k=KNN_K).detach().cpu().numpy()

# Visualize the original molecule
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(5.5, 4.8))
visualize_graph(z=data.z, pos=data.pos, edge_index=edges, ax=ax, title="Molecule (index:symbol)")
plt.show()


In [None]:

explainer = Explainer(
    model=model,
    algorithm=GraphXTransductive(
        hooks=hooks,
        epochs=EPOCHS,
        lr=LR,
        mask_point=MaskPoint.EMBED if MASK_POINT == "embed" else MaskPoint.PRE_AGG,
    ),
    explanation_type="phenomenon",
    node_mask_type="object",   # cluster-parameterized node mask
    edge_mask_type=None,
    model_config=ModelConfig(
        mode="regression",
        task_level="graph",
        return_type="raw",
    ),
)

with torch.no_grad():
    y_full = model(data.z, data.pos)[0]

explanation = explainer(
    x=data.z,
    edge_index=data.pos,   # hooks interpret 'edge_index' as 'pos' for molecules
    target=data.y[:, 0].double(),
    edges=edges,           # for clustering
)

print("Full-model prediction:", y_full.detach().cpu().numpy().ravel())
print("Target:", data.y[:, 0].detach().cpu().numpy().ravel())
print("Node mask shape:", tuple(explanation.node_mask.shape))
if hasattr(explanation, "clusters"):
    print("Num clusters:", len(explanation.clusters))


In [None]:

# Visualize continuous mask and clusters
fig, ax = plt.subplots(1, 2, figsize=(11, 4.4))
visualize_mask(z=data.z, pos=data.pos, edge_index=edges, mask=explanation.node_mask, threshold=0.6, ax=ax[0], title="Node importance (mask)")
if hasattr(explanation, "clusters"):
    visualize_clusters(z=data.z, pos=data.pos, edge_index=edges, clusters=explanation.clusters, ax=ax[1], title="Clusters")
else:
    ax[1].axis("off"); ax[1].set_title("No clusters available")
plt.show()

# Visualize only the explanatory subgraph (above threshold)
fig, ax = plt.subplots(figsize=(5.5, 4.8))
visualize_explanatory_subgraph(z=data.z, pos=data.pos, edge_index=edges, mask=explanation.node_mask, threshold=0.6, ax=ax, title="Explanatory subgraph")
plt.show()


In [None]:

# Optional: simple top-k fidelity check
import torch.nn.functional as F

def eval_topk(model, z, pos, node_mask, ks=[2,3,5,8]):
    with torch.no_grad():
        full = model(z, pos)[0]
    m = node_mask.detach().view(-1)
    order_nodes = torch.argsort(m, descending=True).tolist()
    out = {}
    for k in ks:
        idx_nodes = sorted(order_nodes[:k])
        pred_n = model(z[idx_nodes], pos[idx_nodes])[0]
        loss_n = F.l1_loss(pred_n, full).item()
        out[k] = {"nodes": loss_n}
    return out

res = eval_topk(model, data.z, data.pos, explanation.node_mask, ks=[2,3,5,8])
res
