In [2]:
import json
from tqdm import tqdm

from transformers import T5Tokenizer, T5ForConditionalGeneration

from code.config import cfg
from code.utils import set_seed, init_path
import code.utils
from code.data_utils.dataset import DatasetLoader

import warnings
warnings.filterwarnings('ignore')

In [3]:
set_seed(cfg.seed)

cfg.dataset = "ogbg-molbace" # ogbg-molhiv
cfg.demo_test = True
cfg.device = 0

if cfg.demo_test:
    caption_file_name = "%s/input/caption/test_smiles2caption_%s.json" % (code.utils.project_root_path, cfg.dataset)
else:
    caption_file_name = "%s/input/caption/smiles2caption_%s.json" % (code.utils.project_root_path, cfg.dataset)

In [3]:
dataloader = DatasetLoader(name=cfg.dataset, text='raw')
text = dataloader.text
if cfg.demo_test:
    text = text[:10]

In [4]:
tokenizer = T5Tokenizer.from_pretrained(
    "laituan245/molt5-large-smiles2caption", 
    model_max_length=512,
)
model = T5ForConditionalGeneration.from_pretrained(
    'laituan245/molt5-large-smiles2caption'
).to(cfg.device)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
list_caption = []
for smiles in tqdm(text):
    input_ids = tokenizer(smiles, return_tensors="pt").input_ids.to(cfg.device)

    outputs = model.generate(input_ids, num_beams=5, max_length=512)
    # print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    list_caption.append(tokenizer.decode(outputs[0], skip_special_tokens=True))

100%|██████████| 10/10 [00:27<00:00,  2.79s/it]


In [6]:
init_path(dir_or_file=caption_file_name)
with open(caption_file_name, 'w') as file:
    json.dump(list_caption, file, indent=2)