In [45]:
from collections import Counter
import numpy as np
import sys

sys.path.append('../../utils')
import clustering_utils

def map_residue_numbering_to_auth(pdb_path: str, binding_residues: dict[np.ndarray], binding_scores: dict[np.ndarray]) -> dict[list[int]]:
    """
    Map the binding residues from zero-based numbering (0=first residue, 1=second residue, etc.) to the auth labeling (residue labeling from the PDB file).
    Args:
        pdb_path (str): Path to the PDB file.
        binding_residues (dict[np.ndarray]): Dictionary of binding residues, keys are chain IDs and values are arrays of residue indices (zero-based).
        auth (bool): Whether to use author fields.
    Returns:
        dict[list[int]]: Dictionary of binding residues in the auth labeling, keys are chain IDs and values are lists of residue numbers.
    """
    import biotite.structure.io.pdb as pdb
    from biotite.structure.io.pdb import get_structure
    from biotite.structure import get_residues
    
    cif_file = pdb.PDBFile.read(pdb_path)
    
    protein = get_structure(cif_file, model=1) #, use_author_fields=False) - UNCOMMENT THIS TO USE LABEL_SEQ_ID FIELDS
    protein = protein[(protein.atom_name == "CA") 
                        & (protein.element == "C") ]
    
    mapped_residues = {}
    mapped_scores = {}
    for chain_id in binding_residues.keys():
    
        protein_chain = protein[protein.chain_id == chain_id]
        mapped_residues[chain_id] = []
        mapped_scores[chain_id] = []
        residue_ids, _ = get_residues(protein_chain)
        
        # loop over all residues in chain and check if the residue index matches the binding residue index, if so, add the auth residue number to the mapped residues list
        for i, residue_id in enumerate(residue_ids):
            residue_index = np.where(binding_residues[chain_id] == i)[0] # get positions where the residue index matches the binding residue index
            if len(residue_index) > 0:
                mapped_residues[chain_id].append(residue_id)
                mapped_scores[chain_id].append(binding_scores[chain_id][i])
    
        assert len(mapped_residues[chain_id]) == len(binding_residues[chain_id]), f"Chain {chain_id} has different number of residues in mapped residues and original binding residues"
        assert len(mapped_scores[chain_id]) == len(binding_residues[chain_id]), f"Chain {chain_id} has different number of scores in mapped scores and original binding residues"
    
    return mapped_residues, mapped_scores

def keep_only_standard_residues(structure):
    """Keep only standard protein residues in the structure."""
    for chain in list(structure):
        for residue in list(chain):
            if residue.get_resname() not in clustering_utils.aal_prot:
                chain.detach_child(residue.id)
    return structure

def get_protein_surface_points(pdb_path, predicted_binding_sites):
    from Bio.PDB import PDBParser
    from Bio.PDB.SASA import ShrakeRupley

    p = PDBParser(QUIET=1)
    struct = p.get_structure("protein", pdb_path)
    struct = struct[0]
    struct = keep_only_standard_residues(struct)

    # compute SASA
    sr = ShrakeRupley(n_points=clustering_utils.POINTS_DENSITY_PER_ATOM, probe_radius=clustering_utils.PROBE_RADIUS)
    sr.compute(struct, level="A")

    surface_points = []
    map_surface_points_to_atom_id = []
    atom_coords = {}
    residue_coords = {}
    map_atoms_to_residue_id = {}
    for residue in struct.get_residues():
        # consider only residues from predicted binding sites
        residue_chain = residue.get_full_id()[2]
        residue_id = residue.get_id()[1]
        
        if 'CA' in residue:
            residue_coords[(residue_chain, residue_id)] = residue['CA'].get_vector()
        else:
            # if no CA atom, use the first atom's coordinates
            first_atom = next(residue.get_atoms())
            residue_coords[(residue_chain, residue_id)] = first_atom.get_vector()
        
        if residue.get_id()[1] not in predicted_binding_sites[residue_chain]:
            continue
        
        # get surface points for each atom in the residue
        for atom in residue.get_atoms():
            atom_id = atom.get_serial_number()
            surface_points.append(atom.sasa_points)
            map_surface_points_to_atom_id.extend([atom_id] * len(atom.sasa_points))
            atom_coords[atom_id] = atom.get_vector()
            map_atoms_to_residue_id[atom_id] = (residue_chain, residue_id)

    surface_points = np.vstack(surface_points)
    map_surface_points_to_atom_id = np.array(map_surface_points_to_atom_id)
    return surface_points, map_surface_points_to_atom_id, map_atoms_to_residue_id, atom_coords, residue_coords



def execute_atom_clustering(pdb_path, predictions, probabilities, eps=10):
    """
    Execute atom-level clustering based on predicted binding residues.
    Args:
        pdb_path: Path to the PDB file.
        chain_id: Chain identifier of the protein.
        predictions: List of predicted binding residue IDs (mmCIF numbering).
        probabilities: List of probabilities/scores for the predicted binding residues.
    Returns:
        clusters: Dict {cluster_id: [atom_id, ...], ...}
        cluster_residues: List of Lists [[residue_id, ...], ...] for each cluster. The ordering corresponds to cluster IDs.
        cluster_scores: List of average scores for each cluster. List has size of N, where N is number of clusters, and the ordering corresponds to cluster IDs.
        atom_coords: Dict {atom_id: np.array([x,y,z])}
    """
    mapped_prediction, mapped_scores = map_residue_numbering_to_auth(pdb_path=pdb_path,\
                                                binding_residues=predictions, \
                                                binding_scores=probabilities)
    # 2. Get surface points and their mapping to atoms, atom coordinates, and atom to residue mapping
    all_points, map_point_to_atom, map_atoms_to_residue_id, atom_coords, residue_coords = get_protein_surface_points(pdb_path, mapped_prediction)

    if all_points.shape[0] == 0:
        return None, None, None, None, None

    # 3. Cluster surface points and propagate labels to atoms    
    atom_labels = clustering_utils.cluster_atoms_by_surface(
        all_points, map_point_to_atom, eps=eps)

    # get cluster dictionary {cluster_id: [atom_id, ...], ...}
    clusters = {}
    for atom_index, cluster_label in atom_labels.items():
        if cluster_label not in clusters:
            clusters[cluster_label] = []
        clusters[cluster_label].append(atom_index)

    # 4. Voting: the residue gets the label of the majority of its atoms
    cluster_scores = [[] for _ in range(max(clusters) + 1)]
    cluster_residues = [[] for _ in range(max(clusters) + 1)]
    auth_predictions = {}
    for chain_id, pred in mapped_prediction.items():
        auth_predictions[chain_id] = np.array(pred)


    # 4.1 For each atom in each cluster, get its residue and score
    for atom_id, cluster_label in atom_labels.items():
        chain_id, residue_id = map_atoms_to_residue_id[atom_id] # this is auth residue id
        score = mapped_scores[chain_id][np.where(auth_predictions[chain_id] == int(residue_id))[0][0]]
        cluster_scores[cluster_label].append(score)
        cluster_residues[cluster_label].append(f'{chain_id}_{residue_id}')

    # 4.2 Vote
    # Reformat auth_predictions to be a list of strings in the format "chain_residueid", e.g. "A_123"
    reformated_auth_predictions = []
    for chain_id, pred in auth_predictions.items():
        reformated_auth_predictions.extend([f'{chain_id}_{res_id}' for res_id in pred])
        
    residue_voting = {residue: [0 for _ in range(len(cluster_residues))] for residue in reformated_auth_predictions}
    for i, labels in enumerate(cluster_residues):
        counts = Counter(labels)
        for residue, number_of_occurences in counts.items():
            residue_voting[residue][i] = number_of_occurences
    
    residue_clusters = {i: [] for i in range(len(cluster_residues))}
    # 4.3 get residue cluster assignment based on voting
    for residue, votes in residue_voting.items():
        cluster = np.argmax(votes)
        residue_clusters[cluster].append(residue)
    
    # 5. Compute average cluster scores
    final_cluster_scores = []
    for scores in cluster_scores:
        if len(scores) == 0:
            final_cluster_scores.append(0.0)
        else:
            final_cluster_scores.append(np.mean(scores))
    cluster_scores = final_cluster_scores

    return clusters, residue_clusters, cluster_scores, atom_coords, residue_coords


predictions = {}
probabilities = {}
for chain_id in ['A', 'B', 'C', 'D']:
    with open(f'data/reference_input/pdb1a00_{chain_id}_predictions.csv') as f:
        prediction = [float(i) for i in f.read().splitlines()]
        predictions[chain_id] = np.where(np.array(prediction) > 0.7)[0]
        probabilities[chain_id] = np.array(prediction)

clusters, residue_clusters, cluster_scores, _, _ = execute_atom_clustering(pdb_path='data/reference_input/pdb1a00.pdb',\
                            predictions=predictions, \
                            probabilities=probabilities)


In [46]:
def run_assertions(clusters):
    for cluster_id, atoms in clusters.items():
        for atom in atoms:
            for cluster_iid, atoms in clusters.items():
                if cluster_id != cluster_iid:
                    assert atom not in atoms, f'Atom {atom} is in both cluster {cluster_id} and cluster {cluster_iid}'
run_assertions(clusters)
run_assertions(residue_clusters)
assert len(cluster_scores) == len(clusters) == len(residue_clusters), "Number of clusters, cluster scores, and residue clusters should be the same"