In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
from pprint import pprint
import tqdm

from scipy.stats import chi2_contingency

In [3]:
from rdkit import Chem

In [4]:
from molbart.decoder import DecodeSampler
from molbart.tokeniser import MolEncTokeniser
from molbart.models.pre_train import BARTModel

import molbart.util as util
import torch
from molbart.data.datamodules import MoleculeDataModule

from rdkit import Chem

## Read ZINC smiles

In [5]:
smiles_to_check = []
with open("data/now.smi", "r") as f:
    smiles_to_check = f.readlines()

In [6]:
smiles_to_check = [i.split()[0].strip() for i in smiles_to_check]

## Read MU-targeting smiles (alternatively)

In [8]:
dframe = pd.read_csv("data/smiles_data.csv")

In [32]:
smiles_to_check = dframe[dframe["set"]=="val"]["smiles"].values

In [31]:
dframe

Unnamed: 0.1,Unnamed: 0,smiles,fname,set
0,0,c1ccccc1N(CNC2=O)C23CCN(CC3)C(=O)OCc4cccc(c4)O...,mu_human_IC50_binding_all_inactives_450_struct...,train
1,1,c1cccc(c12)n(c(=O)[nH]2)C3CCN(CC3)C(=O)OCc(c4)...,mu_human_IC50_binding_all_inactives_450_struct...,train
2,2,c1cccc(c12)n(c(=O)[nH]2)C3CCN(CC3)C(=O)OCc(c4)...,mu_human_IC50_binding_all_inactives_450_struct...,train
3,3,COc1cc(N)c(Cl)cc1C(=O)NCC[N@H+](C)C[C@H](CCCC2...,mu_human_IC50_binding_all_inactives_450_struct...,val
4,4,c1ccccc1N(CNC2=O)C23CCN(CC3)C(=O)OCc4cccc(c4)O...,mu_human_IC50_binding_all_inactives_450_struct...,train
...,...,...,...,...
71217,71217,C=CC[NH+](CC=C)CCOCCCc1ccccc1,mu_rat_Ki_all_inactives_1006_structures.sdf:428,train
71218,71218,C=CC[NH+](CC=C)CCOCCCc1ccccc1,mu_rat_Ki_all_inactives_1006_structures.sdf:428,train
71219,71219,C=CCN(CC=C)CCOC/C=C/c1ccccc1,mu_rat_Ki_all_inactives_1006_structures.sdf:644,train
71220,71220,C=CC[NH+](CC=C)CCOCCCc1ccccc1,mu_rat_Ki_all_inactives_1006_structures.sdf:428,train


## Load model

In [41]:
class Config:
    # model_path="./pre-trained/mask/step=1000000.ckpt"
    model_path="./weights/mask/version_16/checkpoints/epoch=479-step=54719.ckpt"

tokeniser = util.load_tokeniser("/home/wwydmanski/Chemformer/bart_vocab.txt", util.DEFAULT_CHEM_TOKEN_START)
sampler = DecodeSampler(tokeniser, util.DEFAULT_MAX_SEQ_LEN)
model = util.load_bart(Config, sampler)
model = model.to('cuda:1')

In [42]:
def resample_molecule(smiles):
    if type(smiles) is str:
        smiles = [smiles]
        
    token_output = tokeniser.tokenise(smiles, pad=True)    
    enc_token_ids = tokeniser.convert_tokens_to_ids(token_output['original_tokens'])

    token_output["encoder_input"] = torch.tensor(enc_token_ids).to('cuda:1').T
    token_output["encoder_pad_mask"] = torch.tensor(token_output["original_pad_masks"]).to('cuda:1').T
    token_output["decoder_input"] = torch.tensor(enc_token_ids).to('cuda:1').T
    token_output["decoder_pad_mask"] = torch.tensor(token_output["original_pad_masks"]).to('cuda:1').T
    model.num_beams = 10
    smiles_batch, log_lhs_batch = model.sample_molecules(token_output, "beam")
    
    return smiles_batch

## Check seq2seq consistency

In [43]:
batch_size = 8

matching = 0
valid_any = 0
valid_rigorous = 0
processed = 0
with open("res.tsv", "w") as f:
    pass
    
with tqdm.trange(0, len(smiles_to_check), batch_size) as t:
    for i in t:
        smiles = smiles_to_check[i:i+batch_size]
        batch = resample_molecule(smiles)
        with open("res.tsv", "a") as f:
            for sm, res in zip(smiles, batch):
                processed += 1
                f.write(f"{sm}\t{res}\n")
                if sm in res:
                    matching += 1
                    
                if any(map(lambda x: Chem.MolFromSmiles(x) is not None, res)):
                    valid_any += 1
                    
                if Chem.MolFromSmiles(res[0]) is not None:
                    valid_rigorous += 1
        with open("stats.tsv", "w") as f:
            f.write(str(matching/processed*100) + "\t" + str(valid_any/processed*100) + "\t" + str(valid_rigorous/processed*100))
        t.set_postfix(accuracy=str(matching/processed*100), valid_any=valid_any/processed*100, valid_rigorous=valid_rigorous/processed*100)        

  beam_idxs_list = list((top_idxs // vocab_size).T)
  0%|                                                                                          | 0/1618 [00:05<?, ?it/s]


KeyboardInterrupt: 

In [18]:
smiles

array(['CC(C)(C)NC(=O)NC[C@H](C1)N(Cc(c12)cccc2)C(=S)NC[C@H]3CCCN3C(=O)Nc4ccccc4',
       'c1ccccc1[C@H](CCCCC)[N@H+]2CC[C@H]2[C@H](N)c3cccc(Cl)c3',
       'Clc1ccccc1C(c2ccccc2Cl)N(CC3)CCC34N(CNC4=O)c5ccccc5',
       'Clc1ccccc1C(c2ccccc2Cl)N(CC3)CCC34N(CNC4=O)c5ccccc5',
       'C1CCC[C@H]([C@@H]12)N(C(=O)N2)[C@@H]3CC[N@H+](CC3)CC4CCCCCCC4',
       'FC(F)(F)c1ccc(cc1)N2CCN(CC2)Cc(c[nH]3)c(c34)cccn4',
       'c1cc(F)cc(c12)N(C(=O)CC(=O)N2)[C@@H]3CC[N@H+](CC3)C4CCCCCCC4',
       'c1cc(F)cc(c12)N(C(=O)CC(=O)N2)[C@@H]3CC[N@H+](CC3)C4CCCCCCC4'],
      dtype=object)

In [37]:
sampled_smiles = [mols[0] for mols in batch]
target_smiles = smiles

In [40]:
model.sampler.calc_sampling_metrics(batch, smiles)

{'top_1_accuracy': 0.0,
 'invalid': 0.875,
 'top_2_accuracy': 0.0,
 'top_3_accuracy': 0.0,
 'top_5_accuracy': 0.0,
 'top_10_accuracy': 0.0}

In [39]:
perc_invalid

0.875