# Pose clustering of docking results
## 0. Information and code from:
<a id = 'references'></a>
1. https://chem-workflows.com/articles/2019/06/24/pose-clustering-of-docking-results/
2. https://joernhees.de/blog/2015/08/26/scipy-hierarchical-clustering-and-dendrogram-tutorial/
3. https://www.analyticsvidhya.com/blog/2021/06/single-link-hierarchical-clustering-clearly-explained/
4. http://cda.psych.uiuc.edu/multivariate_fall_2013/matlab_help/cluster_analysis.pdf
5. https://chem-workflows.com/articles/2019/07/18/building-a-multi-molecule-mol2-reader-for-rdkit/

## 1. Libraries

In [None]:
import os

# RDKit
from rdkit import Chem
from rdkit.Chem import rdFMCS,AllChem, Draw
from rdkit.Chem.Draw import DrawingOptions
from rdkit.Chem.Draw import IPythonConsole

# Other powerful Python libraries
import numpy as np
import math
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import gridspec
import seaborn as sns

# Clustering
from scipy import spatial
from scipy.cluster import hierarchy
from scipy.cluster.hierarchy import dendrogram, linkage, cophenet, fcluster, inconsistent
from scipy.spatial.distance import pdist

## 2. Fix wonky connectivity and visualize
Connectivity fix from: https://bioinformatics.stackexchange.com/questions/15877/is-it-possible-to-correct-bond-order-and-conectivity-problems-in-an-sdf-file-acc

In [None]:
%%capture --no-display

goodconn = Chem.MolFromSmiles() # Ligand SMILES to fix
                                # wonky connectivity

poses = Chem.SDMolSupplier('poses.sdf') # File with all docking poses in SDF format
renamed_poses = []

for index,p in enumerate(poses):
    
    fixed_mol = AllChem.AssignBondOrdersFromTemplate(goodconn, p) # This fixes the wonky connectivity from the
                                                                  # pdbqt -> mol2 conversion
    fixed_mol.SetProp('_Name', str(index+1))
    renamed_poses.append(fixed_mol)
        
# If all molecules are correct (sanitized), Draw.MolsToGridImage must work 
img = Draw.MolsToGridImage(renamed_poses,
                           molsPerRow = (5),
                           legends = [i.GetProp('_Name') for i in renamed_poses],
                           useSVG = False,
                           maxMols = 100)
img

## 3. RMSD in place calculation: using the RMSD formula from Wikipedia

In [None]:
size=len(renamed_poses)
hmap=np.empty(shape=(size,size))
table=pd.DataFrame()

for i,mol in enumerate(renamed_poses):
    for j,jmol in enumerate(renamed_poses):   
        # Substructure matching + RMSD calculation
        rmsd = AllChem.CalcRMS(jmol, mol)
        #saving the rmsd values to a matrix and a table for clustering
        hmap[i,j]=rmsd
        table.loc[mol.GetProp('_Name'),jmol.GetProp('_Name')]=rmsd

# Check if the distance matrix is symmetric
def check_symmetric(a, rtol=1e-010, atol=1e-010):
    return np.allclose(a, a.T, rtol=rtol, atol=atol)

check_symmetric(hmap)

## 4. Clustering of the poses: [SciPy average linkage](https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html#scipy.cluster.hierarchy.linkage) algorithm

In [None]:
# I will use the matrix that I generated before
hmap_condensed = spatial.distance.squareform(hmap, checks=False) # Transform uncondensed to condensed matrix
                                                                 # checks=False because I am sure the matrix is symmetrical
linked = linkage(hmap_condensed,'average')
labelList = [mol.GetProp('_Name') for mol in renamed_poses] # Labels for the dendrogram

# Plot dendrogram
plt.figure(figsize=(10,10))

ax1=plt.subplot()
o=dendrogram(linked,  
            orientation='left',
            labels=labelList,
            distance_sort='descending',
            show_leaf_counts=True,
            link_color_func=lambda k: 'black' # Make the dendrogram black
            )

ax1.spines['left'].set_visible(False)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
plt.title('Pose clustering',fontsize=20,weight='bold')
plt.tick_params ('both',width=2,labelsize=10)
plt.tight_layout()
plt.show() 

The `color_threshold` argument of `dendrogram()` automatically picks a distance cut-off value of 70% of the final merge and colors the first clusters below that in different colors.

The cophenetic correlation coefficient (very very briefly) compares (correlates) the actual pairwise distances of all your samples to those implied by the hierarchical clustering. The closer the value is to 1, the better the clustering preserves the original distances.

In [None]:
# Cophenetic correlation
c, coph_dists = cophenet(linked, pdist(hmap))
c

## 5. Determining the number of clusters
### 5.1. Automatic cut-off selection
While manual selection of a cut-off value offers a lot of benefits when it comes to checking for a meaningful clustering and cut-off, there are cases in which you want to automate this. The problem again is that there is no golden method to pick the number of clusters for all cases.

The default method in `SciPy` is the inconsistency method. This method compares each merge's height, $h$, to the average heigth, $avg$, and normalizes it by the standard deviation, $std$, of the $depth$ previous levels.

$$
inconsistency = \frac{h - avg}{std}
$$

Based on these results, I could pick a $depth$ and an $inconsistency$ limit to generate the clusters. Choosing a low $inconsistency$ limit will generate more clusters than choosing a high $inconsistency$ limit.

### 5.2. Manual cut-off selection
The inconsistency method is not ideal because it is very sensitive to the $depth$ parameter and the $inconsistency$ limit we set to create our clusters. There is no universal and consistent way to pick these parameters.

Another problem in its calculation is that the previous $depth$ levels' heights aren't normally distributed, but expected to increase, so you can't really just calculate a normalized [$z$-score](https://en.wikipedia.org/wiki/Standard_score).

So, I will just manually set a cutoff distance.

#### Dendrogram

In [None]:
# Get clusters
max_d = 3 # Cutoff distance of 3
clusters = fcluster(linked, max_d, criterion='distance')
k = max(clusters)

# Set color palette
color_palette = list(sns.color_palette("Spectral", k).as_hex())
hierarchy.set_link_color_palette(color_palette)

# Plot dendrogram
plt.figure(figsize=(8,8))

ax1=plt.subplot()
o=dendrogram(linked,  
            orientation='left',
            labels=labelList,
            distance_sort='descending',
            show_leaf_counts=True,
            color_threshold = max_d
            )

ax1.spines['left'].set_visible(False)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
plt.title('Pose clustering',fontsize=20,weight='bold')
plt.tick_params ('both',width=2,labelsize=10)
plt.axvline(x=max_d, c='grey', lw=1, linestyle='dashed')
plt.tight_layout()
plt.show() 

#### Heatmap

In [None]:
# This will give us the clusters in order as the last plot
new_data=list(reversed(o['ivl']))

# we create a new table with the order of HCL
hmap_2=np.empty(shape=(size,size))
for index,i in enumerate(new_data):
    for jndex,j in enumerate(new_data):
        hmap_2[index,jndex]=table.loc[i].at[j]

figure= plt.figure(figsize=(30,30))
gs1 = gridspec.GridSpec(2,7)
gs1.update(wspace=0.01)
ax1 = plt.subplot(gs1[0:-1, :2])
dendrogram(linked,
           orientation='left',
           distance_sort='descending',
           show_leaf_counts=True,
           no_labels=True,
           color_threshold = max_d
)
ax1.spines['left'].set_visible(False)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
plt.axvline(x=max_d, c='grey', lw=1, linestyle='dashed')

ax2 = plt.subplot(gs1[0:-1,2:6])
f=ax2.imshow (hmap_2, cmap='magma', interpolation='nearest')

ax2.set_title('Pose clustering',fontsize=20,weight='bold')
ax2.set_xticks (range(len(new_data)))
ax2.set_yticks (range(len(new_data)))
ax2.set_xticklabels (new_data,rotation=90)
ax2.set_yticklabels (new_data)

ax3 = plt.subplot(gs1[0:-1,6:7])
m=plt.colorbar(f,cax=ax3,shrink=0.75,orientation='vertical',spacing='uniform',pad=0.01)
m.set_label ('RMSD')

plt.tick_params ('both',width=2,labelsize=9)
plt.plot()

In [None]:
## Keep for analysis only clusters that contain more than X (e.g. 5) poses
# Big clusters to keep
bigclust = set()

for pos, clustnum in enumerate(clusters):
    membs = sum(clusters == clustnum)
    if membs >= 5:
        bigclust.add(clustnum)

## 6. Analyze results
### 6.1. Docking score of each cluster

In [None]:
# Make a cluster -> [poses] dictionary
cluster_poses_dict = {}
i = 0

while i < len(clusters):
    cluster = clusters[i]
    cluster_poses_dict.setdefault(cluster, [])
    cluster_poses_dict[cluster].append(int(labelList[i])) # Get correct pose labels
    i+=1
    
# Get docking scores
results_file = 'vina_output.pdbqt' # Vina results file (.pdbqt)

scores = []

with open(results_file, "r") as f: 
    lines = f.readlines()
    for line in lines:
        line = line.strip()
        if line.startswith("REMARK VINA RESULT:"):
            fields = line.split()
            score = float(fields[3])
            scores.append(score)

# Make a pandas data frame
cluster_poses_score_todf = {'cluster':[] ,'pose':[], 'score':[]}

for cluster in cluster_poses_dict:
    for pose in cluster_poses_dict[cluster]:
        score = scores[pose - 1]
        cluster_poses_score_todf['cluster'].append(cluster)
        cluster_poses_score_todf['pose'].append(pose)
        cluster_poses_score_todf['score'].append(score)

cluster_poses_score_df = pd.DataFrame(data = cluster_poses_score_todf)

Draw boxplots.

In [None]:
# Draw a boxplot
sns.boxplot(data=cluster_poses_score_df[cluster_poses_score_df.cluster.isin(bigclust)],
            x="cluster", 
            y="score", 
            palette = ['w'])

sns.stripplot(data=cluster_poses_score_df[cluster_poses_score_df.cluster.isin(bigclust)],
              x="cluster",
              y="score", 
              palette = ['k']) 

Visualize poses by cluster.

In [None]:
# Select a cluster (e.g. 1) to visualize the poses that belong to it
cluster = 1
poses = cluster_poses_dict[cluster]
mols = [renamed_poses[i-1] for i in poses] # List comprehension
img = Draw.MolsToGridImage(mols, molsPerRow=3, useSVG=True, legends=[i.GetProp('_Name') for i in mols])
img

### 6.2. Calculate the RMSD from the docking poses to a reference (experimental) structure
1. Find the maximum common substructure (MCS) of the docked and the reference ligand.
2. Calculate the RMSD (MCS atoms) between the docked and the reference ligand.

In [None]:
# Open reference pose (SDF file: reference_pose.sdf) and fix its bond order (fix wonky bonds)
ref_arip = AllChem.AssignBondOrdersFromTemplate(goodconn, Chem.SDMolSupplier('reference_pose.sdf')[0]) 

# Create new empty column in the data frame
cluster_poses_score_df["ref_rmsd"] = ""

# Calculate RMSD from docking poses to the reference pose
i = 0

for pose in renamed_poses:
    pose_id = [int(i.GetProp('_Name')) for i in renamed_poses][i]    
    rmsd = AllChem.CalcRMS(pose, ref_arip)
    # Adding the rmsd to the data frame
    row = cluster_poses_score_df[cluster_poses_score_df.pose == pose_id].index
    cluster_poses_score_df.loc[row, "ref_rmsd"] = rmsd # Avoid SettingWithCopyWarning 

    i+=1

What was the maximum common substructure used to calculate the RMSD?

In [None]:
# Find MCS
template_and_poses = [ref_arip] + list(renamed_poses)
mcs = rdFMCS.FindMCS(template_and_poses)

# Draw
DrawingOptions.bondLineWidth = 2
DrawingOptions.includeAtomNumbers = False
mcs_drawn=Draw.MolToImage(Chem.MolFromSmarts(mcs.smartsString), size = (600,400))
mcs_drawn

Draw boxplots.

In [None]:
# Draw a boxplot
sns.boxplot(data=cluster_poses_score_df[cluster_poses_score_df.cluster.isin(bigclust)], 
            x="cluster", 
            y="ref_rmsd", 
            palette = ['w'])

sns.stripplot(data=cluster_poses_score_df[cluster_poses_score_df.cluster.isin(bigclust)], 
              x="cluster", 
              y="ref_rmsd", 
              palette = ['k']) 

### 6.3. Docking score vs. RMSD from reference
Scatterplot.

In [None]:
# New palette bc there are less clusters
l = len(set(bigclust))
reduced_palette = list(sns.color_palette("Spectral", l).as_hex())

# Make plot bigger
fig, ax = plt.subplots(figsize=(6, 6))
sns.set(font_scale=1.5)
sns.set_style("whitegrid")

# Scatterplot
sns.scatterplot(data=cluster_poses_score_df[cluster_poses_score_df.cluster.isin(bigclust)], 
                x="ref_rmsd", 
                y="score", 
                hue = "cluster",
                palette = reduced_palette,
                legend = True)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))