In this notebook we implement approaches to finding cryptic binding sites
based on the predicted structures downloaded from AlphaFold DB.

In the `HACKATHON_DATA`

1. **pLDDT based approach** - we look for pockets with low predicted
pLDDT score, which indicates low confidence of structure prediction = high
probability of conformational changes

In [3]:
import sys
from pathlib import Path

project_root = Path().resolve().parents[1]
sys.path.append(str(project_root))

In [62]:
import re

from Bio.PDB import PDBParser, MMCIFParser
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

from src.constants import HACKATHON_DATA, ALPHAFOLD_STRUCTURES_DIR, PDB_STRUCTURES_DIR

In [5]:
data = pd.read_csv(HACKATHON_DATA)
data.head()

Unnamed: 0,pdb_id,query_poi,chain,uniprot_id,pdb_pocket_selection,alphafold_pocket_selection
0,3kr4,K_BES_1003,K,Q8IL11,3kr4 and ( (chain K and resi 374+379+386+392+3...,AF-Q8IL11-F1-model_v4 and ( resi 374+379+386+3...
1,3kun,B_HEM_139,B,Q9NAV8,3kun and ( (chain B and resi 24+31+34+35+36+51...,AF-Q9NAV8-F1-model_v4 and ( resi 25+32+35+36+3...
2,5t67,A_SAH_502,A,B3TMQ9,5t67 and ( (chain A and resi 72+76+78+80+90+11...,AF-B3TMQ9-F1-model_v4 and ( resi 72+76+78+80+9...
3,3kr6,A_FFQ_500,A,P0A749,3kr6 and ( (chain A and resi 22+23+49+91+114+1...,AF-P0A749-F1-model_v4 and ( resi 22+23+49+91+1...
4,3kpz,A_ZNE_525,A,P11473,3kpz and ( (chain A and resi 143+147+150+227+2...,AF-P11473-F1-model_v4 and ( resi 143+147+150+2...


In [42]:
def parse_pocket_selection(
    selection: str,
    default_chain: str = "A",
) -> tuple[str, list[int]]:
    """
    Parse a PyMOL selection string to extract chain ID and residue indices.
    """
    chain_match = re.search(r'chain\s+([A-Za-z])', selection)
    chain_id = chain_match.group(1) if chain_match else default_chain

    resi_match = re.search(r'resi\s+([0-9+\s]+)', selection)
    if not resi_match:
        raise ValueError("No residue information found in selection string.")

    resi_list = [int(r.strip()) for r in resi_match.group(1).split('+') if r.strip().isdigit()]

    return chain_id, resi_list

### 1. pLDDT based approach

In [5]:
def extract_plddt_from_pdb(pdb_file: Path) -> list[tuple[int, float]] | list[float]:
    """
    Extract pLDDT values from a PDB file downloaded from AlphaFold DB.

    Returns:
        A list of tuples, where each tuple contains the residue index and its
        corresponding pLDDT value.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("model", pdb_file)

    # AlphaFold predictions contain one model and one chain
    model = next(structure.get_models())
    chain = next(model.get_chains())

    residue_plddt_scores = []

    for residue in chain:
        atoms = list(residue.get_atoms())
        # All atoms in a residue a have the same pLDDT in AF outputs, so
        # we get the pLDDT value from the B-factor field of the first atom
        plddt = atoms[0].get_bfactor()
        residue_plddt_scores.append(plddt)

    return residue_plddt_scores

In [6]:
def residues_mean_plddt(residues: list[int], plddt_scores: list[float]) -> float:
    """
    Calculate the mean pLDDT score for a list of residues.
    """
    selected_scores = [plddt_scores[r - 1] for r in residues] # residues are 1-indexed
    return sum(selected_scores) / len(selected_scores)

Now we iterate over all AlphaFold pocket selections in our dataset, and check the mean value of pLDDT.
If it is below threshold (low confidence prediction = possible high structure variation), classify the selected pocket as *cryptic*. 

In [7]:
pockets_plddt_df = data[['uniprot_id', 'alphafold_pocket_selection']].drop_duplicates()
pockets_plddt_df.head()

Unnamed: 0,uniprot_id,alphafold_pocket_selection
0,Q8IL11,AF-Q8IL11-F1-model_v4 and ( resi 374+379+386+3...
1,Q9NAV8,AF-Q9NAV8-F1-model_v4 and ( resi 25+32+35+36+3...
2,B3TMQ9,AF-B3TMQ9-F1-model_v4 and ( resi 72+76+78+80+9...
3,P0A749,AF-P0A749-F1-model_v4 and ( resi 22+23+49+91+1...
4,P11473,AF-P11473-F1-model_v4 and ( resi 143+147+150+2...


In [8]:
pocket_mean_plddt = []

for idx, row in tqdm(pockets_plddt_df.iterrows(), total=len(pockets_plddt_df)):
    uniprot_id = row['uniprot_id']
    selection_str = row['alphafold_pocket_selection']
    residues = extract_residues(selection_str)
    pdb_file = ALPHAFOLD_STRUCTURES_DIR / f"{uniprot_id}.pdb"

    if not pdb_file.exists():
        print(f"Warning: PDB file for {uniprot_id} not found. Skipping.")
        pocket_mean_plddt.append(None)
        continue

    plddt_scores = extract_plddt_from_pdb(pdb_file)

    try:
        mean_plddt = residues_mean_plddt(residues, plddt_scores)
        pocket_mean_plddt.append(mean_plddt)
    except Exception as e:
        print(f"Error calculating mean pLDDT for {idx}, {uniprot_id}: {e}")
        pocket_mean_plddt.append(None)
        continue



 14%|█▎        | 4664/34260 [01:31<09:56, 49.60it/s]



 14%|█▎        | 4684/34260 [01:32<09:01, 54.63it/s]



 59%|█████▉    | 20218/34260 [06:17<03:01, 77.19it/s]

Error calculating mean pLDDT for 36266, Q7YTB0: list index out of range


 59%|█████▉    | 20226/34260 [06:17<03:29, 67.12it/s]

Error calculating mean pLDDT for 36299, Q7YTB0: list index out of range


 75%|███████▌  | 25797/34260 [07:58<02:45, 51.00it/s]

Error calculating mean pLDDT for 46941, P9WFR9: list index out of range


 75%|███████▌  | 25823/34260 [07:58<02:27, 57.17it/s]

Error calculating mean pLDDT for 46985, P9WFR9: list index out of range


 75%|███████▌  | 25855/34260 [07:59<02:48, 49.84it/s]

Error calculating mean pLDDT for 47038, P9WFR9: list index out of range
Error calculating mean pLDDT for 47054, P9WFR9: list index out of range


 76%|███████▌  | 25869/34260 [07:59<02:24, 58.16it/s]

Error calculating mean pLDDT for 47062, P9WFR9: list index out of range


100%|██████████| 34260/34260 [10:35<00:00, 53.95it/s]


In [9]:
pockets_plddt_df['mean_plddt'] = pocket_mean_plddt

In [10]:
pockets_plddt_df

Unnamed: 0,uniprot_id,alphafold_pocket_selection,mean_plddt
0,Q8IL11,AF-Q8IL11-F1-model_v4 and ( resi 374+379+386+3...,98.145500
1,Q9NAV8,AF-Q9NAV8-F1-model_v4 and ( resi 25+32+35+36+3...,96.951364
2,B3TMQ9,AF-B3TMQ9-F1-model_v4 and ( resi 72+76+78+80+9...,97.025217
3,P0A749,AF-P0A749-F1-model_v4 and ( resi 22+23+49+91+1...,90.284167
4,P11473,AF-P11473-F1-model_v4 and ( resi 143+147+150+2...,95.835385
...,...,...,...
64377,Q16698,AF-Q16698-F1-model_v4 and ( resi 66+69+70+71+8...,97.465417
64379,Q16698,AF-Q16698-F1-model_v4 and ( resi 66+69+70+71+7...,95.601538
64380,Q96X16,AF-Q96X16-F1-model_v4 and ( resi 396+398+401+4...,97.343750
64384,P69834,AF-P69834-F1-model_v4 and ( resi 84+85+86+87+9...,92.081333


In [61]:
pockets_plddt_df['mean_plddt'].hist(bins=50)
plt.xlabel('Mean pLDDT')
plt.ylabel('Count')
plt.title('Distribution of Mean pLDDT for Pockets')
plt.show()

NameError: name 'pockets_plddt_df' is not defined

In [16]:
for threshold in [60, 65, 70, 75, 80, 85, 90]:
    count = (pockets_plddt_df['mean_plddt'] <= threshold).sum()
    print(f"Number of pockets with mean pLDDT <= {threshold}: {count}")

Number of pockets with mean pLDDT <= 60: 19
Number of pockets with mean pLDDT <= 65: 30
Number of pockets with mean pLDDT <= 70: 61
Number of pockets with mean pLDDT <= 75: 90
Number of pockets with mean pLDDT <= 80: 248
Number of pockets with mean pLDDT <= 85: 840
Number of pockets with mean pLDDT <= 90: 3009


### 2. RMSD based approach

In [63]:
from Bio.PDB import Superimposer
from Bio.PDB.Atom import Atom
from Bio.PDB.Structure import Structure


def get_pocket_atoms(
    structure: Structure,
    chain_id: str,
    pocket_residues: list[int],
    atom_names=['CA'],
) -> list[Atom]:
    """
    Extract atoms from a specified pocket in a PDB structure.
    """
    pocket_atoms = []
    model = structure[0]
    chain = model[chain_id]
    for res_id in pocket_residues:
        residue = chain[res_id]
        for atom_name in atom_names:
            if atom_name in residue:
                pocket_atoms.append(residue[atom_name])
    return pocket_atoms


def calculate_pocket_rmsd(
    holo_struct: Structure,
    apo_struct: Structure,
    holo_chain: str,
    apo_chain: str,
    holo_residues: list[int],
    apo_residues: list[int],
) -> float:
    """
    Calculate the RMSD between two pockets in holo and apo structures.
    This function assumes that the residues in both pockets are aligned by their
    C-alpha atoms.
    """
    holo_atoms = get_pocket_atoms(holo_struct, holo_chain, holo_residues)
    apo_atoms = get_pocket_atoms(apo_struct, apo_chain, apo_residues)

    if len(holo_atoms) != len(apo_atoms):
        raise ValueError("The number of atoms in the holo and apo pockets must be the same.")
    if len(holo_atoms) == 0:
        raise ValueError("No atoms found in the specified pockets.")
    if len(holo_atoms) < 3:
        raise ValueError("At least 3 atoms are required to calculate RMSD.")

    superimposer = Superimposer()
    superimposer.set_atoms(holo_atoms, apo_atoms)
    return superimposer.rms


In [None]:
pdb_parser = PDBParser(QUIET=True)
cif_parser = MMCIFParser(QUIET=True)

for _, row in tqdm(data.iterrows(), total=len(data)):
    pdb_id = row['pdb_id']
    uniprot_id = row['uniprot_id']
    holo_pocket_selection = row['pdb_pocket_selection']
    apo_pocket_selection = row['alphafold_pocket_selection']

    if pd.isna(holo_pocket_selection) or pd.isna(apo_pocket_selection):
        continue

    holo_chain, holo_residues = parse_pocket_selection(holo_pocket_selection)
    apo_chain, apo_residues = parse_pocket_selection(apo_pocket_selection)

    holo_file = PDB_STRUCTURES_DIR / f"{pdb_id}.cif"
    apo_file = ALPHAFOLD_STRUCTURES_DIR / f"{uniprot_id}.pdb"

    if not holo_file.exists() or not apo_file.exists():
        print(f"Warning: PDB files for {uniprot_id} not found. Skipping.")
        continue

    holo_structure = cif_parser.get_structure(pdb_id, holo_file)
    apo_structure = pdb_parser.get_structure(uniprot_id, apo_file)

    try:
        rmsd = calculate_pocket_rmsd(
            holo_structure,
            apo_structure,
            holo_chain,
            apo_chain,
            holo_residues,
            apo_residues
        )
        # print(f"RMSD for {uniprot_id} - {pdb_id}: {rmsd:.2f} Å")
    except Exception as e:
        print(f"Error calculating RMSD for {uniprot_id}: {e}")