In [9]:
latent_size=56

#dictionaries of target and off target paths to protein structures and coordinates of centres to use
#box size in angstroms can be specified per target in dictionary as well but in our case we just want 30 for all
box_size = [30,30,30]
target_details = {"FLT3": ["./proteins/flt3.pdbqt",[-28.03685,	-10.361925,	-28.9883],box_size]}
off_target_details = {"CKIT": ["./proteins/ckit.pdbqt",[45.93035714,	97.03574286,	16.1472],box_size],"PDGFRA":["./proteins/pdgfra.pdbqt",[17.58837931,	132.5595172,	-6.030275862],box_size]
                      ,"VEGFR":["./proteins/VEGFR.pdbqt",[25.997,	28.605,	17.134],box_size],"MK2":["./proteins/MK2.pdbqt",[47.6395,	34.809,	16.708],box_size],"JAK2":["./proteins/JAK2.pdbqt",[-31.7445,	-49.661,	35.4655],box_size]}

load_train_path = None
load_mol_path = None
pred_train_path = "./predictions/200properbayes_train4.pk1" #path to save vectors and predictions to
pred_mol_path = "./predictions/200properbayes4.pk1" #path to save SMILES and predictions to upon completion
retry_best=False #whether or not to retry the best observed predictions at double exhaustiveness, legacy feature

hit_threshold = 0.4 #threshold for sampling near points to optimise acquisition function
beta = 4184/(8.3145*310) #1/kT
exh = 8 #exhaustiveness to run vina at
failure_bind = -5.0 #binding energy to use in objective function when rdkit fails to generate a conformer
anybind = -6.0 #binding energy to use in objective function to encourage strong binding to target
init_iters = 150 #iterations to initialise for
num_iters = 200 #iterations to run Bayesian Optimisation for

In [10]:
import os
import pandas as pd
import seaborn as sns
import rdkit
import torch
import vina
import meeko
import pexpect
import pickle
import numpy as np
from scipy.special import erfinv
from selectivebayes.interfaces import vaeinterface,vinainterface
from selectivebayes.initializers import gen_batch_initial_conditions

import warnings
from botorch.exceptions.warnings import InputDataWarning
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.models import FixedNoiseGP
from botorch.optim import optimize_acqf
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood


In [11]:
#initialise vina interfaces
print("Starting")
(target_name,target_loc),=target_details.items()
target_interface = {target_name:vinainterface(*target_loc)}
off_target_interfaces = {off_target_name:vinainterface(*off_target_details[off_target_name]) for off_target_name in off_target_details.keys()}

Starting
Vina Initialisation complete
Vina Initialisation complete
Vina Initialisation complete
Vina Initialisation complete
Vina Initialisation complete
Vina Initialisation complete


In [12]:
dtype = torch.double

pred_list = pickle.load(open(load_mol_path,"rb")) if load_mol_path is not None else []
seen_preds = set(*pred_list)
ind=0
best_pred = 0

#define a function which takes in vectors and returns selectivity scores
def optfunction(vectors):
    global ind
    
    global best_pred
    predictions = torch.zeros(vectors.shape[0],1,dtype=dtype)
    for v in range(vectors.shape[0]):
        ind+=1
        #cast variable uniformly distributed between 0 and 1 to standard normal distribution
        mol=vaeint.decode(np.sqrt(2)*erfinv(2*np.expand_dims(vectors[v,:].numpy(),axis=0)-1))

        if mol in seen_preds:
            print("SEEN")

        target_pred,success = target_interface[target_name].predict(mol,exh)
        
        if success!=-1:
            off_target_preds = {off_target_name:off_target_interfaces[off_target_name].predict(mol,exh)[0] for off_target_name in off_target_interfaces.keys()}
        else:
            off_target_preds = {off_target_name:failure_bind for off_target_name in off_target_interfaces.keys()}
        all_preds = [target_pred,*list(off_target_preds.values())]           

        prediction = np.exp(-beta*target_pred)/(np.exp(-beta*anybind)+np.sum([np.exp(-beta*pr) for pr in all_preds]))
        if prediction>best_pred and success!=-1 and retry_best:
            #if prediction is better than best prediction seen then run again at 2x exhaustiveness to confirm
            target_pred = target_interface[target_name].predict(mol,exh*2)[0]
            off_target_preds = {off_target_name:off_target_interfaces[off_target_name].predict(mol,exh*2)[0] for off_target_name in off_target_interfaces.keys()}
            all_preds = [target_pred,*list(off_target_preds.values())]
            prediction = np.exp(-beta*target_pred)/(np.exp(-beta*anybind)+np.sum([np.exp(-beta*pr) for pr in all_preds]))
        if prediction>best_pred:
            best_pred = prediction
        predictions[v]=prediction

        seen_preds.add(mol)
        pred_list.append([mol,prediction,all_preds])

        print(f"{ind}: {mol}, Pred: {prediction:.4f}")
        print(target_name + f":{target_pred:.2f}, " + "".join([off_target_name +f": {off_target_preds[off_target_name]:.2f}, " for off_target_name in off_target_preds.keys()]))
    
    return predictions

In [13]:
with warnings.catch_warnings():
    warnings.simplefilter('ignore', InputDataWarning)
    print("Starting")
    vaeint=vaeinterface()
    vaeint.start()
    bounds = torch.stack([torch.zeros(latent_size, dtype=dtype), torch.ones(latent_size, dtype=dtype)])
    
    if load_train_path is None:

        train_X = torch.rand(init_iters, latent_size, dtype=dtype)

        train_Y = optfunction(train_X)
    else:

        train_X,train_Y = pickle.load(open(load_train_path,"rb"))
    print("Optimization phase starting")
    for iter in range(num_iters):
        #fixed noise GP with expected improvement heuristic
        gp = FixedNoiseGP(train_X, train_Y,torch.ones(train_Y.shape,dtype=dtype)*0.2)

        mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        fit_gpytorch_mll(mll)

        EI = ExpectedImprovement(gp, best_f=best_pred, maximize=True)
        #seed acquisition function optimisation with 1024 samples around points with objective scores>0.4 with 10 restarts
        best_pct = len([i for i in train_Y if i>hit_threshold])/len(train_Y)*100

        candidate, acq_value = optimize_acqf(
                acq_function=EI,
                bounds=bounds,
                q=1,
                batch_initial_conditions=gen_batch_initial_conditions(EI,bounds,1,10,1024,best_pct,options={"sample_around_best":True,"sample_around_best_sigma":0.1}),
                num_restarts=10,
            )

        #find the closest point and distance for the selected point from the initial points, useful to follow optimisation and exploration/exploitation 
        best_distance = np.inf
        best_i=0
        for i in range(init_iters):
            x=train_X[i]
            distance = (candidate-x).norm()
            if distance<best_distance:
                best_distance=distance
                best_i=i
        print(f"closest index: {best_i}, dist: {best_distance.item()},acq: {acq_value.item()}")

        train_X=torch.cat((train_X,candidate),dim=0)
        prediction = optfunction(candidate)
        train_Y=torch.cat((train_Y,prediction),dim=0)

        if iter%10==0:
            #save every 10th iteration
            pickle.dump([train_X,train_Y],open(pred_train_path,"wb"))
            pickle.dump(pred_list,open(pred_mol_path,"wb"),protocol=2)

    vaeint.stop()
pickle.dump(pred_list,open(pred_mol_path,"wb"),protocol=2)
pickle.dump([train_X,train_Y],open(pred_train_path,"wb"))

Starting
1: Cc1nnc([C@@H]2CCCN2C(=O)C[S@](=O)c2ccccc2CO)[nH]1, Pred: 0.3190
FLT3:-5.96, CKIT: -5.59, PDGFRA: -4.47, VEGFR: -4.83, MK2: -4.54, JAK2: -4.86, 
2: CCOC(=O)c1cccc(NC(=O)[C@@H]2COCCO2)c1, Pred: 0.1311
FLT3:-5.17, CKIT: -5.42, PDGFRA: -4.25, VEGFR: -4.49, MK2: -4.54, JAK2: -4.55, 
3: Cc1ccccc1COCC(=O)NCc1[nH]nc2c1CSC2, Pred: 0.3662
FLT3:-6.53, CKIT: -6.47, PDGFRA: -4.10, VEGFR: -5.37, MK2: -4.20, JAK2: -5.59, 
4: C[C@@H]1CCC(=O)O1, Pred: 0.0181
FLT3:-3.58, CKIT: -3.48, PDGFRA: -3.38, VEGFR: -3.01, MK2: -3.25, JAK2: -3.41, 
5: Cc1nonc1C(=O)N(C1(N)C[C@@H]2C=CC=C[C@@H]2C1)[C@]1(C)CCOC1, Pred: 0.1135
FLT3:-5.08, CKIT: -4.79, PDGFRA: -4.74, VEGFR: -4.97, MK2: -4.28, JAK2: -5.15, 
6: Cc1occc1CN(C)C(=O)CN1CC[C@H](c2ccccc2)c2ccccc21, Pred: 0.2308
FLT3:-6.39, CKIT: -6.45, PDGFRA: -5.56, VEGFR: -4.46, MK2: -3.59, JAK2: -6.61, 
7: Cc1cc(C(=O)N(CC(=O)[O-])Cc2ccccc2)c2cc[nH]c2n1, Pred: 0.3157
FLT3:-5.96, CKIT: -5.33, PDGFRA: -4.47, VEGFR: -4.88, MK2: -4.66, JAK2: -5.32, 
8: CCn1ncc(C(=O)Nc