# Single-Structure Clash Score Calculation

In [3]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
from Bio.PDB import PDBParser
from Bio.PDB.Atom import Atom
from Bio.PDB.Structure import Structure
from scipy.spatial import cKDTree


# -----------------------------
# Parameters / constants
# -----------------------------

# Approximate van der Waals radii (Å).
# IMPORTANT: cite your source in the manuscript (e.g., Bondi 1964) and keep consistent.
VDW_RADIUS_A: Dict[str, float] = {
    "H": 1.20,
    "C": 1.70,
    "N": 1.55,
    "O": 1.52,
    "S": 1.80,
    "F": 1.47,
    "P": 1.80,
    "CL": 1.75,
    "MG": 1.73,
}


@dataclass(frozen=True)
class ClashResult:
    """Container for clash computation outputs."""
    num_atoms: int
    num_pairs_tested: int
    num_clashes: int
    clash_score_per_1000_atoms: float


def _normalize_element(atom: Atom) -> Optional[str]:
    """
    Return a normalized element symbol (upper-case), or None if unknown.

    Biopython may leave atom.element empty/incorrect depending on PDB formatting.
    We normalize to common symbols used in VDW_RADIUS_A (e.g., 'CL', 'MG').
    """
    elem = (atom.element or "").strip().upper()
    if elem in VDW_RADIUS_A:
        return elem

    # Fallback heuristic: derive from atom name (e.g., "CL", "MG", " C1 ")
    # This is imperfect; for publication, ensure your PDBs have correct element columns.
    name = atom.get_name().strip().upper()
    # Try first 2 chars if they look like an element (e.g., 'CL', 'MG')
    if len(name) >= 2 and name[:2] in VDW_RADIUS_A:
        return name[:2]
    # Try first char (e.g., 'C', 'N', 'O', 'S', 'H', 'P', 'F')
    if len(name) >= 1 and name[0] in VDW_RADIUS_A:
        return name[0]

    return None


def _collect_atoms(structure: Structure) -> Tuple[List[Atom], np.ndarray, np.ndarray, np.ndarray]:
    """
    Collect atoms that have known vdW radii and return:
      - atoms list
      - coords: (N, 3) float array
      - radii: (N,) float array
      - residue keys: (N,) structured array that encodes (model_id, chain_id, resseq, icode)

    Residue key is used to exclude same-residue and adjacent-residue checks.
    """
    atoms: List[Atom] = []
    coords: List[np.ndarray] = []
    radii: List[float] = []
    # store residue identity: chain + residue sequence + insertion code
    res_keys: List[Tuple[str, int, str]] = []

    for atom in structure.get_atoms():
        elem = _normalize_element(atom)
        if elem is None:
            continue

        parent_res = atom.get_parent()  # Residue
        parent_chain = parent_res.get_parent()  # Chain

        # Residue id is a tuple like (' ', resseq, icode). We use resseq and icode.
        resseq = int(parent_res.id[1])
        icode = str(parent_res.id[2]).strip()  # insertion code, may be ' '

        atoms.append(atom)
        coords.append(atom.get_coord().astype(float))
        radii.append(VDW_RADIUS_A[elem])
        res_keys.append((str(parent_chain.id), resseq, icode))

    if not atoms:
        return [], np.zeros((0, 3), dtype=float), np.zeros((0,), dtype=float), np.zeros((0,), dtype=object)

    return (
        atoms,
        np.vstack(coords),
        np.asarray(radii, dtype=float),
        np.asarray(res_keys, dtype=object),
    )


def calculate_clash_score(
    pdb_file: str,
    clash_threshold_a: float = 0.40,
    *,
    ignore_same_residue: bool = True,
    ignore_adjacent_residue: bool = True,
    include_hydrogens: bool = True,
) -> ClashResult:
    """
    Compute a simple steric clash score for a PDB structure.

    Definition
    ----------
    For each atom pair (i, j), define vdW overlap:
        overlap(i, j) = (r_i + r_j) - d_ij
    where r_i and r_j are van der Waals radii (Å), and d_ij is Euclidean distance (Å).
    A clash occurs if:
        overlap(i, j) >= clash_threshold_a

    Exclusions
    ----------
    Optionally excludes atom pairs within:
      - the same residue
      - adjacent residues (|resseq_i - resseq_j| == 1) on the same chain
    These rules reduce counting “local geometry” contacts.

    Normalization
    -------------
    clash_score_per_1000_atoms = 1000 * (#clashes) / (N_atoms)

    Parameters
    ----------
    pdb_file
        Path to input PDB file.
    clash_threshold_a
        Clash overlap threshold in Å (default 0.40 Å).
    ignore_same_residue
        If True, do not count pairs from the same residue.
    ignore_adjacent_residue
        If True, do not count pairs from adjacent residues (same chain, consecutive resseq).
    include_hydrogens
        If False, hydrogen atoms are excluded (even if present).

    Returns
    -------
    ClashResult
        Dataclass containing counts and the normalized clash score.

    Notes
    -----
    - This is a custom clash metric and is not identical to MolProbity clashscore.
    - Accurate element annotation in the PDB is strongly recommended.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("structure", pdb_file)

    atoms, coords, radii, res_keys = _collect_atoms(structure)

    if not include_hydrogens:
        keep = np.array([(_normalize_element(a) != "H") for a in atoms], dtype=bool)
        atoms = [a for a, k in zip(atoms, keep) if k]
        coords = coords[keep]
        radii = radii[keep]
        res_keys = res_keys[keep]

    n = int(coords.shape[0])
    if n == 0:
        return ClashResult(num_atoms=0, num_pairs_tested=0, num_clashes=0, clash_score_per_1000_atoms=0.0)

    # KD-tree to find candidate neighbors within a conservative cutoff:
    # If overlap >= threshold, then d <= (r_i + r_j - threshold).
    # Using the max possible (max_r + max_r - threshold) gives a safe global radius.
    max_r = float(np.max(radii))
    search_radius = max(0.0, 2.0 * max_r - float(clash_threshold_a))

    tree = cKDTree(coords)
    candidate_pairs = tree.query_pairs(r=search_radius, output_type="set")

    clashes = 0
    pairs_tested = 0

    for i, j in candidate_pairs:
        # Optional exclusions: same residue / adjacent residue on same chain
        if ignore_same_residue or ignore_adjacent_residue:
            chain_i, resseq_i, icode_i = res_keys[i]
            chain_j, resseq_j, icode_j = res_keys[j]
            if chain_i == chain_j and icode_i == icode_j:
                if ignore_same_residue and (resseq_i == resseq_j):
                    continue
                if ignore_adjacent_residue and (abs(resseq_i - resseq_j) == 1):
                    continue

        d = float(np.linalg.norm(coords[i] - coords[j]))
        overlap = float((radii[i] + radii[j]) - d)

        pairs_tested += 1
        if overlap >= clash_threshold_a:
            clashes += 1

    clash_score = 1000.0 * clashes / n
    return ClashResult(
        num_atoms=n,
        num_pairs_tested=pairs_tested,
        num_clashes=clashes,
        clash_score_per_1000_atoms=clash_score,
    )


def main() -> None:
    pdb_file = "../../../data/results/dataset1/6P2H/dataset1.AF3.6P2H.model_0.fixed.pdb"
    result = calculate_clash_score(pdb_file, clash_threshold_a=0.40, include_hydrogens=True)

    print(f"Atoms used: {result.num_atoms}")
    print(f"Pairs tested: {result.num_pairs_tested}")
    print(f"Clashes: {result.num_clashes}")
    print(f"Clash score (per 1000 atoms): {result.clash_score_per_1000_atoms:.2f}")


if __name__ == "__main__":
    main()


Atoms used: 2183
Pairs tested: 555
Clashes: 87
Clash score (per 1000 atoms): 39.85


# Batch Clash Score Calculation

In [None]:
from __future__ import annotations

import argparse
import os
from dataclasses import asdict
from pathlib import Path
from typing import Iterable, List, Optional, Sequence

import numpy as np
import pandas as pd

# Assumes calculate_clash_score is imported from your module:
# from yourmodule.clash import calculate_clash_score, ClashResult


def list_pdb_files(root_dir: str | os.PathLike, *, recursive: bool = True) -> List[str]:
    """
    List all PDB files under `root_dir`.

    Parameters
    ----------
    root_dir
        Directory to search.
    recursive
        If True, search recursively; otherwise only the top-level directory.

    Returns
    -------
    list of str
        Sorted list of absolute PDB file paths.
    """
    root = Path(root_dir)
    if not root.exists():
        raise FileNotFoundError(f"Root directory not found: {root}")

    if recursive:
        paths = root.rglob("*.pdb")
    else:
        paths = root.glob("*.pdb")

    # Deterministic order for reproducibility in papers
    return sorted(str(p.resolve()) for p in paths if p.is_file())


def batch_clash_scores(
    root_dir: str | os.PathLike,
    output_csv: str | os.PathLike,
    *,
    recursive: bool = True,
    clash_threshold_a: float = 0.40,
    ignore_same_residue: bool = True,
    ignore_adjacent_residue: bool = True,
    include_hydrogens: bool = True,
    fail_fast: bool = False,
) -> pd.DataFrame:
    """
    Compute clash scores for all PDB files under `root_dir` and write a CSV.

    The CSV will include:
      - pdb_path
      - clash_score_per_1000_atoms
      - num_atoms
      - num_pairs_tested
      - num_clashes
      - error (empty if ok)

    Parameters
    ----------
    root_dir
        Directory containing PDB files.
    output_csv
        CSV output path.
    recursive
        Whether to recursively search for PDB files.
    clash_threshold_a
        Clash overlap threshold in Å.
    ignore_same_residue, ignore_adjacent_residue, include_hydrogens
        Passed to calculate_clash_score.
    fail_fast
        If True, raise immediately when a file fails.

    Returns
    -------
    pandas.DataFrame
        Results table.
    """
    pdb_files = list_pdb_files(root_dir, recursive=recursive)
    if len(pdb_files) == 0:
        raise RuntimeError(f"No .pdb files found under: {root_dir}")

    out_path = Path(output_csv)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    rows: list[dict] = []

    for pdb_path in pdb_files:
        row = {"pdb_path": pdb_path, "error": ""}

        try:
            result = calculate_clash_score(
                pdb_path,
                clash_threshold_a=clash_threshold_a,
                ignore_same_residue=ignore_same_residue,
                ignore_adjacent_residue=ignore_adjacent_residue,
                include_hydrogens=include_hydrogens,
            )
            # result is a ClashResult dataclass in the rewritten version
            row.update(
                {
                    "clash_score_per_1000_atoms": float(result.clash_score_per_1000_atoms),
                    "num_atoms": int(result.num_atoms),
                    "num_pairs_tested": int(result.num_pairs_tested),
                    "num_clashes": int(result.num_clashes),
                }
            )
            print(f"✅ {pdb_path}: {row['clash_score_per_1000_atoms']:.2f}")

        except Exception as e:
            row.update(
                {
                    "clash_score_per_1000_atoms": np.nan,
                    "num_atoms": np.nan,
                    "num_pairs_tested": np.nan,
                    "num_clashes": np.nan,
                    "error": f"{type(e).__name__}: {e}",
                }
            )
            print(f"❌ {pdb_path}: {row['error']}")
            if fail_fast:
                raise

        rows.append(row)

    df = pd.DataFrame(rows)

    # Nice column order for human reading
    col_order = [
        "pdb_path",
        "clash_score_per_1000_atoms",
        "num_atoms",
        "num_pairs_tested",
        "num_clashes",
        "error",
    ]
    df = df[col_order]

    df.to_csv(out_path, index=False, encoding="utf-8")
    print(f"\nAll done. Results saved to: {out_path}")
    return df


def build_argparser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description="Batch compute vdW clash score for PDB files.")
    p.add_argument("--root_dir", type=str, required=True, help="Directory containing PDB files.")
    p.add_argument("--output_csv", type=str, required=True, help="Output CSV path.")
    p.add_argument("--no_recursive", action="store_true", help="Do not search recursively.")
    p.add_argument("--threshold", type=float, default=0.40, help="Clash threshold (Å). Default: 0.40")
    p.add_argument("--exclude_h", action="store_true", help="Exclude hydrogens.")
    p.add_argument("--keep_same_residue", action="store_true", help="Do NOT exclude same-residue pairs.")
    p.add_argument("--keep_adjacent_residue", action="store_true", help="Do NOT exclude adjacent-residue pairs.")
    p.add_argument("--fail_fast", action="store_true", help="Stop immediately on first failure.")
    return p


def main() -> None:
    args = build_argparser().parse_args()

    batch_clash_scores(
        args.root_dir,
        args.output_csv,
        recursive=not args.no_recursive,
        clash_threshold_a=args.threshold,
        ignore_same_residue=not args.keep_same_residue,
        ignore_adjacent_residue=not args.keep_adjacent_residue,
        include_hydrogens=not args.exclude_h,
        fail_fast=args.fail_fast,
    )


if __name__ == "__main__":
    main()
