#### 确定生成分子的数量
受到参考文献的启发，我们基于多样性指标，研究合适的生成分子数量。即当生成分子的多样性大致收敛时，便为生成上限。

#### Determine the size of molecules to be generated
Inspired by the refenece, we study the appropriate number of generated molecules based on diversity. Specifically, when the diversity of the generated molecules approximately converges, it serves as the upper limit for generation.

#### Reference
```
Özçelik, R., & Grisoni, F. (2024). The Jungle of Generative Drug Discovery: Traps, Treasures, and Ways Out. ArXiv, abs/2501.05457.
```

In [1]:
from utils.io import *
from eval.similarity import Similarity
from pathlib import Path
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from scipy.ndimage import gaussian_filter1d
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

  from .autonotebook import tqdm as notebook_tqdm


- 由于多样性在生成数量较少的分子时变化较大，因此我们采用指数衰减加权随机采样，对前期生成的分子较多地采样，而后期则减少采样。  
- Since diversity varies significantly when generating a smaller number of molecules, we employ exponentially decaying weighted random sampling to sample more initially generated molecules as possible, while reducing sampling in the later stages.

In [2]:
def generate_indices(total_samples=10000, n_samples=48, alpha=4):
    np.random.seed(42)

    indices = np.arange(total_samples)

    weights = np.exp(-alpha * indices / total_samples)
    weights /= weights.sum() 
    sampled_indices = np.random.choice(
        indices, 
        size=n_samples, 
        replace=False, 
        p=weights,
    )
    return np.sort(sampled_indices)

indices = generate_indices()
indices = np.insert(indices, 0, 1)
indices = np.append(indices, 10000)
indices

array([    1,    51,    85,   116,   146,   164,   251,   319,   368,
         415,   458,   491,   496,   545,   584,   732,   841,   844,
         886,   888,   913,  1114,  1145,  1379,  1414,  1484,  1664,
        1757,  1786,  1809,  1939,  2178,  2214,  2229,  2268,  2294,
        2627,  2784,  2969,  3169,  3682,  3944,  4248,  4747,  5579,
        6702,  6768,  7388,  7598, 10000])

In [3]:
def read_in(path):
    mols = []
    for sdffile in Path(path).glob('*.sdf'):
        mols.extend(read_sdf(sdffile))

    uniques = Preprocess(mols).unique()
    return uniques

def intdiv(mols, indices=indices):
    simis = []
    for idx in tqdm(indices):
        simi = Similarity(mols[:idx])
        simis.append(simi.similarity())
    intdiv = [1-sim for sim in simis]
    return intdiv

def sca_intdiv(mols, indices=indices):
    simis = []
    for idx in tqdm(indices):
        simi = Similarity(mols[:idx])
        simis.append(simi.scaffold_similarity())
    intdiv = [1-sim for sim in simis]
    return intdiv

def circle(mols, indices=indices):
    circles = []
    for idx in tqdm(indices):
        simi =  Similarity(mols[:idx])
        circles.append(simi.circle())
    return circles

# Smooth with Gaussian Filter (sigma=2)
def smooth(data):
    return gaussian_filter1d(data, sigma=2)

In [4]:
ht2a_mols = read_in('../testfile/DiffSBDD/5HT2A')
btk_mols = read_in('../testfile/DiffSBDD/BTK')
nampt_mols = read_in('../testfile/DiffSBDD/NAMPT')

INFO: Valid SMILES: 10489 out of 10752
INFO: Unique SMILES: 10447 out of 10752
INFO: Valid SMILES: 10485 out of 10752
INFO: Unique SMILES: 10429 out of 10752
INFO: Valid SMILES: 10981 out of 11264
INFO: Unique SMILES: 10979 out of 11264


#### Reference-based diversity

In [None]:
def refer_div(valids, ref_mol, indices=indices):
    simis = []
    for idx in tqdm(indices):
        simi = Similarity(mols=valids[:idx], ref_mols=[ref_mol], compared_mode=True)
        simis.append(simi.similarity())
    return [1-sim for sim in simis]

In [5]:
ht2a_ref = read_sdf('../Targets/5HT2A/5HT2A_7wc7_ligand_A.sdf')
btk_ref = read_sdf('../Targets/BTK/BTK_8fll_ligand_A.sdf')
nampt_ref = read_sdf('../Targets/NAMPT/NAMPT_7ppe_ligand_A.sdf')

In [None]:
ht2a_refdiv = refer_div(ht2a_mols, ht2a_ref)
btk_refdiv = refer_div(btk_mols, btk_ref)
nampt_refdiv = refer_div(nampt_mols, nampt_ref)

In [None]:
fig = plt.figure(dpi=150)

ax1 = plt.subplot(1, 1, 1)
sns.lineplot(x=indices, y=ht2a_refdiv, label="5HT2A")
sns.lineplot(x=indices, y=btk_refdiv, label="BTK")
sns.lineplot(x=indices, y=nampt_refdiv, label="NMAPT")
ax1.set_xscale('log')
ax1.set_ylim(0.909, 0.951)
ax1.set_ylabel('Reference Div')
ax1.set_xlabel('Generated Size')

#### Internal Diversity

In [None]:
ht2a = intdiv(ht2a_mols)
btk = intdiv(btk_mols)
nampt = intdiv(nampt_mols)

In [None]:
ht2a_sca = sca_intdiv(ht2a_mols)
btk_sca = sca_intdiv(btk_mols)
nampt_sca = sca_intdiv(intdiv(nampt_mols))

In [None]:
fig = plt.figure(dpi=150, figsize=(8, 4))

ax1 = plt.subplot(1, 2, 1)
sns.lineplot(x=indices, y=ht2a , label="5HT2A")
sns.lineplot(x=indices, y=btk , label="BTK")
sns.lineplot(x=indices, y=nampt, label="NMAPT")
ax1.set_xscale('log')
ax1.set_ylim(-0.05, 1.05)
ax1.set_ylabel('IntDiv')
ax1.set_xlabel('Generated Size')

ax2 = plt.subplot(1, 2, 2)
sns.lineplot(x=indices, y=ht2a_sca , label="5HT2A")
sns.lineplot(x=indices, y=btk_sca , label="BTK")
sns.lineplot(x=indices, y=nampt_sca, label="NMAPT")
ax2.set_xscale('log')
ax2.set_ylim(-0.05, 1.05)
ax2.set_ylabel('Scaffold IntDiv')
ax2.set_xlabel('Generated Size')

plt.tight_layout()

#### #Circle

In [None]:
cir_ht2a = circle(read_in('../testfile/DiffSBDD/5HT2A'))
cir_btk = circle(read_in('../testfile/DiffSBDD/BTK'))
cir_nampt = circle(read_in('../testfile/DiffSBDD/NAMPT'))

In [None]:
#fig = plt.subplots(dpi=150, figsize=(10, 5))

plt.figure(dpi=150, figsize=(8, 4))

ax1 = plt.subplot(1, 2, 1)
sns.lineplot(x=indices, y=cir_ht2a, label="5HT2A")
sns.lineplot(x=indices, y=cir_btk, label="BTK")
sns.lineplot(x=indices, y=cir_nampt, label="NMAPT")
ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.set_ylim(0.875, 12500)
ax1.set_ylabel('#Circle')
ax1.set_xlabel('Generated Size')

ax2 = plt.subplot(1, 2, 2)
sns.lineplot(x=indices, y=cir_ht2a/indices, label="5HT2A")
sns.lineplot(x=indices, y=cir_btk/indices, label="BTK")
sns.lineplot(x=indices, y=cir_nampt/indices, label="NMAPT")
ax2.set_xscale('log')
ax2.set_ylabel('Aver #Circle')
ax2.set_xlabel('Generated Size')

plt.tight_layout()

#### Fragments Diversity
Calculate with ECFP4

In [None]:
from utils.measure import morgan_frags
def frag_diversity(fps):
    all_frags =  process_map(morgan_frags, fps, chunksize=1000, disable=True)
    unique = np.unique([frag for frags in all_frags for frag in frags])
    return unique.shape[0]

def check_frag(mols, indices=indices):
    fps = Similarity(mols=mols).fps
    div = []
    for idx in tqdm(indices):
        frag_num = frag_diversity(fps[:idx])
        div.append(frag_num)
    return div

In [None]:
ht2a_frag = check_frag(ht2a_mols)
btk_frag = check_frag(btk_mols)
nampt_frag = check_frag(nampt_mols)

In [None]:
fig = plt.figure(dpi=150, figsize=(8, 4))

ax1 = plt.subplot(1, 2, 1)
sns.lineplot(x=indices, y=ht2a_frag , label="5HT2A")
sns.lineplot(x=indices, y=btk_frag , label="BTK")
sns.lineplot(x=indices, y=nampt_frag, label="NMAPT")
#ax1.set_xscale('log')
#ax1.set_yscale('log')
ax1.set_ylabel('Fragments Div')
ax1.set_xlabel('Generated Size')

ax2 = plt.subplot(1, 2, 2)
sns.lineplot(x=indices, y=ht2a_frag , label="5HT2A")
sns.lineplot(x=indices, y=btk_frag , label="BTK")
sns.lineplot(x=indices, y=nampt_frag, label="NMAPT")
ax2.set_xscale('log')
ax2.set_yscale('log')
ax1.set_ylabel('Fragments Div')
ax2.set_xlabel('Generated Size')

#### Single Sphere exclusion clustering
Adapted from [This blog](https://greglandrum.github.io/rdkit-blog/posts/2020-11-18-sphere-exclusion-clustering.html).

In [6]:
from rdkit.SimDivFilters import rdSimDivPickers
lp = rdSimDivPickers.LeaderPicker()

In [18]:
thresh = 0.65
fps = Similarity(mols=ht2a_mols).fps
picks = lp.LazyBitVectorPick(fps,len(fps),thresh)

In [19]:
len(picks)

10082

In [20]:
thresh = 0.75
fps = Similarity(mols=ht2a_mols).fps
picks = lp.LazyBitVectorPick(fps,len(fps),thresh)

In [21]:
len(picks)

8494

In [16]:
from rdkit import DataStructs
pickfps = [fps[x] for x in picks]
nearest = []
simhist = []
for i,fpi in enumerate(pickfps):
    tfps = pickfps[:]
    del tfps[i]
    sims = DataStructs.BulkTanimotoSimilarity(fpi,tfps)
    nearest.append(max(sims))
    simhist.extend(sims)
sorted(nearest,reverse=True)[:10]

[0.248,
 0.248,
 0.24793388429752067,
 0.24793388429752067,
 0.24786324786324787,
 0.24786324786324787,
 0.24778761061946902,
 0.24778761061946902,
 0.24770642201834864,
 0.24770642201834864]

In [17]:
simhist

[0.1,
 0.09243697478991597,
 0.06306306306306306,
 0.07608695652173914,
 0.0673076923076923,
 0.08433734939759036,
 0.08108108108108109,
 0.07086614173228346,
 0.09848484848484848,
 0.05737704918032787,
 0.07920792079207921,
 0.06896551724137931,
 0.013333333333333334,
 0.04672897196261682,
 0.08888888888888889,
 0.08227848101265822,
 0.072992700729927,
 0.07017543859649122,
 0.059322033898305086,
 0.10909090909090909,
 0.06779661016949153,
 0.07079646017699115,
 0.06097560975609756,
 0.011904761904761904,
 0.02631578947368421,
 0.06722689075630252,
 0.05504587155963303,
 0.09322033898305085,
 0.014705882352941176,
 0.07207207207207207,
 0.09523809523809523,
 0.07920792079207921,
 0.11764705882352941,
 0.08571428571428572,
 0.06666666666666667,
 0.1092436974789916,
 0.11023622047244094,
 0.08,
 0.06201550387596899,
 0.06363636363636363,
 0.07692307692307693,
 0.07894736842105263,
 0.1,
 0.039603960396039604,
 0.09821428571428571,
 0.037037037037037035,
 0.125,
 0.1016949152542373,
 0.0

#### Save results

In [None]:
import pickle
from easydict import EasyDict

di = EasyDict({
    '5HT2A': EasyDict({'intdiv': int, 'scaff_intdiv': ht2a_sca, 'circle': cir_ht2a}),
    'BTK': EasyDict({'intdiv': btk, 'scaff_intdiv': btk_sca, 'circle': cir_btk}),
    'NAMPT': EasyDict({'intdiv': nampt, 'scaff_intdiv': nampt_sca, 'circle': cir_nampt})
})


In [None]:
with open('../testfile/DiffSBDD/divesity.pkl', 'wb') as fi:
    pickle.dump(di, fi)