In [2]:
from rdkit import Chem
from rdkit.Chem import AllChem, Draw, Descriptors, rdmolfiles, rdMolAlign, rdmolops, rdchem, PyMol, Crippen, PropertyMol
from rdkit import DataStructs
from rdkit import RDLogger

from rdkit.Chem.Draw import IPythonConsole
from rdkit.Geometry import Point3D
from rdkit.Numerics.rdAlignment import GetAlignmentTransform
from rdkit.Chem.AtomPairs import Pairs
from rdkit import DataStructs
from rdkit.Chem import MACCSkeys
from rdkit.Chem import rdFMCS
from rdkit.Chem import rdMolTransforms
from tqdm import tqdm, trange
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import torch
import shutil
import os
import os.path
import random
import re
import subprocess
#import pmx
import bz2
import importlib
import sys
import gc

from scipy.spatial.transform import Rotation as R
from scipy.optimize import linear_sum_assignment

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

try:
    import cPickle as pickle
except:
    import pickle

Bohr2Ang=0.529177249

folder=os.getcwd()

# load the generated ligands from smiles strings

In [3]:
#construct generated ligands from smiles
# refs=[]
unique_smi=[]
names=[]
i=0
with open(folder+"/../cl13_detected_fragments_depth_3.smi", "r") as f:
    for line in tqdm(f):
        s=line.strip()
#         m=Chem.MolFromSmiles(s)
#         m.SetProp('ID', f'set4_{i}')
#         refs.append(m)
        if(s not in unique_smi):
            unique_smi.append(s)
        names.append(f'set4_{i}')
        i+=1
# names=[mol.GetProp('ID') for mol in refs]

print(len(names), len(unique_smi))

384481it [14:47, 433.21it/s] 

384481 384481





# filter out ligands that are too large: >36 heavy atoms

In [None]:
def get_num_heavy_atoms(m):
    n_heavy=0
    for atom in m.GetAtoms():
        if(atom.GetAtomicNum()>1):
            n_heavy+=1
    return(n_heavy)

nHeavies=np.zeros(len(names))
with open(folder+"/../cl13_detected_fragments_depth_3.smi", "r") as f:
    for i,line in tqdm(enumerate(f)):
        s=line.strip()
        nHeavies[i]=get_num_heavy_atoms(Chem.MolFromSmiles(s))

#del unique_smi

In [None]:
plt.xlabel("# Heavy Atoms")
plt.ylabel("# Ligands")
plt.hist(nHeavies, bins=int(np.ceil(np.max(nHeavies))-20), range=(20, np.ceil(np.max(nHeavies))), density=False)
plt.vlines(36, 0, 40000, colors='k', linestyles='dashed', label='')
plt.show()

In [None]:
sel=np.argwhere(nHeavies<=36)[:,0]
print(len(sel))
names_filtered=[names[i] for i in sel]
# Draw.MolsToGridImage([Chem.MolFromSmiles(unique_smi[i]) for i in sel[::1000]], molsPerRow=5)

In [None]:
#load reference crystal structure
fn=folder+"/4d09.mol"
xray_lig=rdmolfiles.MolFromMolFile(fn, sanitize=True, removeHs=False)
xray_lig_name="4d09"

In [9]:
# Draw.MolsToGridImage([xray_lig], legends=[xray_lig_name], molsPerRow=1)

In [10]:
# Chem.AddHs(xray_lig, addCoords=True)

# Functions to evaluate overlap of sidechains

In [11]:
def overlap_measure(molA, molB):
    confA=molA.GetConformer()
    confB=molB.GetConformer()
    posA=[]
    posB=[]
    for i,a in enumerate(molA.GetAtoms()):
        if(a.GetAtomicNum()>1): #not hydrogens
            posA.append(list(confA.GetAtomPosition(i)))

    for i,a in enumerate(molB.GetAtoms()):
        if(a.GetAtomicNum()>1): #not hydrogens
            posB.append(list(confB.GetAtomPosition(i)))
    posA=np.array(posA)
    posB=np.array(posB)
    
    dif=posA[:,np.newaxis,:]-posB[np.newaxis,:,:]
    dist=np.linalg.norm(dif, axis=2)
    A_ind, B_ind = linear_sum_assignment(dist)
    measure = 0
    for i,a in enumerate(A_ind):
        measure+=dist[a, B_ind[i]]
    return(measure)

# Align each ligand onto the closest X-ray structure based on common atoms

In [12]:
def find_mapping(mol, ref, outpath, debug=False, dMCS=True):    
    #remove old output files
    for f in os.listdir(outpath):
        os.remove(os.path.join(outpath, f))
    
    #dump files
    mol_file=outpath+"/mol.pdb"
    ref_file=outpath+"/ref.pdb"
    with open(mol_file,"w") as f:
        f.write(rdmolfiles.MolToPDBBlock(mol))
    with open(ref_file,"w") as f:
        f.write(rdmolfiles.MolToPDBBlock(ref))
        
    #map atoms with pmx
    
    # params
    i1 = ref_file
    i2 = mol_file
    o1 = '{0}/ref_map.dat'.format(outpath)
    o2 = '{0}/mol_map.dat'.format(outpath)
    opdb1 = '{0}/out_pdb1.pdb'.format(outpath)
    opdb2 = '{0}/out_pdb2.pdb'.format(outpath)
    opdbm1 = '{0}/out_pdbm1.pdb'.format(outpath)
    opdbm2 = '{0}/out_pdbm2.pdb'.format(outpath)
    score = '{0}/score.dat'.format(outpath)
    log = '{0}/mapping.log'.format(outpath)

    if(dMCS):
        process = subprocess.Popen(['pmx','atomMapping',
                            '-i1',i1,
                            '-i2',i2,
                            '-o1',o1,
                            '-o2',o2,
                            '-opdb1',opdb1,
                            '-opdb2',opdb2,                                        
                            '-opdbm1',opdbm1,
                            '-opdbm2',opdbm2,
                            '-score',score,
                            '-log',log,
                            '--dMCS', '--d', '0.1',
                            #'--RingsOnly'
                                   ],
                            stdout=subprocess.PIPE, 
                            stderr=subprocess.PIPE)
        process.wait()
    
    if(not os.path.isfile(o2) ): #mapping failed, use less restrictive match criteria: no distance criterion in MCS
        if(debug):
            print("Initial atom mapping filed. Retrying without --dMCS")
#             raise()
        process = subprocess.Popen(['pmx','atomMapping',
                        '-i1',i1,
                        '-i2',i2,
                        '-o1',o1,
                        '-o2',o2,
                        '-opdb1',opdb1,
                        '-opdb2',opdb2,                                        
                        '-opdbm1',opdbm1,
                        '-opdbm2',opdbm2,
                        '-score',score,
                        '-log',log,
                               ],
                        stdout=subprocess.PIPE, 
                        stderr=subprocess.PIPE)
        process.wait()
    
    if(not os.path.isfile(o2) ):
        raise RuntimeError('atomMapping failed after a second, less restrictive, attempt.')
    
    #read mapping: indeces of mol ordered as ref
    mol_inds=[]
    ref_inds=[]
    with open(o2,"r") as f:
        for line in f:
            m,r=line.split()
            mol_inds.append(int(m)-1)
            ref_inds.append(int(r)-1)
            
    #the above mapping is in output atom order
    #pmx atomMapping can change the order from the input one though.
            
    with open(score,"r") as f:
        for line in f:
            score_val=float(line.split()[-1])
            break;
            
    return(mol_inds, ref_inds, score_val)

# Generate all ligand structures in a parallel manner using owl

# 1. input files
Save (xray, scaffold=Null, ref) tupples as separate pickles for each ligand

In [13]:
overwrite=False

xray = Chem.AddHs(xray_lig, addCoords=True)
os.makedirs(folder+'/lig_structures/', exist_ok=True)

#loop over ligands most similar to this xray structure
for i, ref_id in enumerate(trange(len(refs))):
    ref  = refs[ref_id]
    fname=folder+'/lig_structures/{}.pickle'.format(ref.GetProp('ID'))
    if(not os.path.isfile(fname) or overwrite):
        scaffold = None
        if(ref.HasProp('embedded')):
            ref.ClearProp('embedded')
        pickle.dump( (xray, scaffold, PropertyMol.PropertyMol(ref)), open( fname, "wb" ) )
gc.collect()
print("Done")

In [14]:
reload ligands from the pickle files
def reload_refs_from_pickles():
    for i, ref_id in enumerate(trange(len(refs))):
        ref  = refs[ref_id]
        fname=folder+'/lig_structures/{}.pickle'.format(ref.GetProp('ID'))
        if(os.path.isfile(fname)):
            xray, scaffold, ref = pickle.load( open( fname, "rb" ) )
            refs[ref_id] = ref
    print("Done reloading")
reload_refs_from_pickles()

In [18]:
for n in tqdm(names):
    fname=folder+'/lig_structures/{}.pickle'.format(n)
    if(os.path.isfile(fname)):
        ref = pickle.load( open( fname, "rb" ) )[2]

        if(ref.HasProp('embedded') and ref.GetProp('embedded')=="yes"):
            continue; #already handled
        elif(ref.HasProp('corrupt') and ref.GetProp('corrupt')=="yes"): #corrupt rings from incorrect ring colosure numbers in SMILES
            continue; #already handled
        names_left.append(n)

pickle.dump( names_left, open( folder+"/set_4_lig_names_still_to_embed.pickle", "wb" ) )

100%|██████████| 259901/259901 [00:00<00:00, 3040871.57it/s]


# 2B. run embedding (on owl)

In [54]:
import queue
import threading

previously_done=0
max_debug_evals=200000

# print("Reloading ligands.")
# reload_refs_from_pickles()

q = queue.Queue()
print("Building Queue.")
for n in tqdm(names_filtered):
    if (not n in names_left):
        previously_done+=1
        continue
    fname=folder+'/lig_structures/{}.pickle'.format(n)
    if(os.path.isfile(fname)):
        ref = pickle.load( open( fname, "rb" ) )[2]

        if(ref.HasProp('embedded') and ref.GetProp('embedded')=="yes"):
            previously_done+=1
            continue; #already handled
        elif(ref.HasProp('corrupt') and ref.GetProp('corrupt')=="yes"): #corrupt rings from incorrect ring colosure numbers in SMILES
            previously_done+=1
            continue; #already handled
        fname=folder+'/lig_structures/{}.pickle'.format(ref.GetProp('ID'))
        q.put(fname)

        if(q.qsize()>=max_debug_evals):
            break;
        
print("previously_done:", previously_done, "\t out of:", len(names_filtered), flush=True)
nligs_left=len(names_filtered)-previously_done
nligs_left=min(nligs_left, max_debug_evals)
nworkers=600
nligs_per_worker=int(np.ceil(float(nligs_left)/nworkers))
print("ligands left:", nligs_left, "\t # workers:", nworkers, "\t # ligands/worker:", nligs_per_worker,
      "\t estimated completition time:", nligs_per_worker, "min")
print("queue size:", q.qsize(), flush=True)

#raise()

os.makedirs(folder+'/lig_structures/', exist_ok=True)

# #remove old jobscripts
# print("Deleting old jobscripts")
# process = subprocess.Popen(['rm', folder+"/lig_structures_jobscripts/jobscript_*"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# process.wait()
# print(process.stdout.read().decode("utf-8"))
# print(process.stderr.read().decode("utf-8"))
# print("Finished deleting old jobscripts")
# # raise()

cwd=folder+"/lig_jobscripts/"
cmd_str="source /etc/profile; module load sge; cd {};".format(cwd)

def worker(job_id):
#     print("worker", job_id)
    ligands_str=""
    for l in range(nligs_per_worker):
        fname = q.get()
        if fname is None:  # EOF?
            break
        ligands_str+=" "+fname
    if(not ligands_str):
        return # skip writing jobscript if it will not handle any ligands
    jobscript_str=f"""
#!/bin/bash
#$ -S /bin/bash
#$ -pe openmp_* 1
#$ -q *
#$ -N lig_struct_gen_{job_id}
#$ -M ykhalak@gwdg.de
#$ -m n
#$ -l h_rt=4:00:00
#$ -wd {cwd}

cd $TMPDIR

source ~/.ML_profile
python {folder}/embed_script.py -f {ligands_str}
"""
    jobscript_fn=cwd+"/jobscript_{}".format(job_id)
    with open(jobscript_fn,"w") as f:
        f.write(jobscript_str)
        
                
    global cmd_str
    cmd_str+=" qsub {};".format(jobscript_fn)

#     ssh_cmd_arr=["ssh", "owl", "source /etc/profile; module load sge; cd {}; qsub {};".format(cwd, jobscript_fn)]
#     process = subprocess.Popen(ssh_cmd_arr, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
#     process.wait()


for job_id in range(nworkers):
    q.put(None)
    worker(job_id)
    
# raise()
    
print("Submitting.")
ssh_cmd_arr=["ssh", "owl", cmd_str]
process = subprocess.Popen(ssh_cmd_arr, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
process.wait()
print("Done.")

  0%|          | 25/129223 [00:00<08:46, 245.54it/s]

Building Queue.


100%|██████████| 129223/129223 [05:30<00:00, 390.49it/s]

previously_done: 73359 	 out of: 129223
ligands left: 55864 	 # workers: 600 	 # ligands/worker: 94 	 estimated completition time: 94 min
queue size: 55864





TypeError: exceptions must derive from BaseException

In [55]:
# print("Submitting.")
# ssh_cmd_arr=["ssh", "owl", cmd_str]
# process = subprocess.Popen(ssh_cmd_arr, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# process.wait()
# print("Done.")

Submitting.
Done.


# 3. check how many ligands are finished

In [57]:
previously_done=0
n_corrupt=0
#unfinished_pIC50=[]
#unfinished_refs_ids=[]
names_left=[]

for ref_id,n in enumerate(tqdm(names_filtered)):
    fname=folder+'/lig_structures/{}.pickle'.format(n)
    if(os.path.isfile(fname)):
        ref = pickle.load( open( fname, "rb" ) )[2]
        if(ref.HasProp('embedded') and ref.GetProp('embedded')=="yes"):
            previously_done+=1
            continue; #already handled
        elif(ref.HasProp('corrupt') and ref.GetProp('corrupt')=="yes"): #corrupt rings from incorrect ring colosure numbers in SMILES
            previously_done+=1
            n_corrupt+=1
            continue; #already handled
        else:
            #unfinished_refs_ids.append(ref_id)
            names_left.append(n)



print("previously_done:", previously_done, "\t out of:", len(names_filtered), flush=True)
print("corrupt:", n_corrupt)
print("unfinished:", len(names_left))

# print(unfinished_refs_ids)
    
pickle.dump( names_left, open( folder+"/set_4_lig_names_still_to_embed.pickle", "wb" ) )
gc.collect()

100%|██████████| 129223/129223 [10:54<00:00, 197.33it/s]

previously_done: 129223 	 out of: 129223





corrupt: 0
unfinished: 259901


94

In [58]:
print("previously_done:", previously_done, "\t out of:", len(names), flush=True)
print("corrupt:", n_corrupt)
print("unfinished:", len(names_left))

previously_done: 129223 	 out of: 384481
corrupt: 0
unfinished: 0


# 4. Save all sucessfully embedded ones as a pickle dataset.
Not all lignads embeded sucessfully. So we only use the ones that did and further filter out any that are too large.

In [59]:
ligs_to_save=[]
for n in tqdm(names_filtered):
    fname=folder+'/lig_structures/{}.pickle'.format(n)
    if(os.path.isfile(fname)):
        l = pickle.load( open( fname, "rb" ) )[2]
        if( l.HasProp('embedded') and l.GetProp('embedded')=="yes" and not (l.HasProp('corrupt')) ):
            ligs_to_save.append(l)

pickle.dump( ligs_to_save, open( folder+"/set_4_filtered_embedded.pickle", "wb" ) )
print(f"Saved {len(ligs_to_save)} ligands out of a total {len(names)} SMILES.")
gc.collect()

ligs_to_save=[]
for n in tqdm(names_filtered):
    fname=folder+'/lig_structures/{}.pickle'.format(n)
    if(os.path.isfile(fname)):
        l = pickle.load( open( fname, "rb" ) )[2]
        if( l.HasProp('embedded') and l.GetProp('embedded')=="yes" and (Chem.rdmolops.GetFormalCharge(l)==0) ):
            ligs_to_save.append(l)
pickle.dump( ligs_to_save, open( folder+"/set_4_filtered_embedded_neutral_only.pickle", "wb" ) )
print(f"Saved {len(ligs_to_save)} neutral ligands out of a total {len(names)} SMILES.")
gc.collect()

100%|██████████| 129223/129223 [00:19<00:00, 6682.41it/s]
  1%|          | 705/129223 [00:00<00:18, 7047.25it/s]

Saved 129223 ligands out of a total 384481 SMILES.


100%|██████████| 129223/129223 [00:18<00:00, 7131.68it/s]


Saved 129223 neutral ligands out of a total 384481 SMILES.


0