In [10]:
from argparse import Namespace
from re import S
from data.util import load_tokenizer
from collections import OrderedDict

import torch
from torch.utils.data import DataLoader
from rdkit import Chem

import pytorch_lightning as pl
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

#from module.factory import load_encoder, load_decoder
from module.decoder.lstm import LSTMDecoder
from module.decoder.transformer import TransformerDecoder

from module.encoder.gnn import GNNEncoder
from module.vq_layer import FlattenedVectorQuantizeLayer
from pl_module.util import compute_sequence_cross_entropy, compute_sequence_accuracy
from data.factory import load_dataset, load_collate
from data.smiles.util import canonicalize

from pl_module.autoencoder import AutoEncoderModule
from data.util import load_smiles_list, load_tokenizer
from tokenizers import Tokenizer
from tokenizers import pre_tokenizers
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.processors import TemplateProcessing


In [2]:
batch_size = 256
num_workers = 8
decoder_select = 'lstm'

# for latent space
code_dim = 256

# for encoder GNN
encoder_num_layers = 5
encoder_hidden_dim = 256

# for decoder lstm
decoder_num_layers = 3
decoder_hidden_dim = 1024

# for decoder transformer
encoder_layers = 6
emb_size = 1024
nhead = 8
dim_feedforward = 2048
dropout = 0.1

In [3]:
train_dataset = load_dataset("graph2seq", "zinc", "train")
val_dataset = load_dataset("graph2seq", "zinc", "valid")
collate = load_collate("graph2seq")

train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate,
            num_workers=num_workers,
        )

val_dataloader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate,
            num_workers=num_workers,
        )

In [11]:
encoder = GNNEncoder(
            encoder_num_layers,
            encoder_hidden_dim,
            code_dim
)
if decoder_select == "lstm":
    decoder = LSTMDecoder(
            num_layers=decoder_num_layers, 
            hidden_dim=decoder_hidden_dim, 
            code_dim=code_dim
            )
elif decoder_select == "transformer":
    decoder = TransformerDecoder(
                num_encoder_layers = encoder_layers,
                emb_size = emb_size,
                nhead = nhead,
                dim_feedforward = dim_feedforward,
                dropout = dropout,
                code_dim = code_dim
            )

encoder_dict = OrderedDict()
decoder_dict = OrderedDict()
#print(torch.load("../resource/checkpoint/run0_autoencoder/lstm/best.ckpt")["state_dict"].keys())
for k,v in torch.load("../resource/checkpoint/run0_autoencoder/lstm/best.ckpt")["state_dict"].items():
    if k.startswith("encoder"):
        encoder_dict['.'.join(k.split(".")[1:])] = v
    elif k.startswith("decoder"):
        decoder_dict['.'.join(k.split(".")[1:])] = v
        
encoder.load_state_dict(encoder_dict)
decoder.load_state_dict(decoder_dict)
tokenizer = load_tokenizer()
eos = tokenizer.token_to_id("[EOS]")

pos = 0
neg = 0
non_valid = 0

with torch.no_grad():
    for data in val_dataloader:
        batched_input_data, batched_target_data = data
        codes = encoder(batched_input_data)
        logits = decoder(batched_target_data, codes)

        batch_size = batched_target_data.size(0)
        logits = logits[:, :-1]
        targets = batched_target_data[:, 1:]
        preds = torch.argmax(logits, dim=-1)
        # correct = preds[6] == targets[6]
        # correct[targets[6] == 0] = True
        # elem_acc = correct[targets[6] != 0].float().mean()
        # print(preds[6])
        # print(elem_acc)
        for i in range(len(targets)):
            target = list(targets[i])
            pred = list(preds[i])
            try:
                target = target[:target.index(3)]
                pred = pred[:pred.index(3)]
            except:
                continue
            target_seq = tokenizer.decode(target).replace(" ","")
            pred_seq = tokenizer.decode(pred).replace(" ","")
            try:
                Chem.MolFromSmiles(pred_seq)
                print('ok')
            except:
                print('not okay')
                non_valid += 1
            if target_seq == pred_seq:
                pos += 1
                print(target_seq)
                print(pred_seq)
            else:
                print(target_seq)
                print(pred_seq)
                neg += 1
print(pos)
print(neg)
print(non_valid)
        #recon_loss = compute_sequence_cross_entropy(logits, batched_target_data)
        #elem_acc, seq_acc = compute_sequence_accuracy(logits, batched_target_data)


ok
N#Cc1cccc(CS(=O)(=O)N[C@H](c2cccc(C(F)(F)F)c2)C2CC2)c1
N#Cc1cccc(CS(=O)(=O)N[C@H](c2cccc(C(F)(F)F)c2)C2CC2)c1
ok
Cc1cccc(C(=O)N2CCCN(C(=O)Cn3cnnn3)CC2)c1
Cc1cccc(C(=O)N2CCCN(C(=O)Cn3cnnn3)CC2)c1
ok
Cc1cc([C@@H](C)[NH2+][C@@H]2CSCC(C)(C)C2)c(C)o1
Cc1cc([C@@H](C)[NH2+][C@@H]2CSCC(C)(C)C2)c(C)o1
ok
CC(=O)c1ccc(NC(=O)[C@H](C)Oc2ccc(C#N)cc2)cc1C
CC(=O)c1ccc(NC(=O)[C@H](C)Oc2ccc(C#N)cc2CC
ok
Cc1ccc(NC2CCC([NH3+])CC2)cc1C(F)(F)F
Cc1ccc(N2CCC(C[NH3+])C2)cc1C(F)(F)F
ok
COc1ccc(Nc2nnc(SCC(=O)NC3CCCC3)s2)cc1
COc1ccc(Nc2nnc(SCC(=O)NC3CCCC3)s2)cc1
ok
Cc1cc(F)c(CCCCCC(=O)[O-])cc1F
Cc1cc(F)c(CCCCC(=O)[O-](=O)[O-])cc1F
ok
Cc1n[nH]c(C)c1CCC(=O)Nc1cc(N2CCN(Cc3ccccc3)CC2)ncn1
Cc1n[nH]c(C)c1CCC(=O)NC1CC(N2CCN(Cc3ccccc3)CC2)ncn1
ok
COc1ccc([C@@H]2CCCN2C(=O)c2nnn[n-]2)cc1
COc1ccc([C@@H]2CCCN2C(=O)c2nnn[n-]2)cc1
ok
C[C@@H](Sc1nncn1Cc1ccco1)C(=O)c1c[nH]c2ccccc12
C[C@@H](Sc1nncn1Cc1ccco1)C(=O)c1c[nH]c2ccccc12
ok
COc1ccc(O)c(-c2nc3ccc(N)nc3[nH]2)c1
COc1ccc(O)c(-c2nc3n(N)ccc3[nH]2)c1
ok
NC(=O)c1cnc(NC2CCN(c3n

In [12]:
encoder = GNNEncoder(
            encoder_num_layers,
            encoder_hidden_dim,
            code_dim
)
if decoder_select == "lstm":
    decoder = LSTMDecoder(
            num_layers=decoder_num_layers, 
            hidden_dim=decoder_hidden_dim, 
            code_dim=code_dim
            )
elif decoder_select == "transformer":
    decoder = TransformerDecoder(
                num_encoder_layers = encoder_layers,
                emb_size = emb_size,
                nhead = nhead,
                dim_feedforward = dim_feedforward,
                dropout = dropout,
                code_dim = code_dim
            )

encoder_dict = OrderedDict()
decoder_dict = OrderedDict()
#print(torch.load("../resource/checkpoint/run0_autoencoder/lstm/best.ckpt")["state_dict"].keys())
for k,v in torch.load("../resource/checkpoint/run0_autoencoder/lstm/best.ckpt")["state_dict"].items():
    if k.startswith("encoder"):
        encoder_dict['.'.join(k.split(".")[1:])] = v
    elif k.startswith("decoder"):
        decoder_dict['.'.join(k.split(".")[1:])] = v
        
encoder.load_state_dict(encoder_dict)
decoder.load_state_dict(decoder_dict)
tokenizer = load_tokenizer()
eos = tokenizer.token_to_id("[EOS]")

pos = 0
neg = 0
non_valid = 0
with open('check.txt', 'w') as f:
    with torch.no_grad():
        for data in val_dataloader:
            batched_input_data, batched_target_data = data
            codes = encoder(batched_input_data)
            logits = decoder(batched_target_data, codes)

            batch_size = batched_target_data.size(0)
            logits = logits[:, :-1]
            targets = batched_target_data[:, 1:]
            preds = torch.argmax(logits, dim=-1)
            # correct = preds[6] == targets[6]
            # correct[targets[6] == 0] = True
            # elem_acc = correct[targets[6] != 0].float().mean()
            # print(preds[6])
            # print(elem_acc)
            for i in range(len(targets)):
                target = list(targets[i])
                pred = list(preds[i])
                try:
                    target = target[:target.index(3)]
                    pred = pred[:pred.index(3)]
                except:
                    continue
                target_seq = tokenizer.decode(target).replace(" ","")
                pred_seq = tokenizer.decode(pred).replace(" ","")
                f.write(pred_seq)
                f.write('\n')
                try:
                    Chem.MolFromSmiles(pred_seq)
                    print('ok')
                except:
                    print('not okay')
                    non_valid += 1
                if target_seq == pred_seq:
                    pos += 1
                    print(target_seq)
                    print(pred_seq)
                else:
                    print(target_seq)
                    print(pred_seq)
                    neg += 1
print(pos)
print(neg)
print(non_valid)
        #recon_loss = compute_sequence_cross_entropy(logits, batched_target_data)
        #elem_acc, seq_acc = compute_sequence_accuracy(logits, batched_target_data)


ok
N#Cc1cccc(CS(=O)(=O)N[C@H](c2cccc(C(F)(F)F)c2)C2CC2)c1
N#Cc1cccc(CS(=O)(=O)N[C@H](c2cccc(C(F)(F)F)c2)C2CC2)c1
ok
Cc1cccc(C(=O)N2CCCN(C(=O)Cn3cnnn3)CC2)c1
Cc1cccc(C(=O)N2CCCN(C(=O)Cn3cnnn3)CC2)c1
ok
Cc1cc([C@@H](C)[NH2+][C@@H]2CSCC(C)(C)C2)c(C)o1
Cc1cc([C@@H](C)[NH2+][C@@H]2CSCC(C)(C)C2)c(C)o1
ok
CC(=O)c1ccc(NC(=O)[C@H](C)Oc2ccc(C#N)cc2)cc1C
CC(=O)c1ccc(NC(=O)[C@H](C)Oc2ccc(C#N)cc2CC
ok
Cc1ccc(NC2CCC([NH3+])CC2)cc1C(F)(F)F
Cc1ccc(N2CCC(C[NH3+])C2)cc1C(F)(F)F
ok
COc1ccc(Nc2nnc(SCC(=O)NC3CCCC3)s2)cc1
COc1ccc(Nc2nnc(SCC(=O)NC3CCCC3)s2)cc1
ok
Cc1cc(F)c(CCCCCC(=O)[O-])cc1F
Cc1cc(F)c(CCCCC(=O)[O-](=O)[O-])cc1F
ok
Cc1n[nH]c(C)c1CCC(=O)Nc1cc(N2CCN(Cc3ccccc3)CC2)ncn1
Cc1n[nH]c(C)c1CCC(=O)NC1CC(N2CCN(Cc3ccccc3)CC2)ncn1
ok
COc1ccc([C@@H]2CCCN2C(=O)c2nnn[n-]2)cc1
COc1ccc([C@@H]2CCCN2C(=O)c2nnn[n-]2)cc1
ok
C[C@@H](Sc1nncn1Cc1ccco1)C(=O)c1c[nH]c2ccccc12
C[C@@H](Sc1nncn1Cc1ccco1)C(=O)c1c[nH]c2ccccc12
ok
COc1ccc(O)c(-c2nc3ccc(N)nc3[nH]2)c1
COc1ccc(O)c(-c2nc3n(N)ccc3[nH]2)c1
ok
NC(=O)c1cnc(NC2CCN(c3n