In [1]:
import os
import rdkit
import random
import pickle
import numpy as np
import rdkit.Chem.QED as QED
# import pandas as pd
# import seaborn as sns
# import matplotlib.pyplot as plt
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem

import sa_scorer
from measures import *
from utils import *

  from .autonotebook import tqdm as notebook_tqdm


## Load Molecular Database

In [None]:
smiles = []
mols = [] 
vecs = []

### ChEMBL & ENAMINE

In [None]:
file_name = 'chembl_31.sdf'
# file_name = 'Enamine_Discovery_Diversity_Set_50240_DDS-50_20211022.sdf'
# file_name = 'Enamine_Hit_Locator_Library_HLL-460_460160cmpds_20220221.sdf'
with Chem.SDMolSupplier(os.path.join('data', file_name)) as suppl:
    for mol in tqdm(suppl):
        if mol is None: continue
        mols.append(mol)

len(mols)

### ZINC 250k

In [None]:
with open('data/250k_rndm_zinc_drugs_clean_3.csv', 'r') as f:
    lines = f.readlines()
    for i, line in enumerate(lines):
        if i % 2 == 0: continue
        smi = line[1:-1]
        smiles.append(smi)

### GDB-17

In [None]:
with open('data/GDB17.50000000LL.smi', 'r') as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]
    smiles = lines

### MOSES

In [None]:
with open('data/moses_dataset_v1.csv', 'r') as f:
    lines = f.readlines()
    for i, line in enumerate(lines):
        if i == 0: continue
        smi = line.strip().split(',')[0]
        smiles.append(smi)

### Filtered mols

In [None]:
# import _pickle as pickle
# import cPickle as pickle

mols = []
file_names = ['zinc', 'enamine', 'moses', 'chembl', 'gdb']
for file_name in file_names:
    m = pickle.load(open(os.path.join('mols/databases/', 'jnk3_{}.pkl'.format(file_name)), 'rb'))
    print('loaded', len(m), 'mols for', file_name)
    mols += m
print(len(mols))

### MARS generated

In [2]:
import pandas as pd

file_name = 'mols/mars+nc/mols5.csv'

df = pd.read_csv(file_name, header=0, sep=',')
for col in df.columns:
    if col == 'SMILES': continue
    df[col] = pd.to_numeric(df[col], errors='coerce')
df = df.dropna()
df = df.drop_duplicates(subset='SMILES')

MAX_STEP = 2000
df = df[df['Step'] <= MAX_STEP]
# df = df[df['Step'] == MAX_STEP]
df

Unnamed: 0,Step,Score,jnk3,sa,qed,nov,SMILES
0,-1,0.268196,0.00,0.805826,0.372786,0.000000,CC
1,0,0.552634,0.00,0.914243,0.569358,0.941176,CCC(=O)NC(C)C
2,0,0.567260,0.00,0.935456,0.612518,0.956522,CCc1cccc(C(F)(F)F)c1
3,0,0.523465,0.00,0.876067,0.460528,0.933333,CCNCCO
4,0,0.558080,0.01,0.939418,0.572320,0.950000,CCc1ccc(C#N)cc1
...,...,...,...,...,...,...,...
1448316,2000,0.698430,0.61,0.773309,0.727438,0.793721,N#Cc1cc2nc(C#N)c(-c3ccnc(NC4CCOCC4)c3)c(F)c2cc1F
1448317,2000,0.644364,0.38,0.641971,0.724970,0.855486,NC1CN(C2CN(C3COCC(c4nc(NC5CCOC5)ncc4F)C3)C2)C1
1448318,2000,0.612459,0.19,0.823917,0.726583,0.859834,Nc1ccc(-c2cnccn2)cc1-c1ccc2c(c1)OC(F)(F)O2
1448319,2000,0.699544,0.70,0.738538,0.791601,0.798175,O=C1NCC(N2CCC2)N1c1ccc(-c2ccnc(NC3CCOCC3)c2)c(...


In [3]:
df['succ'] = \
    ((df['jnk3'] >= 0.5) & \
    (df['qed' ] >= 0.6) & \
    (df['sa'  ] >= .67)).tolist()
df['score'] = df['jnk3'] + df['qed'] + df['sa']
df_succ = df[df['succ'] == True]
df_succ

Unnamed: 0,Step,Score,jnk3,sa,qed,nov,SMILES,succ,score
23693,11,0.728684,0.61,0.871392,0.719205,0.914735,COc1ccccc1Nc1ccnc(Nc2ccc3c(c2)CCO3)c1,True,2.200597
26177,12,0.710447,0.59,0.860728,0.628671,0.923117,Oc1ccccc1Nc1ccnc(Nc2ccc3c(c2)CCO3)c1,True,2.079399
33990,15,0.701315,0.50,0.907774,0.692441,0.912821,Clc1cnc(Nc2ccccc2)cc1Nc1ccccc1,True,2.100215
36466,16,0.728143,0.65,0.921498,0.722039,0.912570,c1ccc(Nc2ccnc(Nc3ccccc3)c2)cc1,True,2.293537
45562,20,0.723569,0.58,0.914314,0.761208,0.914274,Brc1ccc(Nc2ncc(-c3ccccc3)cn2)cc1,True,2.255522
...,...,...,...,...,...,...,...,...,...
1448313,2000,0.697573,0.60,0.782899,0.691552,0.798738,CN(C)c1c(F)cc(Nc2cc(-c3occ(F)c3F)ccn2)cc1F,True,2.074451
1448314,2000,0.701573,0.65,0.742537,0.695607,0.810686,CN(C)c1cc(Nc2cc(C#N)ccn2)ccc1-c1ccnc(N2CCC2N)c1,True,2.088144
1448315,2000,0.697049,0.61,0.750664,0.830441,0.788196,Fc1cc(F)c(-c2ccnc(NC3CCOCC3)c2)nc1C1=NCCN1F,True,2.191105
1448316,2000,0.698430,0.61,0.773309,0.727438,0.793721,N#Cc1cc2nc(C#N)c(-c3ccnc(NC4CCOCC4)c3)c(F)c2cc1F,True,2.110747


In [4]:
smiles = list(df_succ['SMILES'])

### DST generated

In [None]:
pkl_file = "mols/mols_DST2.pkl"
idx_2_smiles2f, trace_dict = pickle.load(open(pkl_file, 'rb'))
# bestvalue, best_smiles = 0, ''
topk = 100
whole_smiles2f = dict()
for idx, (smiles2f,current_set) in tqdm(idx_2_smiles2f.items()):
    whole_smiles2f.update(smiles2f)

smiles = [smi for smi,f in whole_smiles2f.items()]
# smiles_f_lst = [(smiles,f) for smiles,f in whole_smiles2f.items()]
# smiles_f_lst.sort(key=lambda x:x[1], reverse=True)
# best_smiles_lst = [smiles for smiles,f in smiles_f_lst[:topk]]
# best_f_lst = [f for smiles,f in smiles_f_lst[:topk]]

### JANUS generated

In [None]:
smiles = []
for i in range(10):
    for file_name in ['population_explore', 'population_local_search']:
        with open('mols/mols_JANUS/{}_DATA/{}.txt'.format(i, file_name), 'r') as f:
            lines = f.readlines()
            smiles += lines[0].strip().split(' ')

### RationaleRL generated

In [None]:
with open('mols/mols_RationaleRL3.txt', 'r') as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]
    smiles = lines

## Transforming and filtering

In [5]:
def mol_from_smiles(x):
    return Chem.MolFromSmiles(x)

smiles = list(set(smiles))
len(smiles)
mols = process_map(mol_from_smiles, smiles, chunksize=1000)
mols = [mol for mol in mols if mol is not None]
len(mols)

100%|████████████████████████████████████████████████████████████████████████| 223071/223071 [00:17<00:00, 13048.26it/s]


223071

In [None]:
smiles = list(set([Chem.MolToSmiles(mol) for mol in mols]))
smiles = list(set(smiles))
mols = process_map(mol_from_smiles, smiles, chunksize=1000)
len(mols)

In [None]:
_fscores = None
def readFragmentScores(name='fpscores'):
    import gzip
    import pickle
    global _fscores
    # generate the full path filename:
    _fscores = pickle.load(gzip.open('fpscores.pkl.gz'))
    outDict = {}
    for i in _fscores:
        for j in range(1, len(i)):
            outDict[i[j]] = float(i[0])
    _fscores = outDict
readFragmentScores(name='fpscores')

def sa(mol):
    x = sa_scorer.calculateScore(mol, _fscores=_fscores)
    return (10. - x) / 9. # normalized to [0, 1]

def map_filter(mol):
    if      sa(mol) < .67: return None
    if QED.qed(mol) < 0.6: return None
    return mol

def map_filter_from_smi(smi):
    mol = mol_from_smiles(smi)
    if      sa(mol) < .67: return None
    if QED.qed(mol) < 0.6: return None
    return mol

# map_filter(mols[0])
# map_filter_from_smi(smiles[0])

In [None]:
mols = process_map(map_filter, mols, chunksize=1000)
# mols = process_map(map_filter_from_smi, smiles, chunksize=1000)
# mols = [map_filter(mol_from_smiles(smi)) for smi in tqdm(smiles)]
mols = [mol for mol in mols if mol is not None]
len(mols)

In [None]:
# pickle.dump(mols, open('qed_gdb.pkl', 'wb'))
# mols = pickle.load(open('qed_gdb.pkl', 'rb'))
# len(mols)

In [None]:
ROOT_DIR = ''
TASKS = ['gsk3b', 'jnk3']
SPLITS = ['val', 'dev']

models = {}
def load_model(task):
    with open(os.path.join(ROOT_DIR, 'kinase_rf/%s.pkl' % task), 'rb') as f:
        models[task] = pickle.load(f, encoding='iso-8859-1')
load_model('jnk3')

In [None]:
# mols = pickle.load(open('mols/databases/qed_moses.pkl', 'rb'))
# len(mols)

In [None]:
def fingerprints_from_mol(mol):
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024)
    nfp = np.zeros((0, ), dtype=np.int8)
    DataStructs.ConvertToNumpyArray(fp, nfp)
    return nfp

def map_predict(args):
    model, batch = args
    return model.predict_proba(batch)[:,1]

def get_scores(task, mols):
    if len(mols) == 0: return []
    model = models.get(task)
    if model is None:
        load_model(task)
        model = models[task]
        
    fps = [fingerprints_from_mol(mol) for mol in mols]
    fps = np.stack(fps, axis=0)
    batches = np.array_split(fps, fps.shape[0]//512)
    args = zip([models[task]] * len(batches), batches)
    scores = process_map(map_predict, args, chunksize=1000)
    scores = np.concatenate(scores).tolist()
#     scores = models[task].predict_proba(fps)
#     scores = scores[:,1].tolist()
    return scores

In [None]:
scores = get_scores('jnk3', mols)

In [None]:
mols_jnk3 = [mol for mol, score in zip(mols, scores) if score >= 0.5]
len(mols_jnk3)

In [None]:
mols = mols_jnk3

In [None]:
# pickle.dump(mols_jnk3, open('jnk3_gdb.pkl', 'wb'))

In [6]:
vecs = process_map(fingerprint, mols, chunksize=1000)

100%|████████████████████████████████████████████████████████████████████████| 223071/223071 [00:10<00:00, 20774.75it/s]


## Measuring

In [19]:
def define_measures():
    vectorizer = fingerprints
    sim_mat_func = similarity_matrix_tanimoto
    measures = {}
    for t in np.arange(0.4, 1.1, 0.05):
        measures['#Circles (%.2f)' % t] = NCircles(vectorizer=vectorizer, sim_mat_func=sim_mat_func, threshold=t)
    return measures
    measures = {
        # 'Diversity' : Diversity(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'SumDiversity' : SumDiversity(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'Diameter' : Diameter(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'SumDiameter' : SumDiameter(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'Bottleneck' : Bottleneck(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
        # 'SumBottleneck' : SumBottleneck(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'DPP' : DPP(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
        # '#FG' : NFragment(frag='FG'),
        # '#RS' : NFragment(frag='RS'),
        # '#BM' : NBM(),
        '#Circles (0.75)': NCircles(vectorizer=vectorizer, sim_mat_func=sim_mat_func, threshold=0.75),
        'Richness': Richness(),
    }
    return measures

# define_measures()

In [20]:
measures = define_measures()
results = dict()
circs = None
for name, measure in measures.items():
    print('measureing', name, '...')
    if isinstance(measure, DissimilarityBasedMeasure):
        # idxs = [i for i in range(len(mols)) if random.random() < 0.01]
        # vecs_ = [vecs[i] for i in idxs]
        vecs_ = vecs
        val = measure.measure(vecs_, is_vec=True)
    elif isinstance(measure, NCircles):
        val, circs = measure.measure(vecs, is_vec=True, n_chunk=64)
    else: val = measure.measure(mols)
    results[name] = val
    print(name, ': ', val)

measureing #Circles (0.40) ...


64it [00:15,  4.25it/s]
32it [00:16,  1.92it/s]
16it [00:18,  1.17s/it]
100%|████████████████████████████████████████████████████████████████████████████| 70971/70971 [05:56<00:00, 199.15it/s]

#Circles (0.40) :  20744
measureing #Circles (0.45) ...



64it [00:17,  3.58it/s]
32it [00:09,  3.32it/s]
16it [00:08,  1.95it/s]
100%|████████████████████████████████████████████████████████████████████████████| 43730/43730 [01:58<00:00, 368.54it/s]


#Circles (0.45) :  11454
measureing #Circles (0.50) ...


64it [00:12,  5.04it/s]
32it [00:03,  8.65it/s]
16it [00:02,  6.53it/s]
100%|████████████████████████████████████████████████████████████████████████████| 22555/22555 [00:23<00:00, 971.34it/s]

#Circles (0.50) :  5463
measureing #Circles (0.55) ...



64it [00:07,  8.66it/s]
32it [00:01, 20.24it/s]
16it [00:01, 15.81it/s]
100%|███████████████████████████████████████████████████████████████████████████| 11374/11374 [00:05<00:00, 2073.81it/s]

#Circles (0.55) :  2602
measureing #Circles (0.60) ...



64it [00:04, 14.05it/s]
32it [00:00, 66.34it/s]
16it [00:00, 54.47it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 4752/4752 [00:00<00:00, 4830.92it/s]

#Circles (0.60) :  963
measureing #Circles (0.65) ...



64it [00:03, 19.49it/s]
32it [00:00, 180.49it/s]
16it [00:00, 155.02it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 1816/1816 [00:00<00:00, 7389.27it/s]

#Circles (0.65) :  353
measureing #Circles (0.70) ...



64it [00:02, 24.76it/s]
32it [00:00, 351.45it/s]
16it [00:00, 266.68it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 544/544 [00:00<00:00, 17159.26it/s]

#Circles (0.70) :  101
measureing #Circles (0.75) ...



64it [00:02, 26.33it/s]
32it [00:00, 705.89it/s]
16it [00:00, 668.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 165/165 [00:00<00:00, 58963.97it/s]

#Circles (0.75) :  27
measureing #Circles (0.80) ...



64it [00:02, 30.31it/s]
32it [00:00, 933.36it/s]
16it [00:00, 1449.25it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 65/65 [00:00<00:00, 18392.35it/s]

#Circles (0.80) :  9
measureing #Circles (0.85) ...



64it [00:02, 29.04it/s]
32it [00:00, 1152.14it/s]
16it [00:00, 742.68it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 35358.96it/s]

#Circles (0.85) :  3
measureing #Circles (0.90) ...



64it [00:02, 30.10it/s]
32it [00:00, 1691.32it/s]
16it [00:00, 1256.56it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 44844.76it/s]

#Circles (0.90) :  1
measureing #Circles (0.95) ...



64it [00:02, 29.43it/s]
32it [00:00, 1483.12it/s]
16it [00:00, 1038.97it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 17439.93it/s]

#Circles (0.95) :  1
measureing #Circles (1.00) ...



64it [00:02, 28.82it/s]
32it [00:00, 1366.68it/s]
16it [00:00, 1236.89it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 19345.31it/s]

#Circles (1.00) :  1
measureing #Circles (1.05) ...



64it [00:02, 29.39it/s]
32it [00:00, 1391.41it/s]
16it [00:00, 1234.05it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 22482.03it/s]

#Circles (1.05) :  1





In [21]:
print(results)

{'#Circles (0.40)': 20744, '#Circles (0.45)': 11454, '#Circles (0.50)': 5463, '#Circles (0.55)': 2602, '#Circles (0.60)': 963, '#Circles (0.65)': 353, '#Circles (0.70)': 101, '#Circles (0.75)': 27, '#Circles (0.80)': 9, '#Circles (0.85)': 3, '#Circles (0.90)': 1, '#Circles (0.95)': 1, '#Circles (1.00)': 1, '#Circles (1.05)': 1}


In [22]:
print(','.join(['%i' % v for v in results.values()][:-3]))

20744,11454,5463,2602,963,353,101,27,9,3,1


In [23]:
random.shuffle(vecs)

In [24]:
# pickle.dump(circs, open('circs_qed_enamine.pkl', 'wb'))