In [1]:
%load_ext autoreload
%autoreload 2
# python import
import time
import random
import numpy as np
import logging
import random
import sys
sys.path.append('./rxnft_vae')

# rxnft_vae imports
from rxnft_vae.reaction import ReactionTree, extract_starting_reactants, StartingReactants, Templates, extract_templates
from rxnft_vae.fragment import FragmentVocab, FragmentTree
from rxnft_vae.vae import bFTRXNVAE
from rxnft_vae.mpn import MPN
from rxnft_vae.reaction_utils import read_multistep_rxns, get_qed_score,get_clogp_score

# torch
import torch

# tqdm
from tqdm import tqdm

# my binary vae utils
import binary_vae_utils




def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.

hidden_size = 300

latent_size = 100

depth = 2

data_filename = "./data/data.txt"
w_save_path = "./weights/hidden_size_300_latent_size_100_depth_2_beta_1.0_lr_0.001/bvae_iter-30-with.npy"
metric = "qed"

seed = binary_vae_utils.RANDOM_SEED

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print("hidden size:", hidden_size, "latent_size:", latent_size, "depth:", depth)
print("loading data.....")
routes, scores = read_multistep_rxns(data_filename)
rxn_trees = [ReactionTree(route) for route in routes]
molecules = [rxn_tree.molecule_nodes[0].smiles for rxn_tree in rxn_trees]
reactants = extract_starting_reactants(rxn_trees)
templates, n_reacts = extract_templates(rxn_trees)
reactantDic = StartingReactants(reactants)
templateDic = Templates(templates, n_reacts)

print("size of reactant dic:", reactantDic.size())
print("size of template dic:", templateDic.size())

n_pairs = len(routes)
ind_list = [i for i in range(n_pairs)]

fgm_trees = []
valid_id = []
for i in tqdm(ind_list):
    try:
        fgm_trees.append(FragmentTree(rxn_trees[i].molecule_nodes[0].smiles))
        valid_id.append(i)
    except Exception as e:
        # print(e)
        continue
rxn_trees = [rxn_trees[i] for i in valid_id]

print("size of fgm_trees:", len(fgm_trees))
print("size of rxn_trees:", len(rxn_trees))
data_pairs = []
for fgm_tree, rxn_tree in zip(fgm_trees, rxn_trees):
    data_pairs.append((fgm_tree, rxn_tree))
cset = set()
for fgm_tree in fgm_trees:
    for node in fgm_tree.nodes:
        cset.add(node.smiles)
cset = list(cset)
fragmentDic = FragmentVocab(cset)

print("size of fragment dic:", fragmentDic.size())


mpn = MPN(hidden_size, depth)
model = bFTRXNVAE(fragmentDic, reactantDic, templateDic, hidden_size, latent_size, depth, device,
                    fragment_embedding=None, reactant_embedding=None, template_embedding=None).to(device)
checkpoint = torch.load(w_save_path, map_location=device)
model.load_state_dict(checkpoint)
print("finished loading model...")


seed_all(seed)


X_train, y_train, X_test, y_test = binary_vae_utils.prepare_dataset(model=model, data_pairs=data_pairs,latent_size=latent_size)

X_train = torch.Tensor(X_train)
y_train = torch.Tensor(y_train)
X_test = torch.Tensor(X_test)
y_test = torch.Tensor(y_test)

FM_surrogate = binary_vae_utils.FactorizationMachineSurrogate(n_binary=latent_size//2,k_factors=binary_vae_utils.FACTOR_NUM,random_seed=seed)

options = {
    "LICENSEID": 2687913,
    "WLSACCESSID": "5cbfb8e1-0066-4b7f-ab40-579464946573",
    "WLSSECRET": "a5c475ea-ec91-4cd6-94e9-b73395e273d6"
}

gurobi_solver = binary_vae_utils.GurobiQuboSolver(options)

hidden size: 300 latent_size: 100 depth: 2
loading data.....
size of reactant dic: 9766
size of template dic: 5567


100%|██████████| 21218/21218 [02:42<00:00, 130.80it/s]


size of fgm_trees: 20080
size of rxn_trees: 20080
size of fragment dic: 273




finished loading model...
number of samples: 20080


20080it [06:35, 50.71it/s]


(18072, 50) (2008, 50) (18072, 1) (2008, 1)
Using Gurobi with provided options: {'LICENSEID': 2687913, 'WLSACCESSID': '5cbfb8e1-0066-4b7f-ab40-579464946573', 'WLSSECRET': 'a5c475ea-ec91-4cd6-94e9-b73395e273d6'}
Set parameter WLSAccessID
Set parameter WLSSecret
Set parameter LicenseID to value 2687913
Academic license 2687913 - for non-commercial use only - registered to 89___@edu.k.u-tokyo.ac.jp


In [3]:
import binary_vae_utils
optimizer = binary_vae_utils.MoleculeOptimizer(bvae_model=model,surrogate_model=FM_surrogate,X_train=X_train,y_train=y_train,X_test=X_test,y_test=y_test,qubo_solver=gurobi_solver)

start_time = time.time()

optimizer.optimize()

logging.info("Running Time: %f" % (time.time() - start_time))

--- Starting Iteration 0 ---
lr:  0.001
Model -- Epoch 0 error on validation set: 0.1177, r2 on validation set: -1.6264
Model -- Epoch 100 error on validation set: 0.0438, r2 on validation set: 0.0235
Model -- Epoch 200 error on validation set: 0.0437, r2 on validation set: 0.0249
Model -- Epoch 300 error on validation set: 0.0435, r2 on validation set: 0.0295
Model -- Epoch 400 error on validation set: 0.0436, r2 on validation set: 0.0273
Model -- Epoch 500 error on validation set: 0.0439, r2 on validation set: 0.0205
Model -- Epoch 600 error on validation set: 0.0438, r2 on validation set: 0.0227
Model -- Epoch 700 error on validation set: 0.0436, r2 on validation set: 0.0281
Model -- Epoch 800 error on validation set: 0.0434, r2 on validation set: 0.0310
Model -- Epoch 900 error on validation set: 0.0433, r2 on validation set: 0.0337
Model -- Epoch 1000 error on validation set: 0.0440, r2 on validation set: 0.0192
Model -- Epoch 734 has lowest error!
(2008, 1, 1) (2008, 1)
best torc

  return self._call_impl(*args, **kwargs)
  1%|          | 40/5000 [00:10<22:06,  3.74it/s]


[['Cc1cc2c(NCCc3ccc4c(c3)OCO4)nc(-n3ccnc3)nc2s1', 'Cc1cc2c(NCCc3ccc4c(c3)OCO4)nc(-n3ccnc3)nc2s1*Cc1cc2c(NCCc3ccc4c(c3)OCO4)nc(Cl)nc2s1.c1c[nH]cn1*([#16;a:5]:[c:4]:[#7;a:3]:[c;H0;D3;+0:1](:[#7;a:2])-[n;H0;D3;+0:8]1:[c:7]:[#7;a:6]:[c:10]:[c:9]:1)>>Cl-[c;H0;D3;+0:1](:[#7;a:2]):[#7;a:3]:[c:4]:[#16;a:5].[#7;a:6]1:[c:7]:[nH;D2;+0:8]:[c:9]:[c:10]:1 Cc1cc2c(NCCc3ccc4c(c3)OCO4)nc(Cl)nc2s1*NCCc1ccc2c(c1)OCO2.Cc1cc2c(Cl)nc(Cl)nc2s1*([#16;a:5]:[c:4](:[#7;a:6]):[c:3]:[c;H0;D3;+0:1](:[#7;a:2])-[NH;D2;+0:8]-[C:7])>>Cl-[c;H0;D3;+0:1](:[#7;a:2]):[c:3]:[c:4](:[#16;a:5]):[#7;a:6].[C:7]-[NH2;D1;+0:8]'], ['COc1ccc(CNc2nc(-n3ccnc3)nc3sc(C)c(Cl)c23)cc1OC', 'COc1ccc(CNc2nc(-n3ccnc3)nc3sc(C)c(Cl)c23)cc1OC*c1c[nH]cn1.COc1ccc(CNc2nc(Cl)nc3sc(C)c(Cl)c23)cc1OC*([#16;a:5]:[c:4]:[#7;a:3]:[c;H0;D3;+0:1](:[#7;a:2])-[n;H0;D3;+0:8]1:[c:7]:[#7;a:6]:[c:10]:[c:9]:1)>>Cl-[c;H0;D3;+0:1](:[#7;a:2]):[#7;a:3]:[c:4]:[#16;a:5].[#7;a:6]1:[c:7]:[nH;D2;+0:8]:[c:9]:[c:10]:1 COc1ccc(CNc2nc(Cl)nc3sc(C)c(Cl)c23)cc1OC*COc1ccc(CN)cc1OC.Cc

  return self._call_impl(*args, **kwargs)
  0%|          | 16/5000 [00:04<22:40,  3.66it/s]


[['COc1ccc(CNc2nc(-n3ccnc3)nc3sc(C)cc23)cc1OC', 'COc1ccc(CNc2nc(-n3ccnc3)nc3sc(C)cc23)cc1OC*c1c[nH]cn1.COc1ccc(CNc2nc(Cl)nc3sc(C)cc23)cc1OC*([#16;a:5]:[c:4]:[#7;a:3]:[c;H0;D3;+0:1](:[#7;a:2])-[n;H0;D3;+0:8]1:[c:7]:[#7;a:6]:[c:10]:[c:9]:1)>>Cl-[c;H0;D3;+0:1](:[#7;a:2]):[#7;a:3]:[c:4]:[#16;a:5].[#7;a:6]1:[c:7]:[nH;D2;+0:8]:[c:9]:[c:10]:1 COc1ccc(CNc2nc(Cl)nc3sc(C)cc23)cc1OC*COc1ccc(CN)cc1OC.Cc1cc2c(Cl)nc(Cl)nc2s1*([#16;a:5]:[c:4](:[#7;a:6]):[c:3]:[c;H0;D3;+0:1](:[#7;a:2])-[NH;D2;+0:8]-[C:7])>>Cl-[c;H0;D3;+0:1](:[#7;a:2]):[c:3]:[c:4](:[#16;a:5]):[#7;a:6].[C:7]-[NH2;D1;+0:8]'], ['CCc1cc2c(NCc3ccc(OC)c(OC)c3)nc(-n3ccnc3)nc2s1', 'CCc1cc2c(NCc3ccc(OC)c(OC)c3)nc(-n3ccnc3)nc2s1*c1c[nH]cn1.CCc1cc2c(NCc3ccc(OC)c(OC)c3)nc(Cl)nc2s1*([#16;a:5]:[c:4]:[#7;a:3]:[c;H0;D3;+0:1](:[#7;a:2])-[n;H0;D3;+0:8]1:[c:7]:[#7;a:6]:[c:10]:[c:9]:1)>>Cl-[c;H0;D3;+0:1](:[#7;a:2]):[#7;a:3]:[c:4]:[#16;a:5].[#7;a:6]1:[c:7]:[nH;D2;+0:8]:[c:9]:[c:10]:1 CCc1cc2c(NCc3ccc(OC)c(OC)c3)nc(Cl)nc2s1*COc1ccc(CN)cc1OC.CCc1cc2c(Cl)nc(C

KeyboardInterrupt: 

In [None]:
import numpy as np
import scipy.sparse as sp
from gurobi_optimods.qubo import solve_qubo

Q = np.array([[0, -1, -2], [0, -3, 3], [0, 0, 2]])

# weights = [-3, 2, -1, -2, 3]
# row = [1, 2, 0, 0, 1]
# col = [1, 2, 1, 2, 2]
# Q = sp.coo_array((weights, (row, col)), shape=(3, 3))

result = solve_qubo(Q)

New QUBO solution found with objective 0.0
New QUBO solution found with objective -1.0
New QUBO solution found with objective -4.0
