# 05/05/23

This notebook is intended to evaluate the pretrained models of Chemformer. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from models.chemformer.molbart_dataset import Zinc
from models.chemformer.molbart import BARTModel
from models.chemformer.tokeniser import MolEncTokeniser
from models.chemformer.utils import REGEX
from models.chemformer.molbart_datamodule import MoleculeDataModule

from models.chemformer.sampler import DecodeSampler

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokeniser = MolEncTokeniser.from_vocab_file("tempdata/chemformer/bart_vocab.txt", REGEX, 272)
dataset = Zinc("tempdata/chemformer/zinc.csv")
data_module = MoleculeDataModule(
  dataset,
  tokeniser,
  batch_size=32,
  max_seq_len=512,
  task="mask_aug",
  train_token_batch_size=None,
  num_buckets=12, 
  val_idxs = dataset.val_idxs,
  test_idxs = dataset.test_idxs,
  augment="DO IT",
  unified_model=False,
)

Using a batch size of 32.
Using molecule data module with augmentations.


In [4]:
sampler = DecodeSampler(tokeniser, 512)
model = BARTModel.load_from_checkpoint("tempdata/chemformer/model.ckpt", decode_sampler = sampler)

  rank_zero_warn(
Lightning automatically upgraded your loaded checkpoint from v1.2.3 to v2.0.1.post0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file tempdata/chemformer/model.ckpt`


In [5]:
data_module.setup()
val_dl = data_module.val_dataloader()
for batch in val_dl:
  obj = batch
  break

In [6]:
model = model.cuda()
obj = {k : v.cuda() if type(v) is torch.Tensor else v for k,v in obj.items()}
print(f"{model.device=}")
for k, v in obj.items():
  if hasattr(v, "device"):
    print(f"{k} {v.device=}")

model.device=device(type='cuda', index=0)
encoder_input v.device=device(type='cuda', index=0)
encoder_pad_mask v.device=device(type='cuda', index=0)
decoder_input v.device=device(type='cuda', index=0)
decoder_pad_mask v.device=device(type='cuda', index=0)
target v.device=device(type='cuda', index=0)
target_mask v.device=device(type='cuda', index=0)


### Inputs

In [29]:
tokens = tokeniser.convert_ids_to_tokens(obj["encoder_input"].T)
strs = tokeniser.detokenise(tokens)
strs

targets = obj["target_smiles"]

for idx, (target, my_str) in enumerate(zip(targets, strs)):
  print(f"[{idx}]\n\t{target=}\n\t{my_str=}")

[0]
	target='COc1ccc(Cn2cc(NC(=O)N3CC[C@@H](Oc4ccc(Cl)cc4)C3)cn2)cc1'
	my_str='O(C<MASK>1<MASK>n2cc(NC(<MASK>N<MASK>[C@H]<MASK><MASK><MASK><MASK>(Cl)<MASK>CC3)c<MASK>c1'
[1]
	target='CC(=O)Nc1ccc(N2C[C@@H](C(=O)N[C@H]3C=CCCC3)CC2=O)cc1'
	my_str='<MASK>1=C<MASK>NC<MASK>C(=O)N(c3ccc(NC<MASK>=O)<MASK>c<MASK>)=O<MASK>CC1'
[2]
	target='O=C(NCCCCO)[C@H]1[C@H](C(F)(F)F)[C@H]2CC[C@H]1O2'
	my_str='O1[C@H]2[C@@H](C(F)(F)F)[C@H](C(=O)NCCC<MASK>C<MASK>'
[3]
	target='COCCOc1cccc(NC(=O)N(CC(=O)O)C2CCCC2)c1'
	my_str='c1<MASK><MASK><MASK>NC(<MASK>(<MASK>=O)C2C<MASK>2)=O)cc1OCCOC'
[4]
	target='COc1cc(OC)c(C(=O)N[C@H]2COc3c(Cl)cccc32)cc1OC'
	my_str='O=<MASK>c<MASK>cc(OC)c(OC)<MASK><MASK>[C@@H]1c2cccc(Cl<MASK><MASK>c<MASK><MASK>1'
[5]
	target='Cc1cccc2nnc([C@@]34CCCC[C@@H]3CNC4)n12'
	my_str='c<MASK>c([C@]34<MASK>CC3)CN<MASK>n<MASK>C)ccc2'
[6]
	target='CS(=O)(=O)CCCCNC(=O)[C@H](Sc1ccc(F)cc1F)c1ccccc1'
	my_str='c1c<MASK>([C@@H](S<MASK><MASK>F)<MASK>)cc2)C<MASK>NCCCCS(=<MASK>(C)=O<MASK>c1'
[7]
	target='CN1C

### Using Beam Search

In [38]:
from textwrap import indent

In [36]:
model.test_sampling_alg = "beam"
print(f"Search Algorithm: {model.test_sampling_alg}")
beam_output = model.test_step(obj, -1)
mol_strs, log_lhs = model.sample_molecules(obj, "beam")

Search Algorithm: beam


test_molecular_top_k_accuracy: (```models/chemformer/sampler.py```)
- Choose top k beam search results (10), sorted by log-likelihood
- Check if any of those match the intended target. If, so -> success, else -> fail

In [43]:
display(beam_output)
for idx, (target, my_str, beam_str, log_lh) in enumerate(zip(targets, strs, mol_strs, log_lhs)):
  print(f"[{idx}]\n\t{target=}\n\t{my_str=}\n\t{log_lh=}")
  print(f"\tbeam_str: ")
  print(indent("\n".join(mol_strs[idx]), prefix="\t\t"))

{'test_loss': 0.2379920482635498,
 'test_token_acc': tensor(0.8887, device='cuda:0'),
 'test_perplexity': tensor(0., device='cuda:0'),
 'test_invalid_smiles': 0.0,
 'test_molecular_accuracy': 0.21875,
 'test_molecular_top_1_accuracy': 0.21875,
 'test_molecular_top_2_accuracy': 0.25,
 'test_molecular_top_3_accuracy': 0.25,
 'test_molecular_top_5_accuracy': 0.28125,
 'test_molecular_top_10_accuracy': 0.3125}

[0]
	target='COc1ccc(Cn2cc(NC(=O)N3CC[C@@H](Oc4ccc(Cl)cc4)C3)cn2)cc1'
	my_str='O(C<MASK>1<MASK>n2cc(NC(<MASK>N<MASK>[C@H]<MASK><MASK><MASK><MASK>(Cl)<MASK>CC3)c<MASK>c1'
	log_lh=[-9.995026588439941, -10.035225868225098, -10.152423858642578, -10.194649696350098, -10.203709602355957, -10.253393173217773, -10.291241645812988, -10.323857307434082, -10.43252944946289, -10.60340404510498]
	beam_str: 
		c1cc(OC)ccc1Cn1cc(NC(=O)N2CCO[C@@H](c3ccc(Cl)cc3)C2)cn1
		c1cc(OC)ccc1Cn1ncc(NC(=O)N2C[C@H](Oc3ccc(Cl)cc3)CC2)c1
		c1c(OC)ccc(Cn2ncc(NC(=O)N3CC[C@@H](Oc4ccc(Cl)cc4)C3)c2)c1
		c1c(OC)ccc(Cn2ncc(NC(=O)N3CCC[C@H]3c3ccc(Cl)cc3)c2)c1
		c1cc(OC)ccc1Cn1ncc(NC(=O)N2CCO[C@@H](c3ccc(Cl)cc3)C2)c1
		c1c(OC)ccc(Cn2ncc(NC(=O)N3C[C@H](c4ccc(Cl)cc4)OCC3)c2)c1
		c1cc(OC)ccc1Cn1ncc(NC(=O)N2CC[C@@H](Oc3ccc(Cl)cc3)C2)c1
		c1cc(OC)ccc1Cn1ncc(NC(=O)N2C[C@H](c3ccc(Cl)cc3)OCC2)c1
		c1cc(OC)ccc1Cn1ncc(NC(=O)N2CCC[C@H]2c2ccc(Cl)cc2)c1
		c1cc(OC)ccc1Cn1ncc(NC(=O)N2[C@H](c3ccc(Cl)cc3)CC2)c1
[1]
	target='CC(=O)Nc1ccc(N2C[

### Using Greedy Search

In [44]:
model.test_sampling_alg = "greedy"
print(f"Search Algorithm: {model.test_sampling_alg}")
greedy_output = model.test_step(obj, -1)
mol_strs, log_lhs = model.sample_molecules(obj, "greedy")

Search Algorithm: greedy


In [45]:
display(greedy_output)
for idx, (target, my_str, greedy_str, log_lh) in enumerate(zip(targets, strs, mol_strs, log_lhs)):
  print(f"[{idx}]\n\t{target=}\n\t{my_str=}\n\t{greedy_str=}\n\t{log_lh=}")

{'test_loss': 0.2379920482635498,
 'test_token_acc': tensor(0.8887, device='cuda:0'),
 'test_perplexity': tensor(0., device='cuda:0', grad_fn=<MeanBackward0>),
 'test_invalid_smiles': 0.0,
 'test_molecular_accuracy': 0.15625}

[0]
	target='COc1ccc(Cn2cc(NC(=O)N3CC[C@@H](Oc4ccc(Cl)cc4)C3)cn2)cc1'
	my_str='O(C<MASK>1<MASK>n2cc(NC(<MASK>N<MASK>[C@H]<MASK><MASK><MASK><MASK>(Cl)<MASK>CC3)c<MASK>c1'
	greedy_str='c1c(Cl)ccc([C@H]2CN(C(=O)Nc3cnn(Cc4ccc(OC)cc4)c3)CCO2)c1'
	log_lh=-8.724369049072266
[1]
	target='CC(=O)Nc1ccc(N2C[C@@H](C(=O)N[C@H]3C=CCCC3)CC2=O)cc1'
	my_str='<MASK>1=C<MASK>NC<MASK>C(=O)N(c3ccc(NC<MASK>=O)<MASK>c<MASK>)=O<MASK>CC1'
	greedy_str='C1CCC=C1CCNC(=O)[C@@H]1CN(c2ccc(NC(=O)C)cc2)C(=O)C1'
	log_lh=-11.33339786529541
[2]
	target='O=C(NCCCCO)[C@H]1[C@H](C(F)(F)F)[C@H]2CC[C@H]1O2'
	my_str='O1[C@H]2[C@@H](C(F)(F)F)[C@H](C(=O)NCCC<MASK>C<MASK>'
	greedy_str='C(CCNC(=O)[C@@H]1[C@@H]2O[C@H](CC2)[C@H]1C(F)(F)F)CCC'
	log_lh=-25.318500518798828
[3]
	target='COCCOc1cccc(NC(=O)N(CC(=O)O)C2CCCC2)c1'
	my_str='c1<MASK><MASK><MASK>NC(<MASK>(<MASK>=O)C2C<MASK>2)=O)cc1OCCOC'
	greedy_str='C(COc1cc(NC(=O)N[C@@H](C2CCCCC2)CC(=O)N)ccc1)OC'
	log_lh=-16.216228485107422
[4]
	target='COc1cc(OC)c(C(=O)N[C@H]2COc3c(Cl)cccc32