In [1]:
import torch
from transformers import AutoTokenizer, EsmModel
import torch.nn as nn
import numpy as np

sys.path.append('/home/skrhakv/cryptic-nn/src')
import finetuning_utils

# some constants
MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
MAX_LENGTH = 1024
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
OUTPUT_SIZE = 1
DROPOUT = 0.25

# UPDATE THIS!
MODEL_PATH = "/home/skrhakv/cryptic-nn/final-data/trained-models/multitask-finetuned-model-with-ligysis.pt"

# define the model - if we do not define the model then the loading of the model will fail
class FinetuneESM(nn.Module):
    def __init__(self, esm_model: str) -> None:
        super().__init__()
        self.llm = EsmModel.from_pretrained(esm_model)
        self.dropout = nn.Dropout(DROPOUT)
        self.classifier = nn.Linear(self.llm.config.hidden_size, OUTPUT_SIZE)
        self.plDDT_regressor = nn.Linear(self.llm.config.hidden_size, OUTPUT_SIZE)
        self.distance_regressor = nn.Linear(self.llm.config.hidden_size, OUTPUT_SIZE)

    def forward(self, batch: dict[str, np.ndarray]) -> torch.Tensor:
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
        token_embeddings = self.llm(
            input_ids=input_ids, attention_mask=attention_mask
        ).last_hidden_state
        
        return self.classifier(token_embeddings), self.plDDT_regressor(token_embeddings), self.distance_regressor(token_embeddings)

# load the model
model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model.eval()

with open('/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase/3b6rB.txt', 'r') as f:
    sequence = f.readline()


In [2]:
import sys
sys.path.append('/home/skrhakv/cryptoshow-analysis/src/B-evaluate-cryptoshow')
import eval_utils

DECISION_THRESHOLD = 0.8
DROPOUT = 0.3
LAYER_WIDTH = 256
ESM2_DIM  = 1280 * 2

class CryptoBenchClassifier(nn.Module):
    def __init__(self, input_dim=ESM2_DIM):
        super().__init__()
        self.layer_1 = nn.Linear(in_features=input_dim, out_features=LAYER_WIDTH)
        self.dropout1 = nn.Dropout(DROPOUT)

        self.layer_2 = nn.Linear(in_features=LAYER_WIDTH, out_features=LAYER_WIDTH)
        self.dropout2 = nn.Dropout(DROPOUT)

        self.layer_3 = nn.Linear(in_features=LAYER_WIDTH, out_features=1)

        self.relu = nn.ReLU()

    def forward(self, x):
        # Intersperse the ReLU activation function between layers
        return self.layer_3(self.dropout2(self.relu(self.layer_2(self.dropout1(self.relu(self.layer_1(x)))))))
    

HIGH_SCORE_THRESHOLD = 0.7  # Threshold to consider a point as high score
SMOOTHENED_THRESHOLD = 0.7 # this is defined by the training data - best F1 score was achieved with this threshold 
SMOOTHING_MODEL_PATH = '/home/skrhakv/cryptoshow-analysis/data/C-optimize-smoother/smoother.pt'
smoothing_model = torch.load(SMOOTHING_MODEL_PATH, weights_only=False)

In [3]:
OUTPUT_PATH = '/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase'
pdb_id = '3b6r'
chain_id = 'B'

embedding_path = f'{OUTPUT_PATH}/{pdb_id}{chain_id}_embedding-3B.npy'
coordinates_path = f'{OUTPUT_PATH}/{pdb_id}{chain_id}.npy'

prediction = eval_utils.compute_prediction(
    sequence,
    embedding_path,
    model,
    tokenizer
)

coordinates = np.load(coordinates_path)
clusters = eval_utils.compute_clusters(
        coordinates,
        prediction,
        decision_threshold=DECISION_THRESHOLD,
        method='dbscan',
        eps=3,
        min_samples=1
    )


distance_matrix = eval_utils.compute_distance_matrix(coordinates)

# enhance predicted pockets using the smoothing model
predicted_binding_sites = []
for cluster_label in np.unique(clusters):
    if cluster_label == -1:
        continue
    cluster_residue_indices = np.where(clusters == cluster_label)[0]
    embeddings, indices = eval_utils.process_single_sequence(pdb_id, chain_id, cluster_residue_indices, '/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase/3b6rB-nonfinetuned-embedding.npy', distance_matrix) 
    
    prediction = eval_utils.predict_single_sequence(embeddings, indices, smoothing_model)

    enhanced_residue_indices = np.concatenate((indices[prediction['predictions'] > SMOOTHENED_THRESHOLD], cluster_residue_indices))
    predicted_binding_sites.append(enhanced_residue_indices)

import pickle
with open(f'{OUTPUT_PATH}/{pdb_id}{chain_id}.pkl', 'wb') as f:
    pickle.dump(predicted_binding_sites, f)

predicted_binding_sites

[array([53]),
 array([54]),
 array([57]),
 array([63]),
 array([64]),
 array([277, 278, 279,  53,  65]),
 array([66]),
 array([67]),
 array([278,  68]),
 array([277,  69]),
 array([335,  90]),
 array([122]),
 array([286, 126, 124]),
 array([286, 126]),
 array([234, 185]),
 array([199]),
 array([280, 277]),
 array([284, 278]),
 array([279]),
 array([277, 286, 280]),
 array([126, 286]),
 array([288]),
 array([329, 288, 314, 124, 290]),
 array([314]),
 array([316]),
 array([317]),
 array([315, 318]),
 array([319]),
 array([329]),
 array([329, 286, 290, 314, 124, 331])]

In [None]:
[array([ 67, 195, 198, 199,  53,  54,  55,  57,  56]),
 array([ 65,  67,  68,  69, 195,  22,  23, 277, 278, 279,  53,  54,  63,
         66]),
 array([286, 126, 124]),
 array([286, 230, 124, 126]),
 array([195, 197, 198, 199, 200, 201, 202, 203, 225,  53, 196]),
 array([276,  53,  54,  56,  67,  68, 195, 196, 198, 225, 199, 200, 201]),
 array([230, 222]),
 array([222, 225, 126, 230]),
 array([286,  66,  67, 199, 201, 276, 277, 278, 279, 280]),
 array([280, 124, 126, 286]),
 array([320, 321, 318, 319, 315, 316, 317, 314]),
 array([320, 321, 329, 290, 185, 314, 315,  63, 316, 317, 318, 319])]


## Try-out clustering

In [5]:
CLUSTERING_MODEL_STATE_DICT_PATH = '/home/skrhakv/cryptoshow-analysis/data/F-clustering/clustering.pt'
clustering_model = torch.load(CLUSTERING_MODEL_STATE_DICT_PATH, weights_only=False)
CLUSTERING_DECISION_THRESHOLD = 0.9

### Greedy approach
If the subpocket is predicted as positive with any other subpocket, then add it into the mix

In [6]:
def get_index_in_final_clustering(index, final_clustering):
    for cluster_index, cluster in enumerate(final_clustering):
        if index in cluster:
            return True, cluster_index
    return False, -1

def merge_indices_in_final_clustering(index_1, index_2, final_clustering):
    is_index_1_in_final, index_1_in_final = get_index_in_final_clustering(index_1, final_clustering)
    is_index_2_in_final, index_2_in_final = get_index_in_final_clustering(index_2, final_clustering)
    
    if is_index_1_in_final and is_index_2_in_final:
        if index_1_in_final != index_2_in_final:
            final_clustering[index_1_in_final].extend(final_clustering[index_2_in_final])
            del final_clustering[index_2_in_final]
    elif is_index_1_in_final:
        final_clustering[index_1_in_final].append(index_2)
    elif is_index_2_in_final:
        final_clustering[index_2_in_final].append(index_1)
    else:
        final_clustering.append([index_1, index_2])

    return final_clustering

def get_embedding(embeddings, predicted_binding_site_i, predicted_binding_site_ii):
    embeddings_i = embeddings[predicted_binding_site_i]
    embeddings_ii = embeddings[predicted_binding_site_ii]

    mean_embedding_i = np.mean(embeddings_i, axis=0)
    mean_embedding_ii = np.mean(embeddings_ii, axis=0)

    combined_embedding1 = np.concatenate((mean_embedding_i, mean_embedding_ii), axis=0)
    combined_embedding1 = torch.tensor(combined_embedding1, dtype=torch.float32)

    combined_embedding2 = np.concatenate((mean_embedding_ii, mean_embedding_i), axis=0)
    combined_embedding2 = torch.tensor(combined_embedding2, dtype=torch.float32)

    return (combined_embedding1, combined_embedding2)

def predict(embedding, model):
    combined_embedding1, combined_embedding2 = embedding

    probability1 = torch.sigmoid(model(combined_embedding1.to(DEVICE)))
    probability2 = torch.sigmoid(model(combined_embedding2.to(DEVICE)))

    prediction1 = (probability1 > CLUSTERING_DECISION_THRESHOLD).float()
    prediction2 = (probability2 > CLUSTERING_DECISION_THRESHOLD).float()

    print(probability1.detach().cpu().numpy(), probability2.detach().cpu().numpy())

    # this fails sometimes, but it is a close call, so we will ignore it
    # assert prediction1 == prediction2, "Inconsistent clustering predictions! {}, {}".format(prediction1, prediction2)

    return (prediction1 == 1.0) and (prediction2 == 1.0)

embedding_path = f'/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase/3b6rB-nonfinetuned-embedding.npy'
embeddings = np.load(embedding_path)

final_clustering = []

for i, predicted_binding_site_i in enumerate(predicted_binding_sites):
    was_added = False
    for ii, predicted_binding_site_ii in enumerate(predicted_binding_sites):
        if ii <= i:
            continue

        combined_embedding = get_embedding(embeddings, predicted_binding_site_i, predicted_binding_site_ii)

        should_cluster_together = predict(combined_embedding, clustering_model)
        if should_cluster_together:
            final_clustering = merge_indices_in_final_clustering(i, ii, final_clustering)
            was_added = True
            
    if not was_added:
        final_clustering.append([i])

print('Final clustering:', final_clustering)
final_predicted_binding_sites = []
for cluster in final_clustering:
    this_final_binding_site = []
    for binding_site_index in cluster:
        this_final_binding_site.extend(predicted_binding_sites[binding_site_index])
    final_predicted_binding_sites.append(this_final_binding_site)

final_predicted_binding_sites

[0.73246205] [0.72424614]
[0.9998828] [0.99975115]
[0.9987476] [0.9977325]
[0.82559454] [0.88224685]
[0.9685042] [0.9410501]
[0.9702105] [0.96801555]
[0.42557165] [0.21715511]
[0.9170285] [0.923472]
[0.9931484] [0.9885749]
[0.99744284] [0.99628735]
[0.8916586] [0.9513483]
[0.9979038] [0.9972186]
[0.99803215] [0.99770516]
[0.91990185] [0.889909]
[0.00350576] [0.00565635]
[0.95828855] [0.9394036]
[0.40025485] [0.16760966]
[0.63469607] [0.31079873]
[0.99502796] [0.99186945]
[0.99803215] [0.99770516]
[0.3548842] [0.23747057]
[0.9943773] [0.9891289]
[0.97693765] [0.95589286]
[0.55231136] [0.6516763]
[0.48167482] [0.55293876]
[0.23479061] [0.16633058]
[0.9977977] [0.99440825]
[0.95203006] [0.90067524]
[0.99448293] [0.9913897]
[0.99983394] [0.999708]
[0.99765944] [0.99353576]
[0.5259135] [0.6433668]
[0.96922964] [0.9219403]
[0.9737176] [0.96466863]
[0.30592206] [0.13826908]
[0.93118924] [0.8730327]
[0.9947713] [0.98940337]
[0.9972946] [0.99683446]
[0.93268824] [0.9296092]
[0.99878544] [0.9983

[[53,
  57,
  63,
  277,
  278,
  279,
  53,
  65,
  66,
  278,
  68,
  277,
  69,
  335,
  90,
  286,
  126,
  124,
  286,
  126,
  280,
  277,
  277,
  286,
  280,
  126,
  286,
  329,
  288,
  314,
  124,
  290,
  314,
  319,
  329,
  329,
  286,
  290,
  314,
  124,
  331,
  54,
  122,
  234,
  185,
  64,
  67,
  199,
  284,
  278,
  279,
  288,
  316,
  317,
  315,
  318],
 [199],
 [329, 286, 290, 314, 124, 331]]

### Unified-vote approach
Every subpocket has to agree with adding another subpocket.


In [21]:
embedding_path = f'/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase/3b6rB-nonfinetuned-embedding.npy'
embeddings = np.load(embedding_path)

decisions = {}

VERBOSE = True
final_clustering = []

# this is similar as before
for i, predicted_binding_site_i in enumerate(predicted_binding_sites):
    for ii, predicted_binding_site_ii in enumerate(predicted_binding_sites):
        if ii <= i:
            continue
        if VERBOSE:
            print('Considering merging', i, f'({predicted_binding_site_i})', ii, f'({predicted_binding_site_ii})')
        combined_embedding = get_embedding(embeddings, predicted_binding_site_i, predicted_binding_site_ii)
        should_cluster_together = predict(combined_embedding, clustering_model)
        
        if (i, ii) not in decisions:
            decisions[(i, ii)] = should_cluster_together
            decisions[(ii, i)] = should_cluster_together
        else:
            if decisions[(i, ii)] != should_cluster_together:
                if VERBOSE:
                    print('Inconsistent decision for', i, ii, 'existing:', decisions[(i, ii)], 'new:', should_cluster_together)

        if not should_cluster_together:
            if VERBOSE:
                print('Not merging', i, f'({predicted_binding_site_i})' , ii, f'({predicted_binding_site_ii})', 'based on initial prediction')
        merge = True if should_cluster_together else False # potential merging candidate if the clusters agree
        # here, if we want to merge, we need to check that all subpockets agree:
        if should_cluster_together:
            if VERBOSE:
                print('Initial agreement for merging', i, f'({predicted_binding_site_i})', ii, f'({predicted_binding_site_ii})' f'; check for all members of {i}th cluster')
            # first check if i-th cluster is OK with merging the ii-th subpocket
            is_index_in_final_clustering_i, index_in_final_i = get_index_in_final_clustering(i, final_clustering)
            if VERBOSE:
                print('is_index_in_final_clustering_i:', is_index_in_final_clustering_i, 'index_in_final_i:', index_in_final_i)
            if is_index_in_final_clustering_i: # is the i-th subpocket actually a part of some cluster? Because if not, then we already collected that one vote (and it was positive)
                i_th_cluster = final_clustering[index_in_final_i]
                for subpocket_index in i_th_cluster:
                    if VERBOSE:
                        print('\tChecking member', subpocket_index, 'of cluster', i_th_cluster)
                    if subpocket_index == i:
                        continue
                    # check if all members agree with merging ii-th subpocket
                    member_subpocket = predicted_binding_sites[subpocket_index] # get the residues of the member subpocket
                    print('\tmember_subpocket:', member_subpocket, 'predicted_binding_site_ii:', predicted_binding_site_ii)

                    combined_embedding = get_embedding(embeddings, member_subpocket, predicted_binding_site_ii)
                    should_cluster_together_member = predict(combined_embedding, clustering_model)

                    if VERBOSE:
                        print('\tMember', subpocket_index, 'predicts merging with', ii, ':', should_cluster_together_member)

                    if (subpocket_index, ii) not in decisions:
                        decisions[(subpocket_index, ii)] = should_cluster_together_member
                        decisions[(ii, subpocket_index)] = should_cluster_together_member
                    else:
                        if decisions[(subpocket_index, ii)] != should_cluster_together_member:
                            if VERBOSE: 
                                print('\tInconsistent decision for', subpocket_index, ii, 'existing:', decisions[(subpocket_index, ii)], 'new:', should_cluster_together_member)

                    if not should_cluster_together_member:
                        if VERBOSE:
                            print('\tNot merging', i, ii, 'because of member', subpocket_index, ' in cluster', i_th_cluster)
                        merge = False
                        break
            
            is_index_in_final_clustering_ii, index_in_final_ii = get_index_in_final_clustering(ii, final_clustering)

            if VERBOSE:
                print('check second candidate; check for all members of', ii, 'th cluster')
                print('is_index_in_final_clustering_ii:', is_index_in_final_clustering_ii, 'index_in_final_ii:', index_in_final_ii)

            if is_index_in_final_clustering_ii: # is the ii-th subpocket actually a part of some cluster? Because if not, then we already collected that one vote (and it was positive)
                ii_th_cluster = final_clustering[index_in_final_ii]
                for subpocket_index in ii_th_cluster:
                    if VERBOSE:
                        print('Checking member', subpocket_index, 'of cluster', ii_th_cluster)
                    if subpocket_index == ii:
                        continue
                    # check if all members agree with merging ii-th subpocket

                    member_subpocket = predicted_binding_sites[subpocket_index] # get the residues of the member subpocket
                    print('\tmember_subpocket:', member_subpocket, 'predicted_binding_site_i:', predicted_binding_site_i)

                    combined_embedding = get_embedding(embeddings, member_subpocket, predicted_binding_site_i)
                    should_cluster_together_member = predict(combined_embedding, clustering_model)

                    if VERBOSE:
                        print('\tMember', subpocket_index, 'predicts merging with', i, ':', should_cluster_together_member)

                    if (subpocket_index, i) not in decisions:
                        decisions[(subpocket_index, i)] = should_cluster_together_member
                        decisions[(i, subpocket_index)] = should_cluster_together_member
                    else:
                        if decisions[(subpocket_index, i)] != should_cluster_together_member:
                            if VERBOSE:
                                print('\tInconsistent decision for', subpocket_index, i, 'existing:', decisions[(subpocket_index, i)], 'new:', should_cluster_together_member)

                    if not should_cluster_together_member:
                        if VERBOSE:
                            print('\tNot merging', i, ii, 'because of member', subpocket_index)
                        merge = False
                        break
            if VERBOSE:
                print('All members agree for merging', i, ii, '; final clustering:', final_clustering)
            
        if merge:
            final_clustering = merge_indices_in_final_clustering(i, ii, final_clustering)
            print('Merging', i, ii, final_clustering)

print('Final clustering:', final_clustering)
final_predicted_binding_sites = []
for cluster in final_clustering:
    this_final_binding_site = []
    for binding_site_index in cluster:
        this_final_binding_site.extend(predicted_binding_sites[binding_site_index])
    final_predicted_binding_sites.append(this_final_binding_site)

final_predicted_binding_sites

Considering merging 0 ([53]) 1 ([54])
[0.73246205] [0.72424614]
Not merging 0 ([53]) 1 ([54]) based on initial prediction
Considering merging 0 ([53]) 2 ([57])
[0.9998828] [0.99975115]
Initial agreement for merging 0 ([53]) 2 ([57]); check for all members of 0th cluster
is_index_in_final_clustering_i: False index_in_final_i: -1
check second candidate; check for all members of 2 th cluster
is_index_in_final_clustering_ii: False index_in_final_ii: -1
All members agree for merging 0 2 ; final clustering: []
Merging 0 2 [[0, 2]]
Considering merging 0 ([53]) 3 ([63])
[0.9987476] [0.9977325]
Initial agreement for merging 0 ([53]) 3 ([63]); check for all members of 0th cluster
is_index_in_final_clustering_i: True index_in_final_i: 0
	Checking member 0 of cluster [0, 2]
	Checking member 2 of cluster [0, 2]
	member_subpocket: [57] predicted_binding_site_ii: [63]
[0.99998915] [0.9999747]
	Member 2 predicts merging with 3 : tensor([True], device='cuda:0')
check second candidate; check for all mem

[[53,
  57,
  63,
  277,
  278,
  279,
  53,
  65,
  66,
  278,
  68,
  277,
  69,
  335,
  90,
  286,
  126,
  124,
  286,
  126,
  280,
  277,
  277,
  286,
  280,
  126,
  286,
  329,
  288,
  314,
  124,
  290,
  314,
  319,
  329,
  329,
  286,
  290,
  314,
  124,
  331],
 [54, 122, 234, 185]]

## Try to cluster embeddings
Take the predicted binding residues, and use their embeddings to cluster them.

In [31]:
embedding_path = f'/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase/3b6rB-nonfinetuned-embedding.npy'
embeddings = np.load(embedding_path)

predicted_binding_sites_merged = np.concatenate(predicted_binding_sites)
binding_site_embeddings = embeddings[predicted_binding_sites_merged]

from sklearn.cluster import KMeans
clustering = KMeans(n_clusters=len(predicted_binding_sites_merged) // 15)
labels = clustering.fit_predict(binding_site_embeddings)

final_predicted_binding_sites = [[] for _ in range(np.max(labels) + 1)]
for residue_index, cluster_label in zip(predicted_binding_sites_merged, labels):
    final_predicted_binding_sites[cluster_label].append(residue_index)
final_predicted_binding_sites

[[53,
  54,
  57,
  63,
  64,
  53,
  65,
  66,
  67,
  68,
  69,
  90,
  122,
  126,
  124,
  126,
  234,
  185,
  126,
  329,
  124,
  317,
  315,
  318,
  319,
  329,
  329,
  124,
  331],
 [277, 278, 279, 278, 277, 277, 278, 279, 277, 316],
 [335,
  286,
  286,
  199,
  280,
  284,
  286,
  280,
  286,
  288,
  288,
  314,
  290,
  314,
  286,
  290,
  314]]

### PCA
First reduce the dimensionality of the embedding, then do clustering.

#### Note
This was a bit better I guess, but still there was one cluster that had residues all over the place, and it didn't make sense spatially.

In [None]:
embedding_path = f'/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase/3b6rB-nonfinetuned-embedding.npy'
embeddings = np.load(embedding_path)

predicted_binding_sites_merged = np.concatenate(predicted_binding_sites)
binding_site_embeddings = embeddings[predicted_binding_sites_merged]

from sklearn.decomposition import PCA

pca = PCA(n_components=15, random_state=42)
projections = pca.fit_transform(binding_site_embeddings)

from sklearn.cluster import KMeans
clustering = KMeans(n_clusters=len(predicted_binding_sites_merged) // 15)
labels = clustering.fit_predict(projections)

final_predicted_binding_sites = [[] for _ in range(np.max(labels) + 1)]
for residue_index, cluster_label in zip(predicted_binding_sites_merged, labels):
    final_predicted_binding_sites[cluster_label].append(residue_index)
final_predicted_binding_sites

[[277,
  278,
  279,
  278,
  277,
  234,
  199,
  277,
  278,
  279,
  277,
  329,
  316,
  317,
  315,
  318,
  319,
  329,
  329,
  331],
 [53,
  54,
  57,
  63,
  64,
  53,
  65,
  66,
  67,
  68,
  69,
  335,
  90,
  122,
  126,
  124,
  126,
  185,
  126,
  314,
  124,
  314,
  314,
  124],
 [286, 286, 280, 284, 286, 280, 286, 288, 288, 290, 286, 290]]

### PCAed embeddings + distances
Do the same as before (use PCA on embedding), but merge it with the distances and after then run the clustering.

In [None]:
embedding_path = f'/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase/3b6rB-nonfinetuned-embedding.npy'
coordinates_path = f'{OUTPUT_PATH}/{pdb_id}{chain_id}.npy'
embeddings = np.load(embedding_path)
coordinates = np.load(coordinates_path)

predicted_binding_sites_merged = np.concatenate(predicted_binding_sites)

binding_site_embeddings = embeddings[predicted_binding_sites_merged]
binding_site_coordinates = coordinates[predicted_binding_sites_merged]

from sklearn.decomposition import PCA

pca = PCA(n_components=10, random_state=42)
projections = pca.fit_transform(binding_site_embeddings)

concatenated_features = np.concatenate((projections, binding_site_coordinates), axis=1)


from sklearn.cluster import KMeans
clustering = KMeans(n_clusters=len(predicted_binding_sites_merged) // 15)
labels = clustering.fit_predict(concatenated_features)

final_predicted_binding_sites = [[] for _ in range(np.max(labels) + 1)]
for residue_index, cluster_label in zip(predicted_binding_sites_merged, labels):
    final_predicted_binding_sites[cluster_label].append(residue_index)
final_predicted_binding_sites

(56, 13)

## Sanity check
Do the same on coordinates, compare the results visually

### Note
Visually, this makes much more sense than clustering the embeddings.

In [32]:
coordinates_path = f'{OUTPUT_PATH}/{pdb_id}{chain_id}.npy'
coordinates = np.load(coordinates_path)

predicted_binding_sites_merged = np.concatenate(predicted_binding_sites)
binding_site_coordinates = coordinates[predicted_binding_sites_merged]

from sklearn.cluster import KMeans
clustering = KMeans(n_clusters=len(predicted_binding_sites_merged) // 15)
labels = clustering.fit_predict(binding_site_coordinates)

final_predicted_binding_sites = [[] for _ in range(np.max(labels) + 1)]
for residue_index, cluster_label in zip(predicted_binding_sites_merged, labels):
    final_predicted_binding_sites[cluster_label].append(residue_index)
final_predicted_binding_sites

[[122,
  286,
  126,
  124,
  286,
  126,
  234,
  185,
  286,
  126,
  286,
  288,
  329,
  288,
  314,
  124,
  290,
  314,
  316,
  317,
  315,
  318,
  319,
  329,
  329,
  286,
  290,
  314,
  124,
  331],
 [277, 278, 279, 278, 277, 335, 90, 280, 277, 284, 278, 279, 277, 280],
 [53, 54, 57, 63, 64, 53, 65, 66, 67, 68, 69, 199]]

# Things to try:
1. change clustering decision threshold (0.9)
2. design train data in reverse (now I stack them (v1, v2), do (v2, v1) as well) to do data augmenting


In [None]:
# Considering merging 0 ([53]) 1 ([54])
# [0.73246205] [0.72424614]
# Not merging 0 ([53]) 1 ([54]) based on initial prediction


# Initial agreement for merging 1 ([54]) 3 ([63]); check for all members of 1th cluster
# is_index_in_final_clustering_i: True index_in_final_i: 0
# 	Checking member 0 of cluster [0, 2, 3, 5, 6, 8, 9, 10, 12, 13, 16, 19, 20, 22, 23, 27, 28, 29, 1]
# 	member_subpocket: [53] predicted_binding_site_ii: [63]
# [0.9987476] [0.9977325]


# Considering merging 0 ([53]) 1 ([54])
# [0.73246205] [0.72424614]
# TODO: why not merging when next to each other?