In [1]:
import sys
sys.path.append('/workspace/')

import sqlite3
import mols2grid
import importlib
import pickle
import itertools
import concurrent

import numpy as np
import pandas as pd

import cuml
import cupy as cp

from functools import partial
from subprocess import run
from rdkit import Chem
from rdkit.Chem import Draw, QED, Descriptors, Lipinski, rdDistGeom, rdmolfiles

from flow.pipeline.screening.pose_generate import score_molecule, generate_conformers
from flow.utils.megamolbart import smiles_to_embedding, embedding_to_smiles, sample

# Code to disable rdkit errors and warning
import rdkit.rdBase as rkrb
import rdkit.RDLogger as rkl

import warnings
warnings.filterwarnings('ignore')

log = rkl.logger()
log.setLevel(rkl.ERROR)
rkrb.DisableLog('rdApp.error')

In [2]:
with open('/workspace/test_data.pkl', 'rb') as file:
    # A new file will be created
    data = pickle.load(file)
    x0_smis = data['x0_smis']
    x0_dims = data['x0_dims']
    x0_embs = data['x0_embs']
    y0_scrs = data['y0_scrs']

    x1_smis = data['x1_smis']
    x1_dims = data['x1_dims']
    x1_embs = data['x1_embs']
    y1_scrs = data['y1_scrs']
    
y0_scrs = cp.array(y0_scrs, dtype=np.float)
y1_scrs = cp.array(y1_scrs, dtype=np.float)
x1_embs = x1_embs[:50]
# x0_smis, x0_dims, x0_embs, y0_scrs, x1_smis, x1_dims, x1_embs, y1_scrs

In [3]:
for i in range(len(x0_embs)):
    x0_embs[i] = cp.reshape(x0_embs[i], x0_dims[i]).squeeze().flatten()

for i in range(len(x1_embs)):
    x1_embs[i] = cp.reshape(x1_embs[i], x1_dims[i]).squeeze().flatten()
    
cp.asarray(x0_embs).shape, cp.asarray(x1_embs).shape

((100, 262144), (50, 262144))

Multiquadric:

Where ${\varepsilon}$ is the shape parameter

$ {\displaystyle \varphi (r)={\sqrt {1+(\varepsilon r)^{2}}}}$

In [4]:
epsilon = 0.01
# Multiquadric

def multiquadric_rbf(radius, epsilon):
    return cp.sqrt(1 + (epsilon * radius)**2)

Now compute pairewise Euclidean distance.

$ {\textstyle r=\left\|\mathbf {\vec{x}} -\mathbf {\vec{x}} _{i}\right\|} $

In [5]:
r = cuml.metrics.pairwise_distances(
    cp.asarray(x0_embs),
    cp.asarray(x1_embs),
    metric='euclidean').T
r.shape

(50, 100)

$ {\textstyle \vec{r}= {\vec{x}} -\mathbf {\vec{x}} _{i}} $

In [6]:
r1 = []
# TODO: Revisit for bigger dataset
for emb1 in x1_embs:
    row = []
    for emb0 in x0_embs:
        row.append(emb0.squeeze().flatten() - emb1.squeeze().flatten())
    
    r1.append(cp.asarray(row))

r1 = cp.asarray(r1)

Now compute the gradient

$
{\displaystyle y(\mathbf {x} )=\sum _{i=1}^{N}w_{i}\,\varphi (\left\|\mathbf {\vec{x}} -\mathbf {\vec{x}} _{i}\right\|)}
$

Therefore applying 
$
y_i = w_{i} A
$


Where
$
A =
  \begin{pmatrix}
    \varphi (\left\|\mathbf {\vec{x}}_{0} -\mathbf {\vec{x}} _{0}\right\|) & {...} &\varphi (\left\|\mathbf {\vec{x}}_{i} -\mathbf {\vec{x}} _{0}\right\|)\\
    \vdots & \ddots & \vdots\\
    \varphi (\left\|\mathbf {\vec{x}}_{0} -\mathbf {\vec{x}} _{i}\right\|) & {...} &\varphi (\left\|\mathbf {\vec{x}}_{i} -\mathbf {\vec{x}} _{i}\right\|)
  \end{pmatrix}
$

In [7]:
A = multiquadric_rbf(r, epsilon)
w = cp.matmul(cp.linalg.pinv(A).T, cp.asarray(y0_scrs))
A.shape, w.shape

((50, 100), (50,))

In [8]:
r.shape , r1.shape

((50, 100), (50, 100, 262144))

$ \vec{\nabla} $$y(\vec{x}) = \sum \vec{\nabla}\varphi \left\|\mathbf {\vec{x}} -\mathbf {\vec{x}} _{i}\right\|$

$ {\qquad =\sum _{i=1}^{N}w_{i} {\dfrac {\varepsilon^{2}}{\sqrt {1+(\varepsilon^{2} r^{2})}}}} \vec{r1}$

In [9]:
tmp = ((epsilon**2) / cp.sqrt(1 + (epsilon**2 * r**2)))
tmp = cp.multiply(tmp.reshape(tmp.shape + (1,)), r1)
tmp.shape

(50, 100, 262144)

In [10]:
grad = cp.einsum("i,ijk->jk", w, tmp)

In [11]:
cp.asarray(grad).max()

array(0.10021946)

In [12]:
# coeff = cp.matmul(((epsilon**2) / cp.sqrt(1 + (epsilon**2 * r**2))), r1)
# coeff.shape
# grad = cp.sum(cp.matmul(w, ), axis=1)
# grad

In [13]:
for i in range(len(x0_embs)):
    emb = x0_embs[i]
    dim = x0_dims[i]
    # projected_emb = emb
    projected_emb = emb - grad[i], dim
    projected_emb = cp.reshape(cp.array(emb), dim)
    projected_emb = projected_emb.flatten().tolist()
    mask = list(itertools.repeat(False, dim[0]))
    result = embedding_to_smiles(projected_emb, list(dim), mask, service_port='localhost:50052')
    gsmi = result.generatedSmiles[0]
    mol_wt = 0
    op_mol = Chem.MolFromSmiles(gsmi)
    if op_mol:
        mol_wt =  Descriptors.MolWt(op_mol)
    print(y0_scrs[i], mol_wt, gsmi)

500.4780000000002 0 C(F)(F)(c1ccc(-c2cn3ncc(C(N)=O)c(N[C@H]4[C@@H](CF)CN(S(=O)(=O)C)C[C@H](CF)[C@H]4C4)c3n2)cn1)F
303.74500000000006 325.7950000000001 OCCNCc1ccc2oc3cc(Cl)ccc3c3ccc=3c2c1
352.3990000000001 427.46600000000024 Cc1nc(NCC(N)=O)nc(NCC(=O)NCCN(CCO)CC(=O)NCCO)n1
281.31500000000005 266.3 CC(=O)NN1C(=O)c2ccccc2C1c1ccccc1
475.45800000000025 0 C1(n2c3ncnc(NCCc4cc(OC)c(OC)c(OC)c4)c3)C(=O)C2=C(C(=O)C2)CN1
229.70699999999997 0 C1CC(N)CC2CCC(C3CCC3)CCC1
388.47000000000014 237.34299999999996 COC(=O)CN1C2CCC(CC3CC1(C)C3)C2
290.407 274.40799999999996 CCCN(CCC)CCc1ccc(C)c2c1CC(=O)N2
631.1730000000001 0 c1cc(COc2ccc(Nc3nccc4c5c6c5cc6c(cc5)CC5)ccc6n(c3)CC4)cc2c1cccc2O
248.28500000000003 259.308 O=C(c1ccccc1)c1ccccc1-c1ccccn1
402.45400000000006 0 c1c(-c2ccnc3c(C)ccnc2-c2ccncc2)ccc(C)c1CC(C(=O)NCc1ccncc1)(C)O
266.301 278.44 C/C=C1\CC(C)/C(=C/NCCN(C)C)C(C)C(C)C1=O
439.4750000000002 0 c1ccc2c(c1)-c1nc(-n3c(C)c(C(NCc4cccc5ccccc5)cc4)cn3)ncc1CC2
377.25700000000006 295.224 Brc1cccc(CN2CCN(C3CC3)CC