In [1]:
import pickle
import os
from smart_open import open
import numpy as np
import random
import torch
import itertools
from tqdm import tqdm
from coatiLDM.common.utils import utc_epoch_now, batch_iterable
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem import Draw
from coatiLDM.models.io import load_score_model_from_model_doc, load_due_cg_from_model_doc
import pandas as pd
from coatiLDM.constants import FIGURE_DATA_PATH, QED_OPT_DOCS, COATI2_DOCS
from coatiLDM.common.s3 import load_figure_file

from rdkit.Chem import QED
from rdkit.DataStructs.cDataStructs import TanimotoSimilarity

In [2]:

uncond, _, _ = load_score_model_from_model_doc(QED_OPT_DOCS['score_model'])
guide = load_due_cg_from_model_doc(QED_OPT_DOCS['guide'])

In [3]:
qed_data_path = './qed_opt/'

qed_smiles = load_figure_file('qed_start_smiles.csv',local_dir=qed_data_path,filetype='csv',has_header=False)

In [4]:
qed_smiles

Unnamed: 0,0
0,CC(=O)NCCNC(=O)c1cnn(-c2ccc(C)c(Cl)c2)c1C1CC1
1,C[C@@H](C(=O)C1=c2ccccc2=[NH+]C1)[NH+]1CCC[C@@...
2,CCN(C[C@@H]1CCOC1)C(=O)c1ccnc(Cl)c1
3,Cc1ccccc1C[S@](=O)CCCc1ccccc1
4,CSCC(=O)NNC(=O)c1c(O)cc(Cl)cc1Cl
...,...
795,O=C(CCCn1c(=O)oc2ccccc21)Nc1ccc(Cl)cc1F
796,C[C@H](C#N)Sc1nc(-c2ccccc2)cs1
797,CS(=O)(=O)CCCOc1ccc(-n2cncn2)cc1
798,OCc1cccc(-c2ncccn2)c1


In [5]:
qed_smiles['clean_smiles'] = qed_smiles.apply(lambda x: Chem.MolToSmiles(Chem.MolFromSmiles(x[0])),axis=1)
qed_smiles['no_iso_smiles'] = qed_smiles.apply(lambda x: Chem.MolToSmiles(Chem.MolFromSmiles(x[0]),isomericSmiles=False),axis=1)

In [6]:
smiles_list = qed_smiles['clean_smiles'].tolist()
iso_smiles = qed_smiles['no_iso_smiles'].tolist()

In [7]:
smiles_recs = [{'SMILES': x, 'no_iso': y} for x,y in zip(smiles_list,iso_smiles)]

In [8]:

from coatiLDM.models.coati.io import load_coati2
DEVICE = 'cuda'

MODEL_DOC = COATI2_DOCS['qed_doc']
encoder, tokenizer = load_coati2(MODEL_DOC,
                                    freeze = True,
                                    device = DEVICE)
#encoder = torch.compile(encoder)


Loading model from s3://terray-ml/ben/paper_resources/models/qed_doc.pt
Loading tokenizer coati2_12_12 from s3://terray-ml/ben/paper_resources/models/qed_doc.pt
number of parameters: 50.44M
number of parameters Total: xformer: 54.81M 
vocab_name not found in tokenizer_vocabs, trying to load from file
Freezing encoder
56385536 params frozen!


In [9]:
import torch

In [10]:
# This is lazy and assumes everything tokenizes (It does)
for chunk in tqdm(batch_iterable(smiles_recs, 2048)):
    toks, _ = tokenizer.batch_smiles([row['SMILES'] for row in chunk])
    res = encoder.encode_tokens(toks.to(DEVICE), tokenizer).cpu().numpy()
    for i, entry in enumerate(chunk):
        entry['emb_smiles'] = res[i,:]

1it [00:01,  1.77s/it]


In [11]:
from coatiLDM.models.diffusion_models.ddpm_sample_routines import ddpm_cg_nearby

In [12]:
uncond = uncond.to(DEVICE)
guide = guide.to(DEVICE)

In [13]:
from rdkit import Chem
from rdkit.Chem import QED
from rdkit.Chem import AllChem
from rdkit.DataStructs.cDataStructs import TanimotoSimilarity

def check_criterion(targ_smi, gen_smiles):
    targ_mol = Chem.MolFromSmiles(targ_smi)
    gen_mols = [Chem.MolFromSmiles(x) for x in gen_smiles]
    gen_mols = [x for x in gen_mols if x is not None]

    # Calculate 1024-bit Morgan fingerprint for target molecule
    targ_fp = AllChem.GetMorganFingerprintAsBitVect(targ_mol, radius=2, nBits=2048)

    for mol in gen_mols:
        qed_score = QED.qed(mol)
        if qed_score >= 0.9:
            # Calculate 1024-bit Morgan fingerprint for generated molecule
            #print('cool')
            gen_fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
            # Calculate Tanimoto similarity
            similarity = TanimotoSimilarity(targ_fp, gen_fp)
            if similarity >= 0.4:
                return Chem.MolToSmiles(mol), qed_score

    return None, None  # If no molecule meets the criteria
    



In [18]:
from coatiLDM.data.decoding import force_decode_valid_batch_efficient
smiles_results = {}
#smiles_results = pickle.load(open('smiles_results_best.pkl','rb'))
weights = [10,30,50]
batch_size = 100
cdf_targ = .995
MAX_ORACLE_CALLS = 50000
#cond = embed_scalar(torch.tensor([cdf_targ]*num_per_run),embedding_dim=32)
cond_targ = torch.tensor([cdf_targ]*batch_size).reshape(-1,1).to(DEVICE)
for idx, smi in enumerate(smiles_recs):
    if smi['SMILES'] in smiles_results and smiles_results[smi['SMILES']] is not None:
        print('skipping')
        continue
    smiles_results[smi['SMILES']] = None
    rep_emb = torch.tensor(np.tile(smi['emb_smiles'],(batch_size,1))).to(DEVICE)
    found = False
    oracle_calls =0
    while oracle_calls < MAX_ORACLE_CALLS and not found:
        for T in [50,100,150,200]:
            for weight in weights:
                attempt = ddpm_cg_nearby(uncond_score_net=uncond,emb_batch=rep_emb,T_start=T,cg_due=guide,targets=cond_targ,cg_weight=weight)
                pred_smiles = force_decode_valid_batch_efficient(attempt,encoder,tokenizer,max_attempts=10, silent=True)
                smi_match, smi_qed = check_criterion(smi['SMILES'],pred_smiles)
                oracle_calls += batch_size
                if smi_match is not None:
                    
                    smiles_results[smi['SMILES']] = (smi_match,smi_qed,weight)
                    print(f'FOUND {len([x for x in smiles_results if smiles_results[x] is not None])} out of 800')
                    print(f'MISSED {len([x for x in smiles_results if smiles_results[x] is None])} out of 800')
                    print(f'untested: {800-len(smiles_results)}')
                    found = True
                    break
                if oracle_calls >= MAX_ORACLE_CALLS:
                    break
            if found or oracle_calls >= MAX_ORACLE_CALLS:
                break
        print(f'oracle calls {oracle_calls} for {smi["SMILES"]}')
    print(f'total found {len([x for x in smiles_results if smiles_results[x] is not None])} out of 800')
    if not found:
        print(f'not found {smi["SMILES"]}')


FOUND 1 out of 800
MISSED 0 out of 800
untested: 799
oracle calls 400 for CC(=O)NCCNC(=O)c1cnn(-c2ccc(C)c(Cl)c2)c1C1CC1
total found 1 out of 800


KeyboardInterrupt: 

In [5]:
paper_results = load_figure_file('qed_opt_smiles.pkl',local_dir=qed_data_path,filetype='pkl')

In [6]:
print(len([x for x in paper_results.values() if x is not None]))

765


In [21]:
print(765/len(paper_results))

0.95625


In [22]:
# validate paper results
for k in paper_results:
    targ = Chem.MolFromSmiles(k)
    if paper_results[k] is None:
        continue
    gen = Chem.MolFromSmiles(paper_results[k][0])
    gen_fp = AllChem.GetMorganFingerprintAsBitVect(gen, radius=2, nBits=2048)
    targ_fp = AllChem.GetMorganFingerprintAsBitVect(targ, radius=2, nBits=2048)
    # Calculate Tanimoto similarity
    similarity = TanimotoSimilarity(targ_fp, gen_fp)
    if similarity < 0.4:
        raise ValueError('missed sim cutoff')
    qed_score = QED.qed(gen)
    if qed_score < 0.9:
        raise ValueError('missed qed cutoff')
passed_percent = len([x for x in paper_results.values() if x is not None])/len(paper_results)
print(f'percentage success: {passed_percent}')

percentage success: 0.95625
