In [1]:
import sys
import pymol
_stdouterr = sys.stdout, sys.stderr
pymol.finish_launching(['/usr/bin/pymol', '-q'])
sys.stdout, sys.stderr = _stdouterr
DECISION_THRESHOLD = 0.7

# load something into the PyMOL window
from pymol import cmd


## Clustering
Cluster the predictions into pockets.

In [36]:
import numpy as np
from sklearn.cluster import DBSCAN, AgglomerativeClustering
from biotite.structure import sasa, AtomArray

EPSILON = 5  # Max distance for neighbors (adjust as needed)
MIN_SAMPLES = 5  # Minimum points to form a cluster (adjust as needed)
SASA_THRESHOLD = 0.5  # SASA threshold for filtering points (adjust as needed)
DATASET = 'cryptobench'
CIF_FILES = '/home/vit/Projects/deeplife-project/data/cif_files'
PREDICTIONS_PATH = '/home/vit/Projects/cryptoshow-analysis/data/D-visualize/predictions/finetuning-without-smoothing'
COLORS = ['grey', 'green', 'purple', 'purpleblue', 'raspberry', 'ruby', 'salmon', 'sand', 'skyblue', 'slate', 'smudge', 'splitpea', 'sulfur', 'teal', 'tv_blue', 'tv_green', 'tv_orange', 'tv_red', 'tv_yellow']


def compute_clusters(points: AtomArray, prediction_scores: np.array, check_sasa=False):
    # This function computes clusters for the given points and prediction scores
    points_array = points.coord
    scores_array = prediction_scores

    assert len(points_array) == len(scores_array), f"Length of points and scores do not match: {len(points_array)} vs {len(scores_array)}"
    high_score_mask = scores_array > DECISION_THRESHOLD 

    if check_sasa:
        sasa_values = sasa(points)
        sasa_mask = sasa_values > SASA_THRESHOLD
        high_score_mask = high_score_mask & sasa_mask
        
    high_score_points = points_array[high_score_mask]

    dbscan = DBSCAN(eps=EPSILON, min_samples=MIN_SAMPLES)
    # dbscan = AgglomerativeClustering(distance_threshold=EPSILON, n_clusters=None, linkage='single')
    labels = dbscan.fit_predict(high_score_points)

    # Initialize all labels to -1
    all_labels = -1 * np.ones(len(points), dtype=int)
    # Assign cluster labels to high score points
    all_labels[high_score_mask] = labels
    labels = all_labels

    return labels


In [3]:
apo_structure = '9atc'
chain_id = 'A'

import json
import sys, os
import numpy as np

from biotite.structure import get_residues, get_residue_starts, spread_residue_wise
from biotite.structure.io.pdbx import get_structure
import biotite.structure.io.pdbx as pdbx
import biotite.database.rcsb as rcsb
import biotite


cif_file_path = rcsb.fetch(apo_structure, "cif", target_path=CIF_FILES)

cif_file = pdbx.CIFFile.read(cif_file_path)

auth = get_structure(cif_file, model=1, extra_fields=['atom_id'])
auth = auth[
        (auth.chain_id == chain_id) &
        (biotite.structure.filter_amino_acids(auth))]

protein_id = f'{apo_structure}{chain_id}'

# skip if no residues left
if len(auth) == 0: 
    print(f'No residues left for {protein_id}')

# filter to get correct chain; filter only for peptides
auth_residues_only = auth[get_residue_starts(auth)]

residue_wise_predictions = np.load(f'{PREDICTIONS_PATH}/{protein_id}.npy')
assert len(residue_wise_predictions) == len(auth_residues_only), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth_residues_only)} vs {len(residue_wise_predictions)}"

atom_wise_predictions = spread_residue_wise(auth, residue_wise_predictions)
assert len(atom_wise_predictions) == len(auth), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth)} vs {len(atom_wise_predictions)}"
np.set_printoptions(threshold=sys.maxsize)
clusters = compute_clusters(auth, atom_wise_predictions, check_sasa=True)

cmd.reinitialize()
cmd.set('fetch_path', cmd.exp_path(CIF_FILES), quiet=0)
cmd.fetch(protein_id)
cmd.zoom(protein_id)
cmd.color('grey', protein_id)

for i in range(-1, max(clusters) + 1):
    cluster_seq_labels = auth.atom_id[clusters == i]
    if i == -1 and len(cluster_seq_labels) == 0:
        continue
    cmd.color(COLORS[i + 1], f'{protein_id} and id {"+".join([str(i) for i in cluster_seq_labels])}')

cmd.show('surface', protein_id)
decision = input(">Press Enter for the next protein...\n")

 Setting: fetch_path set to /home/vit/Projects/deeplife-project/data/cif_files.
 ExecutiveLoad-Detail: Detected mmCIF


In [None]:
import json
import sys, os
import numpy as np

from biotite.structure import get_residues, get_residue_starts
from biotite.structure.io.pdbx import get_structure
import biotite.structure.io.pdbx as pdbx
import biotite.database.rcsb as rcsb
import biotite


DATASET = 'cryptobench'
CIF_FILES = '/home/vit/Projects/deeplife-project/data/cif_files'
COLORS = ['green', 'pink', 'purple', 'purpleblue', 'raspberry', 'ruby', 'salmon', 'sand', 'skyblue', 'slate', 'smudge', 'splitpea', 'sulfur', 'teal', 'tv_blue', 'tv_green', 'tv_orange', 'tv_red', 'tv_yellow']
with open(f'../../datasets/{DATASET}-dataset/folds/test.json', 'r') as json_file:
    dataset = json.load(json_file)

skip = False
# skip = True

for apo_structure, holo_structures in dataset.items():

    # finished analysis at: '5wbmB' structure
    if skip:
        if apo_structure == '5wm9':
            skip = False
        else:
            continue
    chain_id = holo_structures[0]['apo_chain']

    # skip multichain structures
    if '-' in chain_id:
        continue


    cif_file_path = rcsb.fetch(apo_structure, "cif", target_path=CIF_FILES)

    cif_file = pdbx.CIFFile.read(cif_file_path)

    auth = get_structure(cif_file, model=1)
    auth = auth[
            (auth.chain_id == chain_id) &
            (biotite.structure.filter_amino_acids(auth))]
    
    protein_id = f'{apo_structure}{chain_id}'
    # skip if no residues left
    if len(auth) == 0: 
        print(f'No residues left for {protein_id}')
        continue

    # filter to get correct chain; filter only for peptides
    auth_residues_only = auth[get_residue_starts(auth)]

    predictions = np.load(f'{PREDICTIONS_PATH}/predictions/{protein_id}.npy') > DECISION_THRESHOLD
    
    assert len(predictions) == len(auth_residues_only), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth_residues_only)} vs {len(predictions)}"
    predicted_binding_residues = auth_residues_only[predictions]
    predicted_binding_residue_coords, predicted_binding_residue_auth_labels = predicted_binding_residues.coord, predicted_binding_residues.res_id

    clusters = compute_clusters(predicted_binding_residues, predictions[predictions], check_sasa=True)
    
    cmd.reinitialize()
    cmd.set('fetch_path', cmd.exp_path(CIF_FILES), quiet=0)
    cmd.fetch(protein_id)
    cmd.zoom(protein_id)
    cmd.color('grey', protein_id)

    for i in range(-1, max(clusters) + 1):
        cluster_residue_auth_labels = predicted_binding_residue_auth_labels[clusters == i]
        if i == -1 and len(cluster_residue_auth_labels) == 0:
            continue
        cmd.color(COLORS[i + 1], f'{protein_id} and resi {"+".join([str(i) for i in cluster_residue_auth_labels])}')
        
    cmd.show('surface', protein_id)
    decision = input(">Press Enter for the next protein...\n")
    if decision.lower() == 'q':
        break


 Setting: fetch_path set to /home/vit/Projects/deeplife-project/data/cif_files.
 ExecutiveLoad-Detail: Detected mmCIF


### Use all atoms
In the previous approach we only used the C-alpha atoms, let's try it using all residues.

In [None]:
import json
import sys, os
import numpy as np

from biotite.structure import get_residues, get_residue_starts, spread_residue_wise
from biotite.structure.io.pdbx import get_structure
import biotite.structure.io.pdbx as pdbx
import biotite.database.rcsb as rcsb
import biotite


DATASET = 'cryptobench'
CIF_FILES = '/home/vit/Projects/deeplife-project/data/cif_files'
COLORS = ['blue', 'red', 'pink', 'purple', 'purpleblue', 'raspberry', 'ruby', 'salmon', 'sand', 'skyblue', 'slate', 'smudge', 'splitpea', 'sulfur', 'teal', 'tv_blue', 'tv_green', 'tv_orange', 'tv_red', 'tv_yellow']

with open(f'/home/vit/Projects/cryptoshow-analysis/datasets/cryptobench-dataset/folds/test.json', 'r') as json_file:
    dataset = json.load(json_file)

skip = False
# skip = True

for apo_structure, holo_structures in dataset.items():

    # finished analysis at: '5wbmB' structure
    if skip:
        if apo_structure == '5wm9':
            skip = False
        else:
            continue
    chain_id = holo_structures[0]['apo_chain']

    # skip multichain structures
    if '-' in chain_id:
        continue

    cif_file_path = rcsb.fetch(apo_structure, "cif", target_path=CIF_FILES)

    cif_file = pdbx.CIFFile.read(cif_file_path)

    auth = get_structure(cif_file, model=1, extra_fields=['atom_id'])
    auth = auth[
            (auth.chain_id == chain_id) &
            (biotite.structure.filter_amino_acids(auth))]

    protein_id = f'{apo_structure}{chain_id}'

    # skip if no residues left
    if len(auth) == 0: 
        print(f'No residues left for {protein_id}')

    # filter to get correct chain; filter only for peptides
    auth_residues_only = auth[get_residue_starts(auth)]

    residue_wise_predictions = np.load(f'{PREDICTIONS_PATH}/predictions/{protein_id}.npy')
    assert len(residue_wise_predictions) == len(auth_residues_only), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth_residues_only)} vs {len(residue_wise_predictions)}"

    atom_wise_predictions = spread_residue_wise(auth, residue_wise_predictions)
    assert len(atom_wise_predictions) == len(auth), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth)} vs {len(atom_wise_predictions)}"

    clusters = compute_clusters(auth, atom_wise_predictions, check_sasa=True)
    
    cmd.reinitialize()
    cmd.set('fetch_path', cmd.exp_path(CIF_FILES), quiet=0)
    cmd.fetch(protein_id)
    cmd.zoom(protein_id)
    cmd.color('grey', protein_id)

    for i in range(-1, max(clusters) + 1):
        cluster_seq_labels = auth.atom_id[clusters == i]
        if i == -1 and len(cluster_seq_labels) == 0:
            continue
        cmd.color(COLORS[i + 1], f'{protein_id} and id {"+".join([str(i) for i in cluster_seq_labels])}')

    cmd.show('surface', protein_id)
    decision = input(">Press Enter for the next protein...\n")
    if decision == 'q':
        break

 Setting: fetch_path set to /home/vit/Projects/deeplife-project/data/cif_files.
 ExecutiveLoad-Detail: Detected mmCIF


## Sphere around each residue
Draw a small sphere around each residue. If the sphere contains more than `N` binding residues, add the residue to the pocket.

In [None]:
import numpy as np
import json

from biotite.structure import get_residues, get_residue_starts, spread_residue_wise
from biotite.structure.io.pdbx import get_structure
import biotite.structure.io.pdbx as pdbx
import biotite.database.rcsb as rcsb    
import biotite
# 3h8a
RADIUS = 5
NUMBER_OF_POINTS = 7
ITERATIONS = 2
SASA_THRESHOLD = 0.1
def spread_using_atom_spheres(points: AtomArray, clusters: np.array, check_sasa=True) -> np.array:
    if check_sasa:
        sasa_values = sasa(points)

    # do a few iterations:
    for iteration in range(ITERATIONS):
        for i, point in enumerate(points):
            # Skip if the point is already assigned to a cluster
            if clusters[i] != -1:
                continue
            
            if check_sasa and sasa_values[i] < SASA_THRESHOLD:
                continue
            
            # Get the coordinates of the point
            coords = point.coord
        
            # get all atoms inside the sphere
            additional_atoms_mask = np.linalg.norm(coords - points.coord, axis=1) < RADIUS
    
            for cluster in np.unique(clusters):
                if cluster == -1:
                    continue
                # Get the indices of the points in the current cluster
                cluster_indices = np.where(clusters == cluster)[0]
                # Check if there are at least N points inside the sphere that belong to the current cluster
                if np.sum(additional_atoms_mask[cluster_indices]) >= NUMBER_OF_POINTS:
                    # Assign the cluster label to the point
                    clusters[i] = cluster
                    break
    return clusters

with open(f'/home/vit/Projects/cryptoshow-analysis/datasets/cryptobench-dataset/folds/test.json', 'r') as json_file:
    dataset = json.load(json_file)

for apo_structure, holo_structures in dataset.items():
    chain_id = holo_structures[0]['apo_chain']
    # if apo_structure != '7e5q':
    #     continue
    if '-' in chain_id:
        continue
    cif_file_path = rcsb.fetch(apo_structure, "cif", target_path=CIF_FILES)
    cif_file = pdbx.CIFFile.read(cif_file_path)
    
    auth = get_structure(cif_file, model=1, extra_fields=['atom_id'])
    auth = auth[
            (auth.chain_id == chain_id) &
            (biotite.structure.filter_amino_acids(auth))]
    
    protein_id = f'{apo_structure}{chain_id}'
    # skip if no residues left
    if len(auth) == 0: 
        print(f'No residues left for {protein_id}')
        # continue
    
    auth_residues_only = auth[get_residue_starts(auth)]
        
    residue_wise_predictions = np.load(f'{PREDICTIONS_PATH}/{protein_id}.npy')
    assert len(residue_wise_predictions) == len(auth_residues_only), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth_residues_only)} vs {len(residue_wise_predictions)}"
    atom_wise_predictions = spread_residue_wise(auth, residue_wise_predictions)
    assert len(atom_wise_predictions) == len(auth), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth)} vs {len(atom_wise_predictions)}"
    
    clusters = compute_clusters(auth, atom_wise_predictions, check_sasa=True)
    clusters = spread_using_atom_spheres(auth, clusters)
    
    # load N-array of probabilities for each residue
    residues = get_residues(auth)
    predicted_binding_residues = residues[0][residue_wise_predictions > DECISION_THRESHOLD]

    cmd.reinitialize()
    cmd.set('fetch_path', cmd.exp_path(CIF_FILES), quiet=0)
    cmd.fetch(protein_id)
    cmd.zoom(protein_id)
    cmd.color('grey', protein_id)

    for i in range(-1, max(clusters) + 1):
        cluster_seq_labels = auth.atom_id[clusters == i]
        if i == -1 or len(cluster_seq_labels) == 0:
            continue
        cmd.color('green', f'{protein_id} and id {"+".join([str(i) for i in cluster_seq_labels])}')
    
    cmd.color('blue',f'{protein_id} and resi {"+".join([str(i) for i in predicted_binding_residues])}')    
    cmd.show('surface', protein_id)
    decision = input(">Press Enter for the next protein...\n")
    if decision == 'q':
        break

 Setting: fetch_path set to /home/vit/Projects/deeplife-project/data/cif_files.
 ExecutiveLoad-Detail: Detected mmCIF
 Setting: fetch_path set to /home/vit/Projects/deeplife-project/data/cif_files.
 ExecutiveLoad-Detail: Detected mmCIF


## Sphere around the pocket's centre
Draw a sphere around the pocket's centre and include all residues inside the sphere into the pocket. 

In [42]:
import json
import numpy as np

from biotite.structure import apply_residue_wise, get_residue_starts
from biotite.structure.io.pdbx import get_structure
import biotite.structure.io.pdbx as pdbx
import biotite.database.rcsb as rcsb
import biotite

SPHERE_RADIUS_RATIO = 2

with open(f'/home/vit/Projects/cryptoshow-analysis/datasets/cryptobench-dataset/folds/test.json', 'r') as json_file:
    dataset = json.load(json_file)

skip = True

for apo_structure, holo_structures in dataset.items():

    # finished analysis at: '5wbmB' structure
    if skip:
        if apo_structure == '4oqo':
            skip = False
        else:
            continue
    chain_id = holo_structures[0]['apo_chain']

    # skip multichain structures
    if '-' in chain_id:
        continue

    cif_file_path = rcsb.fetch(apo_structure, "cif", target_path=CIF_FILES)

    cif_file = pdbx.CIFFile.read(cif_file_path)

    auth = get_structure(cif_file, model=1)
    auth = auth[
            (auth.chain_id == chain_id) &
            (biotite.structure.filter_amino_acids(auth))]
    
    protein_id = f'{apo_structure}{chain_id}'
    # skip if no residues left
    if len(auth) == 0: 
        print(f'No residues left for {protein_id}')
        continue

    auth_residues_only = auth[get_residue_starts(auth)]

    prediction_probabilities = np.load(f'{PREDICTIONS_PATH}/{protein_id}.npy') 
    predictions = prediction_probabilities > DECISION_THRESHOLD

    assert len(predictions) == len(auth_residues_only), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth_residues_only)} vs {len(predictions)}"

    if sum(predictions) == 0:
        continue
    # -1 denotes residues that are not part of any cluster
    clusters = compute_clusters(auth_residues_only, prediction_probabilities)

    cmd.reinitialize()
    cmd.set('fetch_path', cmd.exp_path(CIF_FILES), quiet=0)
    cmd.fetch(protein_id)
    cmd.zoom(protein_id)
    cmd.color('grey', protein_id)

    if max(clusters) == -1:
        continue
    
    for cluster_index in range(-1, max(clusters) + 1):
        if cluster_index == -1:
            continue

        cluster_coords = auth_residues_only.coord[clusters == cluster_index]
        centroid = np.mean(cluster_coords, axis=0)

        # get radius of the sphere
        radius = np.max(np.linalg.norm(cluster_coords - centroid, axis=1)) * SPHERE_RADIUS_RATIO
        # get all atoms inside the sphere
        additional_atoms_mask = np.linalg.norm(centroid - auth.coord, axis=1) < radius
        # map the atom-wise mask to residue-wise
        additional_residues_mask = apply_residue_wise(auth, additional_atoms_mask, np.any)
        assert len(additional_residues_mask) == len(auth_residues_only), f"Length of auth residues and additional residues mask do not match for {protein_id}: {len(auth_residues_only)} vs {len(additional_residues_mask)}"
        
        # additional residue auth labels
        additional_residue_auth_labels = auth_residues_only[additional_residues_mask].res_id
        # original cluster auth labels
        cluster_residue_auth_labels = auth_residues_only[clusters == cluster_index].res_id
        if sum(additional_atoms_mask) > 0:
            cmd.color('green', f'{protein_id} and resi {"+".join([str(i) for i in additional_residue_auth_labels])}')
        cmd.color('blue', f'{protein_id} and resi {"+".join([str(i) for i in cluster_residue_auth_labels])}')
    cmd.show('surface', protein_id) 
    decision = input(">Press Enter for the next protein...\n")
