
Run GNNExplainer on graph test/val splits for isoform 1A2, save PNG visualizations and CSV indexes

Produces per-sample PNGs, per-split CSVs, master_index.csv, README.md and qc_summary.txt.

In [1]:
import os, glob, math, json, csv, time
from pathlib import Path
from typing import Optional

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torch_geometric.explain import Explainer, Explanation
from torch_geometric.explain.algorithm import GNNExplainer
from torch_geometric.explain.config import ModelConfig

from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import rdMolDescriptors 


from model_GINE import GINEModel as GraphModelClass

In [2]:
# ------------------------------
# Config / paths 
# ------------------------------
ISOFORM = "1A2"
GRAPH_ROOT_BASE = os.path.join("..", "GraphDataset", ISOFORM)
CSV_BASE = os.path.join("..", "data", "processed")
CSV_TEST = os.path.join(CSV_BASE, f"{ISOFORM}_test.csv")
CSV_VAL  = os.path.join(CSV_BASE, f"{ISOFORM}_val.csv")
MODEL_PATH = f"models/GINE_CYP{ISOFORM}.pth"

OUT_ROOT = os.path.join("..", "GNNExplainer", ISOFORM)
SPLITS = {"Test": os.path.join(GRAPH_ROOT_BASE, "test"),
          "Val":  os.path.join(GRAPH_ROOT_BASE, "val")}

os.makedirs(OUT_ROOT, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
np.random.seed(SEED); torch.manual_seed(SEED)

# Explainer params
EXPLAINER_PARAMS = {
    "algorithm_epochs": 100,
    "algorithm_lr": 0.01,
    "node_mask_type": "object",
    "edge_mask_type": "object",
    "normalize_per_molecule": True,
    "top_selection": "top_k_or_threshold",  # reported in README
    "top_k": 5,
    "threshold_frac": 0.10,  # 10% of max
}

# CSV header template
CSV_COLUMNS = [
    "Drug_ID", "Split", "Isoform", "True_Label", "Pred_Label", "Logit", "Prob", "PNG_Path",
    "Mol_SMILES", "Num_Atoms", "Num_Bonds", "Num_Rings"
]
# add Top_Atom_1..5 and Top_Bond_1..5
for i in range(1, EXPLAINER_PARAMS["top_k"] + 1):
    CSV_COLUMNS.append(f"Top_Atom_{i}")
for i in range(1, EXPLAINER_PARAMS["top_k"] + 1):
    CSV_COLUMNS.append(f"Top_Bond_{i}")
# metadata fields
CSV_COLUMNS += ["Explainer_Params", "Model_Checkpoint", "Date", "Seed", "Notes"]

In [3]:
# ------------------------------
# Helpers
# ------------------------------
def load_graph_pt(pt_path):
    return torch.load(pt_path)
def drug_id_from_filename(fname: str) -> str:
    name = Path(fname).stem
    return name.split("_")[0]
def normalize_mask(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=float)
    if x.size == 0:
        return x
    mn, mx = float(x.min()), float(x.max())
    if math.isclose(mx, mn):
        return np.zeros_like(x)
    return (x - mn) / (mx - mn)
def ensure_dirs():
    # create per-split and per-category dirs
    categories = ["True_Positive", "True_Negative", "False_Positive", "False_Negative"]
    for split in SPLITS.keys():
        sroot = os.path.join(OUT_ROOT, split)
        os.makedirs(sroot, exist_ok=True)
        for cat in categories:
            os.makedirs(os.path.join(sroot, cat), exist_ok=True)

def rdkit_draw_and_save(mol: Chem.Mol, atom_scores: Optional[np.ndarray], bond_scores: Optional[np.ndarray],
                        out_path: str, size=(600,600)) -> None:
    """Draw molecule with highlighted atoms/bonds and save PNG to out_path"""
    if mol is None:
        # create a blank placeholder
        arr = np.zeros((size[1], size[0], 3), dtype=np.uint8)
        plt.imsave(out_path, arr)
        return

    atom_scores_norm = None if atom_scores is None else normalize_mask(atom_scores).ravel()
    bond_scores_norm = None if bond_scores is None else normalize_mask(bond_scores).ravel()

    atom_colors = {}
    if atom_scores_norm is not None and atom_scores_norm.size>0:
        cmap = plt.get_cmap("OrRd")
        for i, v in enumerate(atom_scores_norm):
            rgba = cmap(float(v))
            atom_colors[i] = (float(rgba[0]), float(rgba[1]), float(rgba[2]))

    bond_indices = []
    bond_colors = {}
    if bond_scores_norm is not None and bond_scores_norm.size>0:
        cmap = plt.get_cmap("OrRd")
        # bond_scores aligned to edge_index order; mapping handled by caller
        for bidx, v in enumerate(bond_scores_norm):
            rgba = cmap(float(v))
            bond_colors[bidx] = (float(rgba[0]), float(rgba[1]), float(rgba[2]))
            bond_indices.append(bidx)

    try:
        drawer = rdMolDraw2D.MolDraw2DCairo(size[0], size[1])
        rdMolDraw2D.PrepareAndDrawMolecule(drawer, mol,
                                           highlightAtoms=list(atom_colors.keys()),
                                           highlightAtomColors={k: atom_colors[k] for k in atom_colors},
                                           highlightBonds=bond_indices,
                                           highlightBondColors={k: bond_colors[k] for k in bond_colors})
        drawer.FinishDrawing()
        png_bytes = drawer.GetDrawingText()
        with open(out_path, "wb") as f:
            f.write(png_bytes)
    except Exception as e:
        print("RDKit draw failed for", out_path, e)
        # fallback: save empty image
        arr = np.zeros((size[1], size[0], 3), dtype=np.uint8)
        plt.imsave(out_path, arr)

In [4]:
# ------------------------------
# Main flow
# ------------------------------
def main():
    ensure_dirs()
    # load optional SMILES maps
    smiles_map = {}
    for csvp, splitname in [(CSV_TEST, "Test"), (CSV_VAL, "Val")]:
        if os.path.exists(csvp):
            df = pd.read_csv(csvp)
            id_col = "Drug_ID" if "Drug_ID" in df.columns else "Drug"
            smi_col = "Drug" if "Drug" in df.columns else df.columns[-1]
            for _, r in df.iterrows():
                try:
                    key = str(int(r[id_col]))
                except Exception:
                    key = str(r[id_col])
                smiles_map[(splitname, key)] = str(r[smi_col])

    # load model
    print("Loading model from", MODEL_PATH)
    ckpt = torch.load(MODEL_PATH, map_location="cpu")

    # need to construct model with sample dims; find any .pt to inspect dims
    any_pt = None
    for split, folder in SPLITS.items():
        files = sorted(glob.glob(os.path.join(folder, "*.pt")))
        if files:
            any_pt = files[0]
            break
    if any_pt is None:
        raise RuntimeError("No .pt graph files found in splits")
    sample = load_graph_pt(any_pt)
    in_node_feats = sample.x.shape[1]
    in_edge_feats = sample.edge_attr.shape[1] if hasattr(sample, "edge_attr") and sample.edge_attr is not None else None

    model_params = {
        "model_embedding_size": 128,
        "model_layers": 4,
        "model_dropout_rate": 0.2,
        "model_dense_neurons": 256,
    }

    try:
        model = GraphModelClass(in_node_feats, in_edge_feats, model_params)
    except TypeError:
        model = GraphModelClass(in_node_feats, model_params)

    # load weights tolerant
    if isinstance(ckpt, dict) and ("model" in ckpt or any(k.startswith("module.") for k in ckpt.keys())):
        state = ckpt.get("model", ckpt)
    else:
        state = ckpt
    model.load_state_dict(state, strict=False)
    model = model.to(DEVICE)
    model.eval()

    # build explainer
    algorithm = GNNExplainer(epochs=EXPLAINER_PARAMS["algorithm_epochs"], lr=EXPLAINER_PARAMS["algorithm_lr"]) 
    model_conf = ModelConfig(mode='binary_classification', task_level='graph', return_type='raw')
    explainer = Explainer(
        model=model,
        algorithm=algorithm,
        explanation_type='model',
        model_config=model_conf,
        node_mask_type=EXPLAINER_PARAMS['node_mask_type'] if 'node_mask_type' in EXPLAINER_PARAMS else 'object',
        edge_mask_type=EXPLAINER_PARAMS['edge_mask_type'] if 'edge_mask_type' in EXPLAINER_PARAMS else 'object'
    )

    master_rows = []

    for split, folder in SPLITS.items():
        print("Processing split:", split, "folder:", folder)
        pt_files = sorted(glob.glob(os.path.join(folder, "*.pt")))
        csv_rows = []
        qc = {"TP":0, "TN":0, "FP":0, "FN":0, "images":0, "failures":0}

        # CSV path for this split
        csv_out_path = os.path.join(OUT_ROOT, f"index_{ISOFORM}_{split}.csv")
        # if exists, start fresh
        with open(csv_out_path, 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=CSV_COLUMNS)
            writer.writeheader()

        for pt in pt_files:
            try:
                data = load_graph_pt(pt)
                x = data.x.to(DEVICE).float()
                edge_index = data.edge_index.to(DEVICE)
                edge_attr = data.edge_attr.to(DEVICE).float() if hasattr(data, 'edge_attr') and data.edge_attr is not None else None
                batch = torch.zeros(x.size(0), dtype=torch.long, device=DEVICE)

                # forward
                with torch.no_grad():
                    try:
                        logit_t = model(x, edge_index, edge_attr, batch)
                    except TypeError:
                        try:
                            logit_t = model(x, edge_index, batch)
                        except Exception:
                            logit_t = model(x, edge_index)
                    if isinstance(logit_t, torch.Tensor):
                        logit = float(logit_t.detach().cpu().numpy().ravel()[0])
                    else:
                        logit = float(logit_t)
                    prob = 1.0 / (1.0 + math.exp(-logit))

                # true label: try data.y, else infer from filename
                true_label = None
                if hasattr(data, 'y'):
                    try:
                        true_label = int(data.y.detach().cpu().numpy().ravel()[0])
                    except Exception:
                        try:
                            true_label = int(data.y)
                        except Exception:
                            true_label = None
                if true_label is None:
                    # infer from filename after underscore
                    fname = Path(pt).stem
                    if '_' in fname:
                        true_label = int(fname.split('_')[-1])
                    else:
                        true_label = -1

                pred_label = 1 if prob >= 0.5 else 0

                # Run explainer
                explanation: Explanation = explainer(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
                node_mask = explanation.node_mask.detach().cpu().numpy().ravel() if explanation.node_mask is not None else np.zeros(x.size(0))
                edge_mask = explanation.edge_mask.detach().cpu().numpy().ravel() if explanation.edge_mask is not None else np.zeros(edge_index.shape[1])

                # get smiles if available
                drug_id = drug_id_from_filename(pt)
                smiles = None
                if (split, drug_id) in smiles_map:
                    smiles = smiles_map[(split, drug_id)]
                elif os.path.exists(CSV_TEST):
                    # try general CSV_TEST map
                    try:
                        df = pd.read_csv(CSV_TEST)
                        row = df[df['Drug_ID'].astype(str) == drug_id]
                        if not row.empty:
                            smiles = str(row.iloc[0]['Drug'])
                    except Exception:
                        pass

                mol = Chem.MolFromSmiles(smiles) if smiles is not None else None

                # Map edge_index -> RDKit bond idx importance aggregation
                # edge_index shape [2,E]; edges may be bidirectional: we map each edge column to a bond id if exists
                bond_importances = {}
                ei = edge_index.detach().cpu().numpy()
                for col in range(ei.shape[1]):
                    u, v = int(ei[0, col]), int(ei[1, col])
                    if mol is not None:
                        b = mol.GetBondBetweenAtoms(u, v)
                        if b is not None:
                            bidx = b.GetIdx()
                            bond_importances.setdefault(bidx, []).append(float(edge_mask[col]))
                        else:
                            # keep placeholder for nonexistent bond, map to -1
                            bond_importances.setdefault(-1, []).append(float(edge_mask[col]))
                    else:
                        # no mol -> keep per-edge mask as per-col
                        bond_importances.setdefault(col, []).append(float(edge_mask[col]))

                # aggregate bond importances: take max across directions if duplicates
                bond_ids = sorted(list(bond_importances.keys()))
                bond_id_vals = []
                for b in bond_ids:
                    vals = bond_importances[b]
                    bond_id_vals.append((b, float(np.max(vals))))

                # build arrays aligned to RDKit bond indices if mol present
                num_atoms = mol.GetNumAtoms() if mol is not None else int(x.size(0))
                num_bonds = mol.GetNumBonds() if mol is not None else int(len(bond_id_vals))

                atom_scores = node_mask if node_mask is not None else np.zeros(num_atoms)
                # build bond_scores_rdkit array where index == bond.GetIdx()
                bond_scores_rd = np.zeros(num_bonds, dtype=float)
                if mol is not None:
                    for (bidx, val) in bond_id_vals:
                        if bidx >= 0 and bidx < num_bonds:
                            bond_scores_rd[bidx] = val
                else:
                    # when mol missing, try to create bond_scores based on bond_id_vals order
                    for i, (k, v) in enumerate(bond_id_vals):
                        if i < num_bonds:
                            bond_scores_rd[i] = v

                # Normalize both raw and normalized
                atom_scores_raw = atom_scores.copy()
                bond_scores_raw = bond_scores_rd.copy()
                atom_scores_norm = normalize_mask(atom_scores_raw) if EXPLAINER_PARAMS['normalize_per_molecule'] else atom_scores_raw
                bond_scores_norm = normalize_mask(bond_scores_raw) if EXPLAINER_PARAMS['normalize_per_molecule'] else bond_scores_raw

                # pick top atoms and bonds
                abs_atom = np.abs(atom_scores_raw)
                top_atom_idxs = list(np.argsort(-abs_atom)[:EXPLAINER_PARAMS['top_k']])
                top_atoms = []
                for idx in top_atom_idxs:
                    sym = mol.GetAtomWithIdx(int(idx)).GetSymbol() if mol is not None else ""
                    top_atoms.append((int(idx), float(atom_scores_raw[int(idx)]), sym))

                abs_bond = np.abs(bond_scores_raw)
                top_bond_idxs = list(np.argsort(-abs_bond)[:EXPLAINER_PARAMS['top_k']])
                top_bonds = []
                # mapping from bond idx to atom i-j
                if mol is not None:
                    for bidx in top_bond_idxs:
                        if bidx >=0 and bidx < mol.GetNumBonds():
                            b = mol.GetBondWithIdx(int(bidx))
                            ai = b.GetBeginAtomIdx(); aj = b.GetEndAtomIdx()
                            top_bonds.append((int(bidx), float(bond_scores_raw[int(bidx)]), f"{ai}-{aj}"))
                        else:
                            top_bonds.append((int(bidx), float(bond_scores_raw[int(bidx)]), "-"))
                else:
                    for bpos in top_bond_idxs:
                        top_bonds.append((int(bpos), float(bond_scores_raw[int(bpos)]), "-"))

                # prepare png path
                category = None
                if true_label == 1 and pred_label == 1:
                    category = "True_Positive"
                    qc['TP'] += 1
                elif true_label == 0 and pred_label == 0:
                    category = "True_Negative"
                    qc['TN'] += 1
                elif true_label == 0 and pred_label == 1:
                    category = "False_Positive"
                    qc['FP'] += 1
                elif true_label == 1 and pred_label == 0:
                    category = "False_Negative"
                    qc['FN'] += 1
                else:
                    category = "Unk"

                stem = f"{drug_id}__true{true_label}__pred{pred_label}__prob{prob:.4f}__logit{logit:.4f}"
                out_png = os.path.join(OUT_ROOT, split, category, stem + ".png")

                # draw and save
                rdkit_draw_and_save(mol, atom_scores_norm, bond_scores_norm, out_png, size=(400,400))
                qc['images'] += 1

                # build CSV row
                row = {
                    "Drug_ID": drug_id,
                    "Split": split,
                    "Isoform": ISOFORM,
                    "True_Label": int(true_label),
                    "Pred_Label": int(pred_label),
                    "Logit": float(logit),
                    "Prob": float(prob),
                    "PNG_Path": os.path.relpath(out_png),
                    "Mol_SMILES": smiles if smiles is not None else "",
                    "Num_Atoms": int(num_atoms),
                    "Num_Bonds": int(num_bonds),
                    "Num_Rings": int(Chem.rdMolDescriptors.CalcNumRings(mol)) if mol is not None else 0,
                }
                # top atoms/bonds into columns as strings
                for i in range(EXPLAINER_PARAMS['top_k']):
                    if i < len(top_atoms):
                        a = top_atoms[i]
                        row[f"Top_Atom_{i+1}"] = f"{a[0]},{a[1]:.6f},{a[2]}"
                    else:
                        row[f"Top_Atom_{i+1}"] = ""
                for i in range(EXPLAINER_PARAMS['top_k']):
                    if i < len(top_bonds):
                        b = top_bonds[i]
                        row[f"Top_Bond_{i+1}"] = f"{b[0]},{b[1]:.6f},{b[2]}"
                    else:
                        row[f"Top_Bond_{i+1}"] = ""

                row["Explainer_Params"] = json.dumps(EXPLAINER_PARAMS)
                row["Model_Checkpoint"] = MODEL_PATH
                row["Date"] = time.strftime("%Y-%m-%d %H:%M:%S")
                row["Seed"] = SEED
                row["Notes"] = ""

                # append to split csv
                with open(csv_out_path, 'a', newline='') as csvfile:
                    writer = csv.DictWriter(csvfile, fieldnames=CSV_COLUMNS)
                    writer.writerow(row)

                master_rows.append(row)

            except Exception as e:
                qc['failures'] += 1
                print("Failure processing", pt, e)

        # write qc_summary
        qc_path = os.path.join(OUT_ROOT, split, 'qc_summary.txt')
        with open(qc_path, 'w') as f:
            f.write(json.dumps(qc, indent=2))

        print(f"Finished split {split}. QC:", qc)

    # write master index
    master_csv = os.path.join(OUT_ROOT, 'master_index.csv')
    if master_rows:
        pd.DataFrame(master_rows).to_csv(master_csv, index=False)

    # write README
    readme = os.path.join(OUT_ROOT, 'README.md')
    with open(readme, 'w') as f:
        f.write(f"# GNNExplainer outputs for {ISOFORM}\n\n")
        f.write(f"Model checkpoint: {MODEL_PATH}\n")
        f.write(f"Explainer params: {json.dumps(EXPLAINER_PARAMS)}\n")
        f.write(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Seed: {SEED}\n")
        f.write("Normalization: per-molecule max->1.0 applied to atom/bond importances. Both raw and normalized are recorded in CSV columns (raw in Top_* values are raw importance).\n")
        f.write("Top selection: top_k by absolute importance.\n")
        f.write("Train split outputs (if present) are for debugging only and should not be used for claims about generalization.\n")

    print("All done. Outputs under:", OUT_ROOT)




In [5]:

if __name__ == '__main__':
    main()

Loading model from models/GINE_CYP1A2.pth


  ckpt = torch.load(MODEL_PATH, map_location="cpu")
  return torch.load(pt_path)


Processing split: Test folder: ..\GraphDataset\1A2\test
Finished split Test. QC: {'TP': 553, 'TN': 519, 'FP': 114, 'FN': 73, 'images': 1259, 'failures': 0}
Processing split: Val folder: ..\GraphDataset\1A2\val
Finished split Val. QC: {'TP': 571, 'TN': 478, 'FP': 140, 'FN': 68, 'images': 1257, 'failures': 0}
All done. Outputs under: ..\GNNExplainer\1A2
