In [None]:
import numpy as np
from sklearn.metrics.pairwise import pairwise_distances
import pickle
import faiss
import rdkit
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator as rdGen
from rdkit import DataStructs
from rdkit.ML.Cluster import Butina
from rdkit.Chem import rdMolDescriptors as rdmd
import time
import warnings
warnings.filterwarnings("ignore")

In [17]:
embedding_distance = [
    ("cddd", "euclidean", faiss.IndexFlatL2),
    ("cddd", "cosine", faiss.IndexFlatIP),
    ("molformer", "euclidean", faiss.IndexFlatL2),
    ("molformer", "cosine", faiss.IndexFlatIP),
    ("macaw", "cosine", faiss.IndexFlatIP),
    ("mol2vec", "cosine", faiss.IndexFlatIP),
]

In [None]:
smiles_list = np.genfromtxt("data/smiles.csv", dtype=str, comments=None)[1:]
print(f"Loaded SMILES: {smiles_list.shape}")

fpgen = rdGen.GetMorganGenerator(radius=2, fpSize=2048)
fingerprint_list = [fpgen.GetFingerprint(Chem.MolFromSmiles(smiles)) for smiles in smiles_list]
print(f"Fingerprints generated: {len(fingerprint_list)}")

## Cluster by Tanimoto similarity on Morgan2 fingerprints

In [None]:
def butina_cluster(fingerprint_list, threshold=0.35):
    start_time = time.time()
    dist_matrix = []
    n = len(fingerprint_list)
    for i in range(1, n):
        sims = DataStructs.BulkTanimotoSimilarity(fingerprint_list[i], fingerprint_list[:i])
        dist_matrix.extend([1-x for x in sims])
    clusters = Butina.ClusterData(dist_matrix, nPts=n, distThresh=threshold, isDistData=True)
    end_time = time.time()
    print(f"Time: {end_time - start_time}")
    indices =  np.zeros((n,))
    for idx_cls, cluster in enumerate(clusters, 1):
        for idx in cluster:
            indices[idx] = idx_cls
    return indices

In [None]:
fingerprint_list = [rdmd.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smiles), 2, nBits=2048) for smiles in smiles_list]
clusters_fp = butina_cluster(fingerprint_list)
print(f"Number of clusters: {np.max(clusters_fp)}")

## Cluster by distance on embedding

In [3]:
def kmeans_cluster(embedding_list, n_clusters=500, n_iter=20):
    vector_dimension = embedding_list.shape[1]
    kmeans = faiss.Kmeans(d=vector_dimension, k=n_clusters, niter=n_iter, verbose=True)
    kmeans.train(embedding_list)
    distances, indices = kmeans.index.search(embedding_list, 1)
    return indices

In [None]:
def rand_index(clusters_fp, clusters_emb):
    counts = [[0, 0], [0, 0]]
    n = len(clusters_fp)
    for i in range(n):
        for j in range(i+1, n):
            counts[int(clusters_fp[i] == clusters_fp[j])][int(clusters_emb[i] == clusters_emb[j])] += 1
    print(f"Counts:\nFingerprints - different cluster, embedding - different cluster: {counts[0][0]}")
    print(f"Fingerprints - different cluster, embedding - same cluster: {counts[0][1]}")
    print(f"Fingerprints - same cluster, embedding - different cluster: {counts[1][0]}")
    print(f"Fingerprints - same cluster, embedding - same cluster: {counts[1][1]}")
    rand_idx = (counts[0][0] + counts[1][1]) / np.sum(counts)
    return rand_idx

In [None]:
for emb_name, dist_name, index_cls in embedding_distance:
    with open(f"embedding/embedding_{emb_name}_{dist_name}.pkl", "rb") as file:
        embedding_list = pickle.load(file)
    embedding_list = np.nan_to_num(embedding_list)
    print(embedding_list.shape)

    start_time = time.time()
    vector_dimension = embedding_list.shape[1]
    index = index_cls(vector_dimension)
    index.add(embedding_list)
    print(f"Added to index: {index.ntotal}")

    clusters_emb = kmeans_cluster(embedding_list, n_clusters=100, n_iter=10)
    end_time = time.time()
    print(f"Time: {end_time - start_time}\n")
    rand_idx = rand_index(clusters_fp, clusters_emb)
    print(f"Rand index for {emb_name} - {dist_name}: {rand_idx}")

## Compare similarity search results

In [None]:
import time

def get_results_fp(smiles, smiles_list, threshold):
    fpgen = rdGen.GetMorganGenerator(radius=2, fpSize=2048)
    query_fp = fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
    start_time = time.time()
    similarities = [DataStructs.TanimotoSimilarity(fp, query_fp) for fp in fingerprint_list]
    distances = 1 - np.array(similarities)
    results = np.argsort(distances)
    total_time = time.time() - start_time
    idx = np.argmax(distances[results] > threshold)
    return results[:idx], total_time

def get_results_emb(embedding, index, threshold):
    query_emb = np.array([embedding])
    start_time = time.time()
    distances, indices = index.search(query_emb, k=index.ntotal)
    results = indices[0]
    total_time = time.time() - start_time
    distances = distances.flatten()
    distances = (distances - np.min(distances)) / np.ptp(distances)
    if distances[0] > distances[-1]:
        distances = distances[::-1]
        results = results[::-1]
    idx = np.argmax(distances > threshold**2) # Results from FAISS are squared euclidean distance
    return results[:idx], total_time

In [None]:
query_smiles_list = np.genfromtxt("data/smiles_query.csv", dtype=str, delimiter='\n', comments=None)[1:]
print(len(query_smiles_list))

It is assumed that embeddings are saved as *embedding_\<emb_name\>_\<dist_name\>.pkl* and embeddings of the query compounds as *embedding_\<emb_name\>_\<dist_name\>_query.pkl*.

In [None]:
threshold = 0.15

for emb_name, dist_name, index_cls in embedding_distance:
    with open(f"embedding/embedding_{emb_name}_{dist_name}_query.pkl", "rb") as file:
        query_embedding_list = pickle.load(file)
    query_embedding_list = np.nan_to_num(query_embedding_list)
    print(f"Loaded query embeddings: {query_embedding_list.shape}")  
    with open(f"embedding/embedding_{emb_name}_{dist_name}.pkl", "rb") as file:
        embedding_list = pickle.load(file)
    embedding_list = np.nan_to_num(embedding_list)
    print(f"Loaded embeddings: {embedding_list.shape}")

    vector_dimension = embedding_list.shape[1]
    index = index_cls(vector_dimension)
    index.add(embedding_list)
    print(f"Added to index: {index.ntotal}")
    
    total_common_count = 0
    total_time_fp = 0
    total_time_emb = 0
    for query_smiles, query_emb in zip(query_smiles_list, query_embedding_list):
        results_emb, time_emb = get_results_emb(query_emb, index, threshold)
        total_time_emb += time_emb
        results_fp, time_fp = get_results_fp(query_smiles, smiles_list, threshold)
        total_time_fp += time_fp
        try:
            count = len((set(results_emb) & set(results_fp))) / len(set(results_fp))
        except ZeroDivisionError:
            count = 1
        total_common_count += count
    print(f"\nRecall {emb_name} - {dist_name}: {total_common_count / len(query_smiles_list)}")
    print(f"Total time emb / fp {emb_name} - {dist_name}: {total_time_emb} / {total_time_fp}")