In [None]:
import sys
sys.path.append('/workspace/')

import sqlite3
import mols2grid
import importlib
import pickle
import itertools
import concurrent
import pandas as pd

import numpy as np
import pandas as pd

import cuml
import cupy as cp

from functools import partial
from subprocess import run
from rdkit import Chem
from rdkit.Chem import Draw, QED, Descriptors, Lipinski, rdDistGeom, rdmolfiles

from flow.pipeline.screening.pose_generate import score_molecule, generate_conformers
from flow.utils.megamolbart import smiles_to_embedding, embedding_to_smiles, sample

import warnings
warnings.filterwarnings('ignore')

# Code to disable rdkit errors and warning
import rdkit.rdBase as rkrb
import rdkit.RDLogger as rkl

log = rkl.logger()
log.setLevel(rkl.ERROR)
rkrb.DisableLog('rdApp.error')

In [None]:
workspace = '/content/28b0e566-a15a-11ec-83ca-7de881115940'

receptor_file = f'{workspace}/inputs/rec.pdbqt'
score_config = f'{workspace}/inputs/config'

#### Please start MegaMolBART service before this cell
```
docker-compose --env-file .env\
                -f support/docker/megamolbart/docker-compose.yml\
                up  --scale megamolbart=2
```

In [None]:
%%time
db_url = '/data/chembl.db'
conn = sqlite3.connect(db_url, uri=True)
generation = 0

df = pd.read_sql(
    '''
     SELECT canonical_smiles as smiles
     FROM compound_structures order by random()
     LIMIT 100
     ''', 
     con=conn)
# df.head()

In [None]:
%%time

x0_smis = []
x0_dims = []
x0_embs = []
y0_scrs = []

x1_smis = []
x1_dims = []
x1_embs = []
y1_scrs = []

for smi in df['smiles'].tolist():
    embs = None
    while True:
        try:
            embs = sample(smi,
                          num_sample=1, 
                          padding_size=512,
                          service_port='localhost:50052')
        except Exception as ex:
            print(ex)
            break
        if Chem.MolFromSmiles(embs[1]['smiles']) is None or embs[0]['smiles'] == embs[1]['smiles']:
            # print(f'{x1_smi} is invalid or same as input {x0_smi}')
            continue

        emb = embs[0]['embedding']
        x0_embs.append(cp.reshape(cp.array(emb.embedding), emb.dim).squeeze())
        x0_dims.append(list(emb.dim))
        x0_smis.append(embs[0]['smiles'])

        emb = embs[1]['embedding']
        x1_embs.append(cp.reshape(cp.array(emb.embedding), emb.dim).squeeze())
        x1_dims.append(list(emb.dim))
        x1_smis.append(embs[1]['smiles'])
        break
    
x0_smis, x1_smis

In [None]:
for smi in x0_smis:
    min_score, score_model = score_molecule(smi,
                                            receptor_file=receptor_file,
                                            score_config=score_config,
                                            conformer_dir='/tmp')
    y0_scrs.append(min_score)
    
for smi in x1_smis:
    min_score, score_model = score_molecule(smi,
                                            receptor_file=receptor_file,
                                            score_config=score_config,
                                            conformer_dir='/tmp')
    y1_scrs.append(min_score)
y0_scrs, y1_scrs

In [None]:
# %%time

# db_url = f'{workspace}/common.sqlite3'
# conn = sqlite3.connect(db_url, uri=True)
# df = pd.read_sql(
#     '''
#     Select smiles, embedding, embedding_dim, score from generated_smiles
#     LIMIT 100
#     ''', 
#     con=conn)

# for i in range(df.shape[0]):
#     dim = pickle.loads(df.embedding_dim[i])
#     x0_smi = df.smiles[i]
#     x0_smis.append(x0_smi)
#     x0_dims.append(dim)
#     x0_embs.append(cp.reshape(cp.array(pickle.loads(df.embedding[i])), dim))
#     y0_scrs.append(df.score[i])
    
#     while(True):
#         x1_samples = sample(x0_smi, service_port='localhost:50052', num_sample=1)
#         x1_sample = x1_samples[1]
#         x1_smi = x1_sample['smiles']
        
#         #TODO: This could lead to an infinite loop
#         if Chem.MolFromSmiles(x1_smi) is None or x0_smi == x1_smi:
#             # print(f'{x1_smi} is invalid or same as input {x0_smi}')
#             continue

#         dim = x1_sample['embedding'].dim
#         x1_smis.append(x1_smi)
#         x1_dims.append(cp.array(dim))
#         x1_embs.append(cp.reshape(cp.array(x1_sample['embedding'].embedding), dim))
#         break
#     # y1_scrs.append(df.score[i])

# # smi, emb, dim, scr

In [None]:
# # Reshape embedding to equal shape
# import itertools
# max_size = x1_embs[0].shape
# for emb in x1_embs:
#     print(emb.shape)
#     if max_size[0] < emb.shape[0]:
#         max_size = emb.shape

# for i in range(len(x1_embs)):
#     emb = x0_embs[i]
#     dim = x0_dims[i]
#     smi = x0_smis[i]
#     mask = itertools.repeat(False, max_size[0])
#     emb = cp.resize(emb, max_size)
#     flattened = emb.flatten().tolist()
#     re_smi = embedding_to_smiles(flattened, max_size, mask, service_port='localhost:50052')
#     print(type(flattened), dim, len(smi), smi, re_smi)
    
# max_size

Filter the matching input and sample pairs.

In [None]:
# %%time
# from functools import partial
# import concurrent

# receptor_file = f'{workspace}/inputs/rec.pdbqt'
# score_config = f'{workspace}/inputs/config'

# partial_func = partial(score_molecule,
#                        receptor_file=receptor_file,
#                        score_config=score_config,
#                        conformer_dir='/tmp')

# with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
#     futures = {executor.submit(partial_func, smi): smi for smi in x1_smis}

#     y1_map = dict(zip(x1_smis, itertools.repeat(None, len(x1_smis))))
#     for future in concurrent.futures.as_completed(futures):
#         min_score, score_model = future.result()
#         y1_map[futures[future]] = min_score
#     y1_scrs = list(y1_map.values())
# y1_scrs

In [None]:
data = dict({
    'x0_smis': x0_smis,
    'x0_dims': x0_dims,
    'x0_embs': x0_embs,
    'y0_scrs': y0_scrs,
    'x1_smis': x1_smis,
    'x1_dims': x1_dims,
    'x1_embs': x1_embs,
    'y1_scrs': y1_scrs
    })

with open('/workspace/test_data.pkl', 'wb') as file:
    # A new file will be created
    pickle.dump(data, file)

In [None]:
with open('/workspace/test_data.pkl', 'rb') as file:
    # A new file will be created
    data = pickle.load(file)

In [None]:
for i in range(len(x1_dims)):
    x1_dims[i] = list(x1_dims[i])

    
for i in range(len(x0_dims)):
    x0_dims[i] = list(x0_dims[i])


In [None]:
data