In [None]:
import os
import rdkit
import random
import pickle
import numpy as np
import rdkit.Chem.QED as QED
# import pandas as pd
# import seaborn as sns
# import matplotlib.pyplot as plt
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem

import sa_scorer
from measures import *
from utils import *

## Load Molecular Database

In [None]:
smiles = []
mols = [] 
vecs = []

### ChEMBL & ENAMINE

In [None]:
file_name = 'chembl_31.sdf'
# file_name = 'Enamine_Discovery_Diversity_Set_50240_DDS-50_20211022.sdf'
# file_name = 'Enamine_Hit_Locator_Library_HLL-460_460160cmpds_20220221.sdf'
with Chem.SDMolSupplier(os.path.join('data', file_name)) as suppl:
    for mol in tqdm(suppl):
        if mol is None: continue
        mols.append(mol)

len(mols)

### ZINC 250k

In [None]:
with open('data/250k_rndm_zinc_drugs_clean_3.csv', 'r') as f:
    lines = f.readlines()
    for i, line in enumerate(lines):
        if i % 2 == 0: continue
        smi = line[1:-1]
        smiles.append(smi)

### GDB-17

In [None]:
with open('data/GDB17.50000000LL.smi', 'r') as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]
    smiles = lines

### MOSES

In [None]:
with open('data/moses_dataset_v1.csv', 'r') as f:
    lines = f.readlines()
    for i, line in enumerate(lines):
        if i == 0: continue
        smi = line.strip().split(',')[0]
        smiles.append(smi)

### Filtered mols

In [None]:
# import _pickle as pickle
# import cPickle as pickle

mols = []
file_names = ['zinc', 'enamine', 'moses', 'chembl', 'gdb']
for file_name in file_names:
    m = pickle.load(open(os.path.join('mols/databases/', 'jnk3_{}.pkl'.format(file_name)), 'rb'))
    print('loaded', len(m), 'mols for', file_name)
    mols += m
print(len(mols))

### MARS generated

In [None]:
import pandas as pd

file_name = 'mols/mars+nc/mols5.csv'

df = pd.read_csv(file_name, header=0, sep=',')
for col in df.columns:
    if col == 'SMILES': continue
    df[col] = pd.to_numeric(df[col], errors='coerce')
df = df.dropna()
df = df.drop_duplicates(subset='SMILES')

MAX_STEP = 2000
df = df[df['Step'] <= MAX_STEP]
# df = df[df['Step'] == MAX_STEP]
df

In [None]:
df['succ'] = \
    ((df['jnk3'] >= 0.5) & \
    (df['qed' ] >= 0.6) & \
    (df['sa'  ] >= .67)).tolist()
df['score'] = df['jnk3'] + df['qed'] + df['sa']
df_succ = df[df['succ'] == True]
df_succ

In [None]:
smiles = list(df_succ['SMILES'])

### DST generated

In [None]:
pkl_file = "mols/mols_DST2.pkl"
idx_2_smiles2f, trace_dict = pickle.load(open(pkl_file, 'rb'))
# bestvalue, best_smiles = 0, ''
topk = 100
whole_smiles2f = dict()
for idx, (smiles2f,current_set) in tqdm(idx_2_smiles2f.items()):
    whole_smiles2f.update(smiles2f)

smiles = [smi for smi,f in whole_smiles2f.items()]
# smiles_f_lst = [(smiles,f) for smiles,f in whole_smiles2f.items()]
# smiles_f_lst.sort(key=lambda x:x[1], reverse=True)
# best_smiles_lst = [smiles for smiles,f in smiles_f_lst[:topk]]
# best_f_lst = [f for smiles,f in smiles_f_lst[:topk]]

### JANUS generated

In [None]:
smiles = []
for i in range(10):
    for file_name in ['population_explore', 'population_local_search']:
        with open('mols/mols_JANUS/{}_DATA/{}.txt'.format(i, file_name), 'r') as f:
            lines = f.readlines()
            smiles += lines[0].strip().split(' ')

### RationaleRL generated

In [None]:
with open('mols/mols_RationaleRL3.txt', 'r') as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]
    smiles = lines

## Transforming and filtering

In [None]:
def mol_from_smiles(x):
    return Chem.MolFromSmiles(x)

smiles = list(set(smiles))
len(smiles)
mols = process_map(mol_from_smiles, smiles, chunksize=1000)
mols = [mol for mol in mols if mol is not None]
len(mols)

In [None]:
smiles = list(set([Chem.MolToSmiles(mol) for mol in mols]))
smiles = list(set(smiles))
mols = process_map(mol_from_smiles, smiles, chunksize=1000)
len(mols)

In [None]:
_fscores = None
def readFragmentScores(name='fpscores'):
    import gzip
    import pickle
    global _fscores
    # generate the full path filename:
    _fscores = pickle.load(gzip.open('fpscores.pkl.gz'))
    outDict = {}
    for i in _fscores:
        for j in range(1, len(i)):
            outDict[i[j]] = float(i[0])
    _fscores = outDict
readFragmentScores(name='fpscores')

def sa(mol):
    x = sa_scorer.calculateScore(mol, _fscores=_fscores)
    return (10. - x) / 9. # normalized to [0, 1]

def map_filter(mol):
    if      sa(mol) < .67: return None
    if QED.qed(mol) < 0.6: return None
    return mol

def map_filter_from_smi(smi):
    mol = mol_from_smiles(smi)
    if      sa(mol) < .67: return None
    if QED.qed(mol) < 0.6: return None
    return mol

# map_filter(mols[0])
# map_filter_from_smi(smiles[0])

In [None]:
mols = process_map(map_filter, mols, chunksize=1000)
# mols = process_map(map_filter_from_smi, smiles, chunksize=1000)
# mols = [map_filter(mol_from_smiles(smi)) for smi in tqdm(smiles)]
mols = [mol for mol in mols if mol is not None]
len(mols)

In [None]:
# pickle.dump(mols, open('qed_gdb.pkl', 'wb'))
# mols = pickle.load(open('qed_gdb.pkl', 'rb'))
# len(mols)

In [None]:
ROOT_DIR = ''
TASKS = ['gsk3b', 'jnk3']
SPLITS = ['val', 'dev']

models = {}
def load_model(task):
    with open(os.path.join(ROOT_DIR, 'kinase_rf/%s.pkl' % task), 'rb') as f:
        models[task] = pickle.load(f, encoding='iso-8859-1')
load_model('jnk3')

In [None]:
# mols = pickle.load(open('mols/databases/qed_moses.pkl', 'rb'))
# len(mols)

In [None]:
def fingerprints_from_mol(mol):
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024)
    nfp = np.zeros((0, ), dtype=np.int8)
    DataStructs.ConvertToNumpyArray(fp, nfp)
    return nfp

def map_predict(args):
    model, batch = args
    return model.predict_proba(batch)[:,1]

def get_scores(task, mols):
    if len(mols) == 0: return []
    model = models.get(task)
    if model is None:
        load_model(task)
        model = models[task]
        
    fps = [fingerprints_from_mol(mol) for mol in mols]
    fps = np.stack(fps, axis=0)
    batches = np.array_split(fps, fps.shape[0]//512)
    args = zip([models[task]] * len(batches), batches)
    scores = process_map(map_predict, args, chunksize=1000)
    scores = np.concatenate(scores).tolist()
#     scores = models[task].predict_proba(fps)
#     scores = scores[:,1].tolist()
    return scores

In [None]:
scores = get_scores('jnk3', mols)

In [None]:
mols_jnk3 = [mol for mol, score in zip(mols, scores) if score >= 0.5]
len(mols_jnk3)

In [None]:
mols = mols_jnk3

In [None]:
# pickle.dump(mols_jnk3, open('jnk3_gdb.pkl', 'wb'))

In [None]:
vecs = process_map(fingerprint, mols, chunksize=1000)

## Measuring

In [None]:
def define_measures():
    vectorizer = fingerprints
    sim_mat_func = similarity_matrix_tanimoto
    measures = {}
    for t in np.arange(0.4, 1.1, 0.05):
        measures['#Circles (%.2f)' % t] = NCircles(vectorizer=vectorizer, sim_mat_func=sim_mat_func, threshold=t)
    return measures
    measures = {
        # 'Diversity' : Diversity(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'SumDiversity' : SumDiversity(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'Diameter' : Diameter(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'SumDiameter' : SumDiameter(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'Bottleneck' : Bottleneck(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
        # 'SumBottleneck' : SumBottleneck(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
#         'DPP' : DPP(vectorizer=vectorizer, sim_mat_func=sim_mat_func),
        # '#FG' : NFragment(frag='FG'),
        # '#RS' : NFragment(frag='RS'),
        # '#BM' : NBM(),
        '#Circles (0.75)': NCircles(vectorizer=vectorizer, sim_mat_func=sim_mat_func, threshold=0.75),
        'Richness': Richness(),
    }
    return measures

# define_measures()

In [None]:
measures = define_measures()
results = dict()
circs = None
for name, measure in measures.items():
    print('measureing', name, '...')
    if isinstance(measure, DissimilarityBasedMeasure):
        # idxs = [i for i in range(len(mols)) if random.random() < 0.01]
        # vecs_ = [vecs[i] for i in idxs]
        vecs_ = vecs
        val = measure.measure(vecs_, is_vec=True)
    elif isinstance(measure, NCircles):
        val, circs = measure.measure(vecs, is_vec=True, n_chunk=64)
    else: val = measure.measure(mols)
    results[name] = val
    print(name, ': ', val)

In [None]:
print(results)

In [None]:
print(','.join(['%i' % v for v in results.values()][:-3]))

In [None]:
random.shuffle(vecs)

In [None]:
# pickle.dump(circs, open('circs_qed_enamine.pkl', 'wb'))