# Pocket-centric metrics
Metrics like F1, MCC, ACC, AUC ROC, and AUPRC are oriented towards binary classification. However, during the binding site prediction, pockets are the actual target. Let's perform an analysis oriented towards pockets.

In [39]:
import json
import numpy as np
from scipy.spatial import distance_matrix

from biotite.structure import get_residue_starts, Atom, AtomArray
from biotite.structure.io.pdbx import get_structure
import biotite.structure.io.pdbx as pdbx
import biotite.database.rcsb as rcsb
import biotite

from clustering import compute_clusters, THRESHOLD

DATASET_PATH = '/home/vit/Projects/cryptic-nn/datasets/cryptobench-dataset/folds/test.json'
CIF_FILES = '/home/vit/Projects/deeplife-project/data/cif_files'
PREDICTIONS_PATH = '/home/vit/Projects/cryptic-nn/data/predictions/ESM2-3B-extended-finetuning'
K = 2

def load_dataset():
    with open(DATASET_PATH, 'r') as json_file:
        dataset = json.load(json_file)
    return dataset

def load_structure(pdb_id):
    cif_file_path = rcsb.fetch(pdb_id, "cif", target_path=CIF_FILES)
    cif_file = pdbx.CIFFile.read(cif_file_path)
    return get_structure(cif_file, model=1)

def compute_centroid(pocket_residues):
    pocket_coords = np.array([residue.coord for residue in pocket_residues])
    pocket_center = np.mean(pocket_coords, axis=0)
    return pocket_center

def compute_pocket_center(structure, pocket):
    pocket_residues = [residue for residue in structure if str(residue.res_id) in pocket]
    # sanity check
    assert len(pocket) == len(set([residue.res_id for residue in pocket_residues]))
    # get centroid of pocket residues
    return compute_centroid(pocket_residues)
    
def get_atom_array_from_pocket_center(pocket_centers):
    array = AtomArray(len(pocket_centers))
    for i, pocket_center in enumerate(pocket_centers):
        array.coord[i] = pocket_center
    return array

def get_dcc_points(structure, holo_structures):
    # load pocket centers from the dataset
    pocket_centers = []
    for holo_structure in holo_structures:
        apo_pocket = holo_structure['apo_pocket_selection']
        apo_pocket = [residue.split(
            '_')[1] for residue in apo_pocket]
        pocket_centers.append(compute_pocket_center(structure, apo_pocket))

    pocket_centers = get_atom_array_from_pocket_center(pocket_centers)

    # some pocket centers are the same/very close, let's cluster them
    cluster_centers = compute_clusters( pocket_centers, np.array([1] * len(pocket_centers)))
    clustered_centers = []
    for cluster_id in range(-1, max(cluster_centers) + 1):
        # if unclustered, then keep the original pocket center
        if cluster_id == -1:
            clustered_centers.extend(pocket_centers[cluster_centers == cluster_id].coord)
        # else compute the centroid of the cluster
        else:
            this_cluster_centers = pocket_centers[cluster_centers == cluster_id]
            # compute cluster centroid
            cluster_centroid = compute_centroid(this_cluster_centers)
            clustered_centers.append(cluster_centroid)

    return np.array(clustered_centers)

def count_correctly_predicted_pockets(points, clusters_centroids, metric='DCC'):
    CUTOFF_THRESHOLD = 12 if metric == 'DCC' else 4

    # Take only N cluster centroids with the highest average score, where N = len(points)
    N = len(points)
    # Sort clusters_centroids by average_score in descending order
    clusters_centroids_sorted = sorted(clusters_centroids, key=lambda x: x[1], reverse=True)
    # Select top N centroids
    
    top_clusters_centroids = [centroid for centroid, score in clusters_centroids_sorted[:N + K]]
    # Compute the full distance matrix between all cluster centroids and all points

    if len(top_clusters_centroids) == 0:
        print(f"No pockets found; N= {N}, len(clustered_centroids) = {len(clusters_centroids)}")
        return -1

    distances = distance_matrix(points, top_clusters_centroids)

    # loop over each actual pocket and check whether there is a predicted pocket that is close enough (<CUTOFF_THRESHOLD)
    pockets_found = 0
    for i, this_distances in enumerate(distances):
        # get the closest predicted pocket
        closest_pocket = np.argmin(this_distances)
        # check if the distance is less than the cutoff threshold
        if this_distances[closest_pocket] < CUTOFF_THRESHOLD:
            pockets_found += 1
    return pockets_found

def analyze(metric='DCC'):
    # set to 12.0 Angstroms for DCC, 4 Angstroms for DCA
    # see https://jcheminf.biomedcentral.com/articles/10.1186/s13321-024-00923-z
    CUTOFF_THRESHOLD = 12 if metric == 'DCC' else 4

    dataset = load_dataset()

    total_pockets = 0
    total_pockets_found = 0

    for apo_structure, holo_structures in dataset.items():
        chain_id = holo_structures[0]['apo_chain']

        # skip multichain structures
        if '-' in chain_id:
            continue
        
        auth = load_structure(apo_structure)
        auth = auth[
                (auth.chain_id == chain_id) &
                (biotite.structure.filter_peptide_backbone(auth))]

        points = None
        # load pocket centers
        if metric == 'DCC':
            points = get_dcc_points(auth, holo_structures)
        
        # load ligands
        if metric == 'DCA':
            raise NotImplementedError
            ligands = []
            for holo_structure in holo_structures:
                # TODO: align to holo structure (https://www.biotite-python.org/latest/apidoc/biotite.structure.superimpose.html)
                # TODO: use the AffineTransformation value to map the ligand to the apo structure
                ...
            points = np.array(ligands)

        
        protein_id = f'{apo_structure}{chain_id}'

        # filter to get correct chain; filter only for peptides
        auth_residues = auth[get_residue_starts(auth) + 1] # skip the first atom (N), second should be CA
 
        # load predictions and cluster them
        predictions = np.load(f'{PREDICTIONS_PATH}/predictions/{protein_id}.npy')
        predictions_mask = predictions > THRESHOLD
        assert len(predictions) == len(auth_residues), f'Length of predictions ({len(predictions)}) does not match length of auth residues ({len(auth_residues)})'
        
        # filter auth residues to get only the ones that are predicted to bind and cluster them into pockets
        predicted_binding_residues = auth_residues[predictions_mask]
        predicted_binding_scores = predictions[predictions_mask]
        print(len(predicted_binding_residues))
        clusters = compute_clusters(predicted_binding_residues, predictions[predictions_mask])

        clusters_centroids = []
        # loop over each cluster and compute the cluster's center with its average score
        for cluster_id in range(-1, max(clusters) + 1):
            if cluster_id == -1:
                # skip unclustered residues
                continue
            cluster_residues = predicted_binding_residues[clusters == cluster_id]
            # compute cluster centroid and compute the average score
            cluster_centroid = compute_centroid(cluster_residues)
            average_score = np.mean(predicted_binding_scores[clusters == cluster_id])
            
            clusters_centroids.append((cluster_centroid, average_score))

        pockets_found = count_correctly_predicted_pockets(points, clusters_centroids, metric=metric)
        if pockets_found == -1:
            continue

        total_pockets += len(points)
        total_pockets_found += pockets_found

    return total_pockets, total_pockets_found, total_pockets_found / total_pockets

analyze()

18
27
26
29
28
30
20
3
No pockets found; N= 1, len(clustered_centroids) = 0
21
30
29
40
13
23
14
31
31
3
No pockets found; N= 1, len(clustered_centroids) = 0
25
11
17
1
No pockets found; N= 1, len(clustered_centroids) = 0
12
55
1
No pockets found; N= 3, len(clustered_centroids) = 0
24
21
21
21
24
29
12
14
0


ValueError: Found array with 0 sample(s) (shape=(0, 3)) while a minimum of 1 is required by DBSCAN.

## DCC
Distance between the predicted and the real binding site center.

In [None]:
# TODO: DCC


DATASET = 'cryptobench'
CIF_FILES = '/home/vit/Projects/deeplife-project/data/cif_files'
PREDICTIONS_PATH = '/home/vit/Projects/cryptic-nn/data/predictions/ESM2-3B-extended-finetuning'

with open(f'../../datasets/{DATASET}-dataset/folds/test.json', 'r') as json_file:
    dataset = json.load(json_file)

skip = True
for apo_structure, holo_structures in dataset.items():

    # finished analysis at: '5wbmB' structure
    if skip:
        if apo_structure == '5wm9':
            skip = False
        else:
            continue
    # print(f'Processing {apo_structure} ...')
    binding_residues = set()
    chain_id = holo_structures[0]['apo_chain']

    # skip multichain structures
    if '-' in chain_id:
        continue

    for holo_structure in holo_structures:

        apo_pocket = holo_structure['apo_pocket_selection']
        
        new_apo_residues = [residue.split(
            '_')[1] for residue in apo_pocket]

        binding_residues.update(new_apo_residues)

    cif_file_path = rcsb.fetch(apo_structure, "cif", target_path=CIF_FILES)

    cif_file = pdbx.CIFFile.read(cif_file_path)

    auth = get_structure(cif_file, model=1)
    auth = auth[
            (auth.chain_id == chain_id) &
            (biotite.structure.filter_peptide_backbone(auth))]
    
    protein_id = f'{apo_structure}{chain_id}'
    # skip if no residues left
    if len(auth) == 0: 
        print(f'No residues left for {protein_id}')
        continue

    # filter to get correct chain; filter only for peptides
    auth_residues_only = get_residues(auth)

    predictions = np.load(f'{PREDICTIONS_PATH}/predictions/{protein_id}.npy') > 0.5
    
    assert len(predictions) == len(auth_residues_only[0]), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth_residues_only[0])} vs {len(predictions)}"
    predicted_binding_residues = auth_residues_only[0][predictions]


## DCA
Distance between the predicted binding site center and the closest ligand atom.

In [None]:
# TODO: DCA