In [2]:
import sys
import pymol
_stdouterr = sys.stdout, sys.stderr
pymol.finish_launching(['/usr/bin/pymol', '-q'])
sys.stdout, sys.stderr = _stdouterr

# load something into the PyMOL window
from pymol import cmd


Gtk-Message: 16:56:03.020: Failed to load module "pk-gtk-module"
Cannot open file '/home/vit/Projects/cryptic-nn/src/prediction-analysis/data/pymol/icons/icon2.svg', because: No such file or directory
Cannot open file '/home/vit/Projects/cryptic-nn/src/prediction-analysis/data/pymol/icons/icon2.svg', because: No such file or directory


Could not read PyMOL stylesheet.
DEBUG: PYMOL_DATA='./data'
 Detected OpenGL version 4.6. Shaders available.
 Geometry shaders not available
 Detected GLSL version 4.60.
 Setting: fetch_path set to /home/vit/Projects/deeplife-project/data/cif_files.
 ExecutiveLoad-Detail: Detected mmCIF


## Clustering
Cluster the predictions into pockets.

In [3]:
import numpy as np
from sklearn.cluster import DBSCAN, AgglomerativeClustering

EPSILON = 7  # Max distance for neighbors (adjust as needed)
MIN_SAMPLES = 3  # Minimum points to form a cluster         (adjust as needed)

def compute_clusters(points: list[list[float]], prediction_scores: list[float]):
    # This function computes clusters for the given points and prediction scores
    points_array = np.array(points)
    scores_array = np.array(prediction_scores).reshape(-1, 1)
    stacked = np.hstack((points_array, scores_array))  # Combine coordinates with scores

    high_score_mask = stacked[:, 3] > 0.65  # TODO: tweak this
    high_score_points = stacked[high_score_mask][:, :3]  # Extract only (x, y, z) coordinates

    dbscan = DBSCAN(eps=EPSILON, min_samples=MIN_SAMPLES)
    # dbscan = AgglomerativeClustering(distance_threshold=EPSILON, n_clusters=None, linkage='single')
    labels = dbscan.fit_predict(high_score_points)

    # Initialize all labels to -1
    all_labels = -1 * np.ones(len(points), dtype=int)
    # Assign cluster labels to high score points
    all_labels[high_score_mask] = labels
    labels = all_labels

    return labels

In [None]:
import json
import sys, os
import numpy as np

from biotite.structure import get_residues, get_residue_starts
from biotite.structure.io.pdbx import get_structure
import biotite.structure.io.pdbx as pdbx
import biotite.database.rcsb as rcsb
import biotite


DATASET = 'cryptobench'
CIF_FILES = '/home/vit/Projects/deeplife-project/data/cif_files'
PREDICTIONS_PATH = '/home/vit/Projects/cryptic-nn/data/predictions/ESM2-3B-extended-finetuning'
COLORS = ['red', 'pink', 'purple', 'purpleblue', 'raspberry', 'ruby', 'salmon', 'sand', 'skyblue', 'slate', 'smudge', 'splitpea', 'sulfur', 'teal', 'tv_blue', 'tv_green', 'tv_orange', 'tv_red', 'tv_yellow']
with open(f'../../datasets/{DATASET}-dataset/folds/test.json', 'r') as json_file:
    dataset = json.load(json_file)

skip = False
# skip = True

for apo_structure, holo_structures in dataset.items():

    # finished analysis at: '5wbmB' structure
    if skip:
        if apo_structure == '5wm9':
            skip = False
        else:
            continue
    chain_id = holo_structures[0]['apo_chain']

    # skip multichain structures
    if '-' in chain_id:
        continue


    cif_file_path = rcsb.fetch(apo_structure, "cif", target_path=CIF_FILES)

    cif_file = pdbx.CIFFile.read(cif_file_path)

    auth = get_structure(cif_file, model=1)
    auth = auth[
            (auth.chain_id == chain_id) &
            (biotite.structure.filter_peptide_backbone(auth))]
    
    protein_id = f'{apo_structure}{chain_id}'
    # skip if no residues left
    if len(auth) == 0: 
        print(f'No residues left for {protein_id}')
        continue

    # filter to get correct chain; filter only for peptides
    auth_residues_only = auth[get_residue_starts(auth)]

    predictions = np.load(f'{PREDICTIONS_PATH}/predictions/{protein_id}.npy') > 0.5
    
    assert len(predictions) == len(auth_residues_only), f"Length of auth residues and predictions do not match for {protein_id}: {len(auth_residues_only)} vs {len(predictions)}"
    predicted_binding_residues = auth_residues_only[predictions]
    predicted_binding_residue_coords, predicted_binding_residue_auth_labels = predicted_binding_residues.coord, predicted_binding_residues.res_id

    clusters = compute_clusters(predicted_binding_residue_coords, predictions[predictions])
    
    cmd.reinitialize()
    cmd.set('fetch_path', cmd.exp_path(CIF_FILES), quiet=0)
    cmd.fetch(protein_id)
    cmd.zoom(protein_id)
    cmd.color('grey', protein_id)

    for i in range(-1, max(clusters) + 1):
        cluster_residue_auth_labels = predicted_binding_residue_auth_labels[clusters == i]
        if i == -1 and len(cluster_residue_auth_labels) == 0:
            continue
        cmd.color(COLORS[i + 1], f'{protein_id} and resi {"+".join([str(i) for i in cluster_residue_auth_labels])}')
        
    cmd.show('surface', protein_id)
    input(">Press Enter for the next protein...\n")
