In [1]:
# lr:0.001
# anneal_rate:0.9
# batch_size:32
# clip_norm:50
# num_epochs:5
# alpha:250
# beta:0
# max_beta:1
# step_beta:0.002
# anneal_iter:40000
# kl_anneal_iter:2000
# print_iter:100
# save_iter:5000
# num_workers:4

In [2]:
!pip install -q torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu116.html
!pip install -q dive-into-graphs

[0m

In [3]:
!pip install -q toolz
!pip install -q wandb

[0m

In [4]:
%load_ext autoreload
%autoreload 2
import os
import json
import argparse
import pickle 

import numpy as np
import pandas as pd
import torch

from molecule_optimizer.externals.fast_jtnn.datautils import SemiMolTreeFolder, SemiMolTreeFolderTest
from molecule_optimizer.runner.semi_jtvae import SemiJTVAEGeneratorPredictor
from torch_geometric.data import DenseDataLoader

import rdkit

lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

import warnings
warnings.filterwarnings("ignore")

import math

In [11]:
with open('runner_20_LogP_50_1_iter_10000.xml', 'rb') as f:
    runner = pickle.load(f)

In [32]:
from rdkit import Chem
from tqdm import tqdm

def generate_molecules(
    model, 
    training_list, # list of smiles for training
    molecules_target=3000, # try to generate this number of novel molecules 
    attempt_limit=10000, # try up to attempt limit times
    target_condition=None, # None if unconditional, target property, torch tensor of size 1
):
    num_molecules = 0
    num_attempts = 0
    
    num_invalid = 0
    num_in_training_set = 0
    num_duplicates = 0
    
    generated_set = set()
    
    with tqdm(total=molecules_target) as gen_pbar:
        while (num_molecules < molecules_target) and (num_attempts < attempt_limit):
            # Generate new molecule
            smi = model.sample_prior(label=target_condition)
            num_attempts += 1

            # Check if invalid
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                num_invalid += 1
                continue

            # Check if in training set
            if smi in training_list:
                num_in_training_set += 1
                continue 

            if smi in generated_set:
                num_duplicates += 1
                continue

            generated_set.add(smi)
            num_molecules += 1
            gen_pbar.update(1)
    
    print(f"no. generated: {num_attempts} (100%)")
    print(f"no. invalid: {num_invalid} ({num_invalid / num_attempts * 100}%)")
    print(f"no. in training set: {num_in_training_set} ({num_in_training_set / num_attempts * 100}%)")
    print(f"no. duplicated: {num_duplicates} ({num_duplicates / num_attempts * 100}%)")
    print(f"no. new unique: {num_molecules} ({num_molecules / num_attempts * 100}%)")
    
    return generated_set, num_molecules, num_attempts, num_invalid, num_in_training_set, num_duplicates

In [36]:
cond = torch.randn(1, 1)
cond

tensor([[-0.6647]])

In [37]:
generate_molecules(runner.vae, [], molecules_target=10, target_condition=cond)

100%|██████████| 10/10 [00:01<00:00,  9.93it/s]

no. generated: 10 (100%)
no. invalid: 0 (0.0%)
no. in training set: 0 (0.0%)
no. duplicated: 0 (0.0%)
no. new unique: 10 (100.0%)





({'CC(=O)NCc1cccnc1',
  'CC(=O)Nc1ccc(CS(=O)C2CCCCC2)cc1',
  'COc1ncnc(NCc2cccc3ccncc23)n1',
  'COc1nnccc1CN1CCC(C(N)=O)CC1',
  'CSCc1ccccc1NC(=O)C1CC1C',
  'Cc1ccc(NC(=O)c2ccc[nH+]c2)s1',
  'Cc1csc(NS(=O)(=O)c2ccc(C(C)O)cc2)c1',
  'NC(=O)CSc1ccc(N2CCCC2)nn1',
  'Nc1cnc(N2CCNCC2(N)C(=O)c2ccccc2Br)nc1',
  'O=C(NC1CCCCC1)c1cccs1'},
 10,
 10,
 0,
 0,
 0)

In [33]:
generate_molecules(runner.vae, [])

100%|██████████| 3000/3000 [05:29<00:00,  9.10it/s]

no. generated: 3242 (100%)
no. invalid: 0 (0.0%)
no. in training set: 0 (0.0%)
no. duplicated: 242 (7.464528069093153%)
no. new unique: 3000 (92.53547193090684%)





({'Fc1ccc(CN2CCCCC2c2nnc[nH]2)cc1',
  'Cc1cccc2c1CCN2C(=O)c1cccnc1',
  'CC(c1ccc2c(c1)OCCCO2)N1C2CCC1CN(C)CC2',
  'Cc1ccc(NC(=O)COc2ccccc2F)s1',
  'O=C(NCc1ccccc1Cl)c1cccs1',
  'CN(C)c1cncc(C(=O)NCc2ccccc2S(N)(=O)=O)c1',
  'CC1CCNC=CN1C(=O)COc1ccccn1',
  'Cc1cccc(-c2noc(N3CCC(C4SCNC4=O)CC3)n2)c1',
  'Nc1cnnc(C(=O)NCC2CCS(=O)(=O)C2)c1',
  'CC1CCCCCN1C(=O)CSc1cccc(Cl)c1',
  'NC1CC1C(=O)Nc1ccccc1CSc1ccccn1',
  'Cc1cccc(NS(=O)(=O)CC(O)c2ccsc2)c1',
  'O=C(CO)N1CCCC(c2cn[nH]c2)C1',
  'CC1COCC1NCc1ccc(C(=O)N2CCNCC2)cc1',
  'Cc1cnncc1NC(=O)CSc1ccc(Cl)cc1',
  'CN(C)C1(C)CCCN(C(=O)c2cccc3c2OCCO3)C1',
  'COc1nc(C(=O)NC2=CCCC2)no1',
  'NC(=O)c1cccc(S(=O)(=O)N2CC=CCC2)c1',
  'NC(c1ccccc1S(N)(=O)=O)C1CCOC1=O',
  'CC1CC1S(=O)(=O)NCc1ccccc1',
  'Cc1ccc(OCC(=O)Nc2ccc(F)cc2)cc1',
  'Cc1ccc(C2CCCCN2C=O)cc1',
  'COc1ccc(C(=O)NCc2noc(C)n2)cc1',
  'CC(COc1ccc2[nH]ccc2c1)N1CCCCC1=O',
  'Cc1cccc(NS(=O)(=O)C(C)O)c1',
  'Cc1cnccc1NC(=O)c1cccc[nH+]1',
  'NC(=O)c1ccoc1CN1CCCCC12OCCO2',
  'Cc1ccccc1OCC(=O)NC1CC1',