In [39]:
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, DataCollatorForLanguageModeling
from torch.utils.data import Dataset, DataLoader
import os
from pathlib import Path
import torch

In [56]:
import sys
sys.path.append("../../src")
from dfs_transformer.utils.rdkit import isValid
from rdkit import Chem
import numpy as np



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Using backend: pytorch


In [22]:

#any model weights from the link above will work here
model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTA_PubChem1M_shard00_155k")
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTA_PubChem1M_shard00_155k")

fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)



In [23]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

In [35]:
class RawTextDataset(Dataset):
    """
    Custom Torch Dataset for tokenizing large (up to 100,000,000+ sequences) text corpuses,
    by not loading the entire dataset into cache and using lazy loading from disk (using huggingface's
    'NLP' library. See 'https://github.com/huggingface/nlp' for more details on the NLP package.
    Examples
    --------
    >>> from raw_text_dataset import RawTextDataset
    >>> dataset = RawTextDataset(tokenizer=tokenizer, file_path="shard_00_selfies.txt", block_size=512)
    Downloading: 100%
    1.52k/1.52k [00:03<00:00, 447B/s]
    Using custom data configuration default
    Downloading and preparing dataset text/default-f719ef2eb3ab586b (download: Unknown size, generated: Unknown size, post-processed: Unknown sizetotal: Unknown size) to /root/.cache/huggingface/datasets/text/default-f719ef2eb3ab586b/0.0.0/3a79870d85f1982d6a2af884fde86a71c771747b4b161fd302d28ad22adf985b...
    Dataset text downloaded and prepared to /root/.cache/huggingface/datasets/text/default-f719ef2eb3ab586b/0.0.0/3a79870d85f1982d6a2af884fde86a71c771747b4b161fd302d28ad22adf985b. Subsequent calls will reuse this data.
    Loaded Dataset
    Number of lines: 999988
    Block size: 512
    """

    def __init__(self, tokenizer, file_path: str, block_size: int):
        super().__init__()
        self.tokenizer = tokenizer
        self.file_path = file_path
        self.block_size = block_size

        self.dataset = Path("/mnt/ssd/datasets/pubchemvalid.txt").read_text(encoding="utf-8").splitlines()
        print("Loaded Dataset")
        self.len = len(self.dataset)
        print("Number of lines: " + str(self.len))
        print("Block size: " + str(self.block_size))

    def __len__(self):
        return self.len

    def preprocess(self, feature_dict):
        batch_encoding = self.tokenizer(
            feature_dict,
            add_special_tokens=True,
            truncation=True,
            max_length=self.block_size,
        )
        return torch.tensor(batch_encoding["input_ids"])

    def __getitem__(self, i):
        line = self.dataset[i]
        example = self.preprocess(line)
        return example

In [36]:
dataset = RawTextDataset(tokenizer=tokenizer, file_path='/mnt/ssd/datasets/pubchemvalid.txt', block_size=512)

Loaded Dataset
Number of lines: 9942
Block size: 512


In [37]:
dl = DataLoader(dataset, batch_size=16, collate_fn = data_collator)

In [51]:
model = model.to('cuda:0')

In [58]:
import tqdm
smiles = []
for data in tqdm.tqdm(dl):
    mask = data['labels']!=-100
    pred = data['input_ids'].clone()
    pred[mask] = model(data['input_ids'].to('cuda:0')).logits.cpu().argmax(dim=2)[mask]
    smiles += tokenizer.batch_decode(pred, skip_special_tokens=True)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 622/622 [01:44<00:00,  5.94it/s]


In [59]:
valid = np.asarray([isValid(Chem.MolFromSmiles(sml)) for sml in smiles])
print(valid.sum()/len(valid))

0.9747535707101187


In [60]:
orig_smiles = Path("/mnt/ssd/datasets/pubchemvalid.txt").read_text(encoding="utf-8").splitlines()

In [62]:
same = []
for sml, osml in tqdm.tqdm(zip(smiles, orig_smiles)):
    try:
        csml1 = Chem.MolToSmiles(Chem.MolFromSmiles(sml))
        csml2 = Chem.MolToSmiles(Chem.MolFromSmiles(osml))
        same += [csml1 == csml2]
    except:
        continue
print(np.asarray(same).sum()/len(same))

9942it [00:03, 2628.24it/s]

0.692085440099061



