In [None]:
from __future__ import annotations

from typing import FrozenSet, Set, Tuple
import numpy as np
from Bio.PDB import PDBParser
from Bio.PDB.Residue import Residue


RNA_BASES = {"A", "U", "G", "C"}
CONTACT_ELEMENTS = {"N", "O"}

# Residue identifier: (chain_id, residue_number, insertion_code)
ResidueID = Tuple[str, int, str]
BasePair = FrozenSet[ResidueID]


def _res_id(res: Residue) -> ResidueID:
    """
    Convert a Biopython Residue to a stable identifier used for set comparison.

    Returns
    -------
    (chain_id, resseq, icode)
        chain_id : str  (e.g., "A")
        resseq   : int  (PDB residue number)
        icode    : str  (insertion code; '' if none)
    """
    chain_id = str(res.get_parent().id)
    resseq = int(res.id[1])
    icode = str(res.id[2]).strip()
    return (chain_id, resseq, icode)


def get_base_pairs(structure, cutoff_a: float = 3.3, min_contacts: int = 2) -> Set[BasePair]:
    """
    Infer base pairs from an RNA structure using a contact-count heuristic.

    Definition
    ----------
    Consider residues with resname in {A, U, G, C}.
    Consider only atoms with element in {N, O}.
    Two residues are called a "base pair" if the number of inter-residue
    N/O atom pairs within `cutoff_a` Å is >= `min_contacts`.

    Parameters
    ----------
    structure
        Biopython Structure (from PDBParser.get_structure).
    cutoff_a
        Atom-atom distance cutoff in Å (default 3.3).
    min_contacts
        Minimum number of N/O contacts needed to call a base pair (default 2).

    Returns
    -------
    set of frozenset({ResidueID, ResidueID})
        Each base pair is an unordered pair of residue IDs.
    """
    residues = [
        r for r in structure.get_residues()
        if r.get_resname().strip().upper() in RNA_BASES
    ]

    pairs: Set[BasePair] = set()

    for i, r1 in enumerate(residues):
        atoms1 = [a for a in r1.get_atoms() if (a.element or "").strip().upper() in CONTACT_ELEMENTS]
        if not atoms1:
            continue

        for r2 in residues[i + 1:]:
            atoms2 = [a for a in r2.get_atoms() if (a.element or "").strip().upper() in CONTACT_ELEMENTS]
            if not atoms2:
                continue

            # Count N/O contacts between residues r1 and r2
            contacts = 0
            for a1 in atoms1:
                for a2 in atoms2:
                    if np.linalg.norm(a1.coord - a2.coord) <= cutoff_a:
                        contacts += 1
                        # Early exit once threshold is reached
                        if contacts >= min_contacts:
                            pairs.add(frozenset((_res_id(r1), _res_id(r2))))
                            break
                if contacts >= min_contacts:
                    break

    return pairs


def basepair_metrics(
    pred_pdb: str,
    ref_pdb: str,
    cutoff_a: float = 3.3,
    min_contacts: int = 2,
) -> Tuple[float, float, float]:
    """
    Compute base-pair contact Precision/Recall/F1 of a predicted PDB vs a reference PDB.

    Parameters
    ----------
    pred_pdb
        Path to predicted PDB file.
    ref_pdb
        Path to reference PDB file.
    cutoff_a, min_contacts
        Passed to get_base_pairs().

    Returns
    -------
    precision, recall, f1 : float
        Standard PRF1 computed from set overlap between predicted and reference base-pair sets.
    """
    parser = PDBParser(QUIET=True)
    pred_struct = parser.get_structure("pred", pred_pdb)
    ref_struct = parser.get_structure("ref", ref_pdb)

    pred_pairs = get_base_pairs(pred_struct, cutoff_a=cutoff_a, min_contacts=min_contacts)
    ref_pairs = get_base_pairs(ref_struct, cutoff_a=cutoff_a, min_contacts=min_contacts)

    tp = len(pred_pairs & ref_pairs)
    fp = len(pred_pairs - ref_pairs)
    fn = len(ref_pairs - pred_pairs)

    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = (2.0 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
    return precision, recall, f1


if __name__ == "__main__":
    pred_pdb = "pred.pdb"
    ref_pdb = "ref.pdb"
    p, r, f1 = basepair_metrics(pred_pdb, ref_pdb)
    print(f"Precision={p:.3f}, Recall={r:.3f}, F1={f1:.3f}")
