In [1]:
import torch

from models.ranked_transformer import Moonshot
from models.chemformer.tokeniser import MolEncTokeniser
from pathlib import Path
from models.chemformer.utils import REGEX, DEFAULT_MAX_SEQ_LEN
from datasets.generic_index_dataset import GenericIndexedModule

from datasets.dataset_utils import pad, tokenise_and_mask

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
path = "./tempdata/epoch=76-step=264264.ckpt"
vocab_path = "tempdata/chemformer/bart_vocab.txt"
chem_token_start = 272

In [21]:
tokeniser = MolEncTokeniser.from_vocab_file(
  vocab_path, REGEX, chem_token_start
)
direct = Path("tempdata/SMILES_dataset")
features = ["HSQC", "SMILES"]
def tam(a):
  return tokenise_and_mask(a, tokeniser)
feature_handlers = [pad, tam]
gim = GenericIndexedModule(direct, features, feature_handlers, len_override = 5)
gim.setup("fit")
val_dl = gim.val_dataloader()

In [14]:
obj = torch.load("tempdata/chemformer/model.ckpt")

In [16]:
obj["hyper_parameters"]

{'pad_token_idx': 0,
 'vocab_size': 523,
 'd_model': 512,
 'num_layers': 6,
 'num_heads': 8,
 'd_feedforward': 2048,
 'lr': 1.0,
 'weight_decay': 0.0,
 'activation': 'gelu',
 'num_steps': 1933600,
 'max_seq_len': 512,
 'dropout': 0.1,
 'schedule': 'transformer',
 'warm_up_steps': 8000,
 'batch_size': 128,
 'acc_batches': 1,
 'mask_prob': 0.1,
 'epochs': 10,
 'clip_grad': 1.0,
 'train_tokens': 'None',
 'num_buckets': 12,
 'limit_val_batches': 1.0,
 'augment': True,
 'task': 'mask_aug'}

In [17]:
args = {
  'pad_token_idx': 0,
  'vocab_size': 523,
  'd_model': 512,
  'num_layers': 6,
  'num_heads': 8,
  'd_feedforward': 2048,
  'lr': 1.0,
  'weight_decay': 0.0,
  'activation': 'gelu',
  'num_steps': 1933600,
  'max_seq_len': 512,
  'dropout': 0.1,
  'schedule': 'transformer',
  'warm_up_steps': 8000,
  'batch_size': 128,
  'acc_batches': 1,
  'mask_prob': 0.1,
  'epochs': 10,
  'clip_grad': 1.0,
  'train_tokens': 'None',
  'num_buckets': 12,
  'limit_val_batches': 1.0,
  'augment': True,
  'task': 'mask_aug',
  'module_only': True,
  'dim_model': 512,
  'dim_coords': [224, 224, 64],
  'coord_enc': 'sce',
  'wavelength_bounds': [[0.01, 250], [0.01, 250]],
  'gce_resolution': 0.1,
  'heads': 4,
  'layers': 4,
  'lr': 1.0e-3
}

In [18]:
model = Moonshot.load_from_checkpoint(path, strict=False, **args).cuda()

  rank_zero_warn(
Initialized SignCoordinateEncoder[512] with dims [224, 224, 64] and 2 positional encoders. 64 bits are reserved for encoding the final bit


In [22]:
for (hsqc, collated_smiles) in val_dl:
  hsqc, collated_smiles = hsqc.cuda(), {k: (v.cuda() if k != "raw_smiles" else v) for k,v in collated_smiles.items()}
  break

In [33]:
begin_tok, pad_tok, end_tok = tokeniser.begin_token, tokeniser.pad_token, tokeniser.end_token
begin_tok_idx, pad_tok_idx, end_tok_idx = tokeniser.vocab[begin_tok], tokeniser.vocab[pad_tok], tokeniser.vocab[end_tok]
print(f"{begin_tok_idx=} {pad_tok_idx=} {end_tok_idx=}")

seq_len = 50
b_s = 8

token_ids = [begin_tok_idx] + ([pad_tok_idx] * (seq_len - 1))
token_ids = torch.tensor(token_ids)[None, :].tile((b_s, 1)).cuda()
decoder_pad_mask = torch.zeros((b_s, seq_len)).cuda()
print(token_ids.size())

begin_tok_idx=2 pad_tok_idx=0 end_tok_idx=3
torch.Size([51, 2])


In [23]:
model.eval()
with torch.no_grad():
  for i in range(1, seq_len):
    decoder_inputs = token_ids[:,:i]
    decoder_mask = token_ids[:,:i]
    my_collated_smiles = {
      "decoder_inputs": decoder_inputs,
      "decoder_mask": decoder_mask,
    }
    output = model.forward((hsqc, collated_smiles))

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
