In [1]:
import sys
sys.path.append("/work01/home/wxxie/project/drug-gen/mollvae/MolLVAE/code")
from dataset import DatasetSplit
from opt import get_parser
from model.model import LVAE
from utils import set_seed

import torch

from tqdm import tqdm

from moses.utils import CircularBuffer

########## config

parser = get_parser()
config = parser.parse_args("--device cuda:0 \
                           --n_enc_zs 1 --n_dec_xs 1 --gen_bsz 128".split())

device = torch.device(config.device)
set_seed(config.seed)

#n_sample = config.n_sample
test_csv = config.test_load
load_model_from = "../res/exp/model_049.pt"

n_latcode = config.n_enc_zs # get n_latcode latent codes for each mol
n_dec_xs = config.n_dec_xs
gen_bsz = config.gen_bsz
max_len = config.max_len

########## utils

def tensor2string_ad(tensor, vocab):
    """ Adapted from model.tensor2string. Consider pad indx in tensor. """
    
    ids = tensor.tolist()
    if vocab.pad in ids:
        pad_idx = ids.index(vocab.pad)
        ids = ids[:pad_idx]
    string = vocab.ids2string(ids, rem_bos=True, rem_eos=True)
    return string


########## get data and load trained model

test_split = DatasetSplit("test", test_csv)
test_dataloader = test_split.get_dataloader(batch_size=gen_bsz, shuffle=False)

vocab = test_split._vocab

print("Loading trained model...")
model = LVAE(vocab, config)
model.load_state_dict(torch.load(load_model_from))
model.to(device)
model.eval()

Loading vocab...
Loading trained model...


LVAE(
  (embedding): Embedding(37, 128, padding_idx=35)
  (encoder): LSTM_encoder(
    (embedding): Embedding(37, 128, padding_idx=35)
    (lstm): LSTM(128, 256, batch_first=True, bidirectional=True)
  )
  (decoder): LSTM_decoder(
    (x_emb): Embedding(37, 128, padding_idx=35)
    (map_z2hc): Linear(in_features=28, out_features=512, bias=True)
    (decoder_fc): Linear(in_features=256, out_features=37, bias=True)
    (lstm): LSTM(156, 256, batch_first=True)
  )
  (top_down_layers): ModuleList(
    (0): MLP(
      (layer1): Linear(in_features=4, out_features=8, bias=True)
      (layer2): Linear(in_features=8, out_features=8, bias=True)
      (mu): Linear(in_features=8, out_features=8, bias=True)
      (var): Linear(in_features=8, out_features=8, bias=True)
    )
    (1): MLP(
      (layer1): Linear(in_features=8, out_features=16, bias=True)
      (layer2): Linear(in_features=16, out_features=16, bias=True)
      (mu): Linear(in_features=16, out_features=16, bias=True)
      (var): Linea

In [1]:



########## Check test set reconstruction rate
########## v1

success_cnt = 0
for batch in tqdm(test_dataloader):
    
    with torch.no_grad():
        
        batch = (batch[0].to(device), batch[1])
        
        _,h = model.encoder(batch)
        
        z_mu_q_d, z_log_var_q_d = model.bottom_up(h)
        
        ## sample n_latcode times in top ladder z
        qd_mu_top = z_mu_q_d[-1].unsqueeze(1).repeat(1, n_latcode, 1) # (bsz, n_latcode, z_size[-1])
        qd_logvar_top = z_log_var_q_d[-1].unsqueeze(1).repeat(1, n_latcode, 1)
        z_sample_top = model.sample_z(qd_mu_top, qd_logvar_top)
        
        
        ## Input SMILES for ref
        padded_x, _ = batch
        input_seqs = []
        for s in padded_x:
            input_seqs.append(tensor2string_ad(s, vocab))     
        
        for j in range(n_latcode):
            
            z_sample = []
            z_sample.append(z_sample_top[:,j,:])
            
            _,_,_,_,samples = model.top_down(z_mu_q_d,z_log_var_q_d, z_sample=z_sample, mode="eval")
            
            recon_seqs = model.sample(len(input_seqs), max_len=max_len, z_in=samples)
            
            
            for k, (r_s,s) in enumerate(zip(recon_seqs, input_seqs)): # loop over mols
                if r_s == s:
                    success_cnt += 1


total_trials = len(test_split.split_dataset.data) * n_latcode * n_dec_xs
print(f"Test set reconstruction rate: {1.*success_cnt / total_trials * 100}%")

Loading vocab...
Loading trained model...


100%|██████████| 871/871 [02:03<00:00,  7.07it/s]

Test set reconstruction rate: 2160.045924225029%





In [5]:
########## Check test set reconstruction rate
########## v2

success_cnt = 0
for batch in tqdm(test_dataloader):
    
    with torch.no_grad():
    
        batch = (batch[0].to(device), batch[1])
        
        ## Input SMILES for ref
        padded_x, _ = batch
        input_seqs = []
        for s in padded_x:
            input_seqs.append(tensor2string_ad(s, vocab))         
        
        _,h = model.encoder(batch)
        z,_ = model.forward_latent(h) # z already concated
        
        recon_seqs = model.sample(len(input_seqs), max_len=max_len, z_in=z, concated=True)
        
        success_cnt += sum(1 for r_s,s in zip(recon_seqs,input_seqs) if r_s==s)
        
total_trials = len(test_split.split_dataset.data) * n_latcode * n_dec_xs
print(f"Test set reconstruction rate: {1.*success_cnt / total_trials * 100}%")

100%|██████████| 871/871 [02:02<00:00,  7.11it/s]

Test set reconstruction rate: 17.03599989229651%





In [6]:
########## check recon_loss for test set

recon_loss_values = CircularBuffer(config.loss_buf_sz)
data = tqdm(test_dataloader)
for batch in data:
    
    with torch.no_grad():
        
        batch = (batch[0].to(device), batch[1])
        kl_loss, recon_loss = model(batch)
        
        recon_loss_values.add(recon_loss.item())
        recon_loss_value = recon_loss_values.mean()
        
        postfix = [f'recon={recon_loss_value:.5f})']
        data.set_postfix_str(' '.join(postfix))

NameError: name 'CircularBuffer' is not defined

In [8]:
success_cnt/len(test_split.split_dataset.data)

0.16886112532198858

In [3]:
from rdkit import Chem
val = [1 for s in recon_seqs if Chem.MolFromSmiles(s)]
sum(val)/len(val)

1.0

In [4]:
recon_seqs

['CC1CN(c2ccc(N)cc2)C(=O)c2cc(CN3CCS(=O)(=O)C(N)C(C)S3)c(F)c(Cl)c2=N1',
 'CCc1cn2c(c1S(=O)(=O)N1CCN(C(C)C)nn2-c1ccc(OC)c(OC)c1)S(C)(=O)=O',
 'CCN(CCO)CC(C)C1CC2OC3(C)C(O)C(C)(O)CCC4CC(C)C24CC(C)C(C1=O)C3',
 'Cc1ccc(-n2cnc(NC=C3C=CS(=N)(=O)N3)n2)c(N2CCN(c3cccc(C)c3)CC2)c1',
 'O=C(C1CCCN1C(=O)Oc1ccc(C(=O)N2CCCN3CCCC3(Cc3ccccc3)CC2)cc1)NCCC',
 'Cn1nccc1CN1C(NC(=O)C(Cl)(Cl)c2c[nH]cnc2=O)C(=O)c2ccccc2N1',
 'CCC1(O)C=C2CCC3(CC1)CCCC3c1ccc(CN4CCC5(CO3)C(=O)N4)cc1C2=O',
 'CC(Cn1c2c(=O)n(-c3ccc(F)c(-c4ccccc4C)cn3)cc2c1CCCO)C(=O)O',
 'COc1ccccc1C1C(=O)NC(c2ccc(C)cc2)N(c1ccnc(F)c1)c1ccccc1',
 'COc1ccc(C(=O)N2CCCC3(CCN(C(=O)c4cncc(C)c4C)C3)C2)cc1',
 'NS(=O)(=O)c1ccc(-n2nc(C(F)(F)F)cc3c2-c2ccccc2-3)s1',
 'Cc1ccc(S(=O)(=O)Nc2ccc3c(c2)C(=O)C(=O)c3ccccc2S3)cc1',
 'COc1cccc(C2CS(=O)(Nc3ccc(F)cc3)CN(C)c3ccc(O)cc32)c1',
 'CC(C)CC(NC(=O)c1n[nH]c2cc(C(F)(F)F)cn2c1O)c1ccccc1F',
 'CC(C)(C)OC(=O)N1CCN(C(=O)SCn2c(=O)cc(-c3scs3)ncc2=O)C1',
 'COc1ccc(Cc2nc(-c3ccc(OCCC(=O)NCCN)cc3)co2)cc1Br',
 'COc1ccc(-c2ccc3c(N

In [1]:
import sys
sys.path.append("/work01/home/wxxie/project/drug-gen/mollvae/MolLVAE/code")
from dataset import DatasetSplit
from opt import get_parser
from model.model import LVAE
from utils import set_seed

import torch

from tqdm import tqdm

In [2]:
########## config

parser = get_parser()
config = parser.parse_args("--device cuda:0 \
                           --n_enc_zs 10 --n_dec_xs 1 --gen_bsz 128".split())

device = torch.device(config.device)
set_seed(config.seed)

#n_sample = config.n_sample
test_csv = config.test_load
load_model_from = "../res/exp/model_049.pt"

n_latcode = config.n_enc_zs # get n_latcode latent codes for each mol
n_dec_xs = config.n_dec_xs
gen_bsz = config.gen_bsz
max_len = config.max_len

config

Namespace(clip_grad=50, dec_hid_sz=256, dec_n_layer=1, dec_type='lstm', device='cuda:0', dropout=0.1, emb_sz=128, enc_bidirectional=True, enc_hidden_size=256, enc_num_layers=1, enc_sorted_seq=True, enc_type='lstm', gen_bsz=128, kl_anr_type='cyclic', kl_e_start=0, kl_n_cycle=1, kl_w_end=0.001, kl_w_start=0.0001, ladder_d_size=[128, 64, 32], ladder_z2z_layer_size=[8, 16], ladder_z_size=[16, 8, 4], log_path=None, loss_buf_sz=20, lr_anr_type='SGDR', lr_end=1e-06, lr_mult_coeff=1, lr_n_restarts=5, lr_period=10, lr_start=0.00030000000000000003, max_len=150, model_save=None, n_dec_xs=1, n_enc_zs=10, n_epoch=100, n_sample=1000, ratio=0.2, save_frequency=10, seed=56, test_load='../data/test.csv', train_bsz=512, train_load='../data/train.csv', valid_load='../data/valid.csv')

In [3]:

########## get data and load trained model

test_split = DatasetSplit("test", test_csv)
test_dataloader = test_split.get_dataloader(batch_size=gen_bsz, shuffle=False)

vocab = test_split._vocab

print("Loading trained model...")
model = LVAE(vocab, config)
model.load_state_dict(torch.load(load_model_from))
model.to(device)
model.eval()

Loading vocab...
Loading trained model...


LVAE(
  (embedding): Embedding(37, 128, padding_idx=35)
  (encoder): LSTM_encoder(
    (embedding): Embedding(37, 128, padding_idx=35)
    (lstm): LSTM(128, 256, batch_first=True, bidirectional=True)
  )
  (decoder): LSTM_decoder(
    (x_emb): Embedding(37, 128, padding_idx=35)
    (map_z2hc): Linear(in_features=28, out_features=512, bias=True)
    (decoder_fc): Linear(in_features=256, out_features=37, bias=True)
    (lstm): LSTM(156, 256, batch_first=True)
  )
  (top_down_layers): ModuleList(
    (0): MLP(
      (layer1): Linear(in_features=4, out_features=8, bias=True)
      (layer2): Linear(in_features=8, out_features=8, bias=True)
      (mu): Linear(in_features=8, out_features=8, bias=True)
      (var): Linear(in_features=8, out_features=8, bias=True)
    )
    (1): MLP(
      (layer1): Linear(in_features=8, out_features=16, bias=True)
      (layer2): Linear(in_features=16, out_features=16, bias=True)
      (mu): Linear(in_features=16, out_features=16, bias=True)
      (var): Linea

In [4]:
########## utils

def tensor2string_ad(tensor, vocab):
    ids = tensor.tolist()
    if vocab.pad in ids:
        pad_idx = ids.index(vocab.pad)
        ids = ids[:pad_idx]
    string = vocab.ids2string(ids, rem_bos=True, rem_eos=True)
    return string

In [11]:
train_split = DatasetSplit("train", config.train_load)
train_dataloader = train_split.get_dataloader(batch_size=gen_bsz, shuffle=False)
test_data_loader = train_dataloader

Loading vocab...


In [12]:
for batch in tqdm(test_dataloader):
    
    success_cnt = 0
    
    with torch.no_grad():
        
        batch = (batch[0].to(device), batch[1])
        
        _,h = model.encoder(batch)
        
        z_mu_q_d, z_log_var_q_d = model.bottom_up(h)
        
        ## get 10 encoded latent codes for each mol
        qd_mu_top = z_mu_q_d[-1].unsqueeze(1).repeat(1, n_latcode, 1) # (bsz, n_latcode, z_size[-1])
        qd_logvar_top = z_log_var_q_d[-1].unsqueeze(1).repeat(1, n_latcode, 1)
        z_sample_top = model.sample_z(qd_mu_top, qd_logvar_top)
        
        ## get input smiles
        seqs,_ = batch
        input_seqs = []
        for s in seqs:
            input_seqs.append(tensor2string_ad(s, vocab))
        #print(input_seqs)        
        
        success_batch_cnt = [0 for _ in range(gen_bsz)]
        for j in range(n_latcode):
            
            z_sample = []
            z_sample.append(z_sample_top[:,j,:])
            
            _,_,_,_,samples = model.top_down(z_mu_q_d,z_log_var_q_d, z_sample=z_sample, mode="eval")
            
            recon_seqs = model.sample(gen_bsz, max_len=max_len, z_in=samples)
            
            for k, (r_s,s) in enumerate(zip(recon_seqs, input_seqs)): # loop over mols
                if r_s == s:
                    success_batch_cnt[k] += 1
             
        #? debug
        break
        
    success_cnt += sum(success_batch_cnt)

n_x = gen_bsz #len(test_dataloader)
total_trials = n_x * n_latcode * n_dec_xs
print(f"Test set reconstruction rate: {1.*success_cnt / total_trials * 100}%")

  0%|          | 0/871 [00:04<?, ?it/s]

Test set reconstruction rate: 0.0%





In [9]:
from rdkit import Chem

tof = [1 for s in recon_seqs if Chem.MolFromSmiles(s)]

sum(tof)/len(tof)



1.0

In [8]:
qd_mu_top[0]

tensor([[ 0.0304, -0.0196,  0.0056, -0.0181],
        [ 0.0304, -0.0196,  0.0056, -0.0181],
        [ 0.0304, -0.0196,  0.0056, -0.0181],
        [ 0.0304, -0.0196,  0.0056, -0.0181],
        [ 0.0304, -0.0196,  0.0056, -0.0181],
        [ 0.0304, -0.0196,  0.0056, -0.0181],
        [ 0.0304, -0.0196,  0.0056, -0.0181],
        [ 0.0304, -0.0196,  0.0056, -0.0181],
        [ 0.0304, -0.0196,  0.0056, -0.0181],
        [ 0.0304, -0.0196,  0.0056, -0.0181]], device='cuda:0')

In [34]:
input_seqs

['CC1CN(c2ccc(F)cc2C(F)(F)F)CCN1S(=O)(=O)c1ncc(C(O)(C(N)=O)C(F)(F)F)s1',
 'CCc1c2c(nn1C(=O)N(C)C)C(=O)N(c1cc(C)c(=O)n(C)c1)C2c1ccc(Cl)cc1',
 'CC(CCC(=O)O)C1CCC2(C)C3C(=O)CC4C(C)(C)C(O)CCC4(C)C3C(=O)CC12C',
 'Cc1ccc(-n2nccn2)c(C(=O)N2CCN(c3ncc(C(F)(F)F)c(C)n3)CCC2C)c1',
 'O=C(C1CCCN1C(=O)c1ccc(C(=O)N2CCCC2C(=O)N2CCCC2)cc1)N1CCCC1',
 'Cn1cncc1CN1CC(N(CC(N)=O)S(=O)(=O)c2ccccn2)Cc2cc(C#N)ccc21',
 'CC#CC1(O)CCC2C3CCC4=CC(=O)CCC4=C3C(c3ccc(N(C)C)cc3)=CC21C',
 'CC(Cn1c2c(c3cc(NS(=O)(=O)c4ccc(F)cc4)ccc31)CCCC2)C(=O)O',
 'COc1ccccc1C1C(=O)N(C)c2ccc(C(F)(F)F)cc2N(c2ccccc2)C1=O',
 'COc1ccc(C(=O)N2CCCC3(CCN(C(=O)Nc4cccc(C#N)c4)C3)C2)cc1',
 'NS(=O)(=O)c1ccc(-n2nc(C(F)(F)F)cc2-c2cc3ccccc3o2)cc1',
 'Cc1ccc(S(=O)(=O)Nc2ccc3c(c2)C(=O)C(=O)c2ccccc2-3)cc1',
 'COc1cccc(C2C([N+](=O)[O-])=C(N)Oc3cc(C)oc(=O)c32)c1',
 'CC(C)CC(NC(=O)c1cn(-c2ccncc2)c(-c2ccccc2)n1)C(=O)O',
 'CC(C)(C)OC(=O)N1CCN(C(=S)SCc2cn(Cc3ccccc3F)nn2)CC1',
 'COc1ccc(Cc2nc(-c3ccc(OCCCCN(C)C)cc3)c[nH]c2=O)cc1',
 'COc1ccc(-c2ccc3c(N)c(C(=O)Nc4

In [40]:
from rdkit import Chem
m=Chem.MolFromSmiles('OC1CCN2Cc3ccccc3N=C12')

In [42]:
m

<rdkit.Chem.rdchem.Mol at 0x7f595c319350>

In [33]:
z_sample_top[0]

tensor([[-0.1563,  0.8204,  0.3920, -0.4285]], device='cuda:0')

In [None]:
['Cc1cc(C(=O)N2CC3CC(=NNC(C)CS(=O)(=O)NCCCCC3)COC)nc1)c1', 'Oc1ccc(NN2CCOCC34cc(-c4ccccc4C)nc33)CCC2)c1']
['Cc1cccc(N2CCN(CC(=O)Nc3ccccc3[N+](=O)[O-])CC2)c1C', 'Nc1cnc(-c2ccc(C3CCC3)c(OCC3CNC3)c2F)cn1']

In [26]:
ids=seqs.tolist()

In [10]:
test_split.split_dataset.data[:2]

['Cc1cccc(N2CCN(CC(=O)Nc3ccccc3[N+](=O)[O-])CC2)c1C',
 'Nc1cnc(-c2ccc(C3CCC3)c(OCC3CNC3)c2F)cn1']

In [11]:
vocab.pad

35

In [13]:
seqs,_ = batch
smis = []
for s in seqs:
    pad_idx = s.index(vocab.pad)
    smis.append(model.tensor2string(s[:pad_idx]))
print(smis)

AttributeError: 'Tensor' object has no attribute 'index'

In [19]:
t=torch.randn(1,2)
a=torch.tensor([1,2])

In [20]:
a + t

tensor([[1.3853, 1.5616]])

In [21]:
t

tensor([[ 0.3853, -0.4384]])

In [22]:
t ** 2

tensor([[0.1484, 0.1922]])

In [23]:
t.size()

torch.Size([1, 2])

In [25]:
t=torch.randn(2,2,4)

In [26]:
t

tensor([[[-0.5933, -0.1151,  0.5760,  1.1554],
         [ 0.2207,  0.7938,  0.0841, -1.7516]],

        [[ 0.9030, -0.2162, -0.6445,  0.8059],
         [-0.7157, -0.5792,  0.0085,  0.6672]]])

In [27]:
torch.ones(2,2,4)+t

tensor([[[ 0.4067,  0.8849,  1.5760,  2.1554],
         [ 1.2207,  1.7938,  1.0841, -0.7516]],

        [[ 1.9030,  0.7838,  0.3555,  1.8059],
         [ 0.2843,  0.4208,  1.0085,  1.6672]]])