# Build a vocabulary using all tokens

In [1]:
# Imports

import pandas as pd
import numpy as np
import os
import json

from tqdm.auto import tqdm

from hyformer.configs.dataset import DatasetConfig
from hyformer.utils.datasets.auto import AutoDataset


# autoreload
%load_ext autoreload
%autoreload 2


In [2]:

# constants

DATA_DIR = '/lustre/groups/aih/hyformer/data'

BENCHMARKS = [
    'guacamol',
    'hi/drd2',
    'hi/hiv',
    'hi/kdr',
    'hi/sol',
    'lo/drd2',
    'lo/kdr',
    'lo/kcnh2',
    'molecule_net/scaffold/bace',
    'molecule_net/scaffold/bbbp',
    'molecule_net/scaffold/clintox',
    'molecule_net/scaffold/esol',
    'molecule_net/scaffold/freesolv',
    'molecule_net/scaffold/hiv',
    'molecule_net/scaffold/lipo',
    'molecule_net/scaffold/sider',
    'molecule_net/scaffold/tox21',
    'molecule_net/scaffold/toxcast',
    # 'unimol'
]

ADDITIONAL_DATASETS_FILEPATH = '/lustre/groups/aih/hyformer/data_proprietary/fibrosis/raw/FDA_Div_Filtered_Dataset.csv'
    

CONFIG_FILEPATH = 'configs/datasets/{benchmark}/config.json'

In [3]:
import re


class RegexTokenizer:

    SMILES_REGEX_PATTERN = r"""(\[[^\]]+\]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|%[0-9]{2}|[0-9])"""
    def __init__(self, regex_pattern: str = SMILES_REGEX_PATTERN):
        self.regex = re.compile(regex_pattern)

    def get_tokens(self, text):
        return list(set(self.regex.findall(text)))



In [4]:
tokens = []
tokenizer = RegexTokenizer()

In [5]:
# Load data and extract tokens

for benchmark in BENCHMARKS:
    
    config_filepath = CONFIG_FILEPATH.format(benchmark=benchmark)
    dataset_config = DatasetConfig.from_config_filepath(config_filepath)
    
    for split in ['train', 'val', 'test']:
        dataset = AutoDataset.from_config(dataset_config, split=split, root=DATA_DIR)
        for idx in tqdm(range(len(dataset)), desc=f"Extracting tokens from {benchmark} {split} split"):
            # extract tokens based on regex pattern from tokenizer
            tokens.extend(tokenizer.get_tokens(dataset[idx]['data']))
            break
        tokens = list(set(tokens))  

tokens.sort()


Extracting tokens from guacamol train split:   0%|          | 0/1273104 [00:00<?, ?it/s]

Extracting tokens from guacamol val split:   0%|          | 0/79568 [00:00<?, ?it/s]

Extracting tokens from guacamol test split:   0%|          | 0/238706 [00:00<?, ?it/s]

Extracting tokens from hi/drd2 train split:   0%|          | 0/2146 [00:00<?, ?it/s]

Extracting tokens from hi/drd2 val split:   0%|          | 0/239 [00:00<?, ?it/s]

Extracting tokens from hi/drd2 test split:   0%|          | 0/1190 [00:00<?, ?it/s]

Extracting tokens from hi/hiv train split:   0%|          | 0/14126 [00:00<?, ?it/s]

Extracting tokens from hi/hiv val split:   0%|          | 0/1570 [00:00<?, ?it/s]

Extracting tokens from hi/hiv test split:   0%|          | 0/7847 [00:00<?, ?it/s]

Extracting tokens from hi/kdr train split:   0%|          | 0/450 [00:00<?, ?it/s]

Extracting tokens from hi/kdr val split:   0%|          | 0/50 [00:00<?, ?it/s]

Extracting tokens from hi/kdr test split:   0%|          | 0/3116 [00:00<?, ?it/s]

Extracting tokens from hi/sol train split:   0%|          | 0/1297 [00:00<?, ?it/s]

Extracting tokens from hi/sol val split:   0%|          | 0/145 [00:00<?, ?it/s]

Extracting tokens from hi/sol test split:   0%|          | 0/721 [00:00<?, ?it/s]

Extracting tokens from lo/drd2 train split:   0%|          | 0/1985 [00:00<?, ?it/s]

Extracting tokens from lo/drd2 val split:   0%|          | 0/221 [00:00<?, ?it/s]

Extracting tokens from lo/drd2 test split:   0%|          | 0/267 [00:00<?, ?it/s]

Extracting tokens from lo/kdr train split:   0%|          | 0/450 [00:00<?, ?it/s]

Extracting tokens from lo/kdr val split:   0%|          | 0/50 [00:00<?, ?it/s]

Extracting tokens from lo/kdr test split:   0%|          | 0/437 [00:00<?, ?it/s]

Extracting tokens from lo/kcnh2 train split:   0%|          | 0/2981 [00:00<?, ?it/s]

Extracting tokens from lo/kcnh2 val split:   0%|          | 0/332 [00:00<?, ?it/s]

Extracting tokens from lo/kcnh2 test split:   0%|          | 0/406 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/bace train split:   0%|          | 0/1210 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/bace val split:   0%|          | 0/151 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/bace test split:   0%|          | 0/152 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/bbbp train split:   0%|          | 0/1631 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/bbbp val split:   0%|          | 0/204 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/bbbp test split:   0%|          | 0/204 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/clintox train split:   0%|          | 0/1182 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/clintox val split:   0%|          | 0/148 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/clintox test split:   0%|          | 0/148 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/esol train split:   0%|          | 0/902 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/esol val split:   0%|          | 0/113 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/esol test split:   0%|          | 0/113 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/freesolv train split:   0%|          | 0/513 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/freesolv val split:   0%|          | 0/64 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/freesolv test split:   0%|          | 0/65 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/hiv train split:   0%|          | 0/32901 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/hiv val split:   0%|          | 0/4113 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/hiv test split:   0%|          | 0/4113 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/lipo train split:   0%|          | 0/3360 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/lipo val split:   0%|          | 0/420 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/lipo test split:   0%|          | 0/420 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/sider train split:   0%|          | 0/1141 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/sider val split:   0%|          | 0/143 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/sider test split:   0%|          | 0/143 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/tox21 train split:   0%|          | 0/6264 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/tox21 val split:   0%|          | 0/783 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/tox21 test split:   0%|          | 0/784 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/toxcast train split:   0%|          | 0/6860 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/toxcast val split:   0%|          | 0/858 [00:00<?, ?it/s]

Extracting tokens from molecule_net/scaffold/toxcast test split:   0%|          | 0/858 [00:00<?, ?it/s]

In [26]:
len(tokens)

75

In [34]:
vocab = {}
with open('data/vocabulary/unimol.txt', "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        token = line.strip()
        if token:  # Skip empty lines
            vocab[token] = i
            
print("Loaded vocabulary length:", len(vocab))

tokens = list(vocab.keys())

Loaded vocabulary length: 458


In [32]:
# Add new tokens to vocabulary

new_tokens = []
data = pd.read_csv(ADDITIONAL_DATASETS_FILEPATH, index_col=0)['standard_smiles'].tolist()

for idx in range(len(data)):
    new_tokens.extend(tokenizer.get_tokens(data[idx]))

new_tokens = list(set(new_tokens))

print("Extended vocabulary length:", len(new_tokens))




Extended vocabulary length: 75


In [37]:
tokens_extended = list(set(tokens + new_tokens))

In [38]:
len(tokens_extended)

460

In [40]:
with open('data/vocabulary/unimol_extended.txt', 'w') as f:
    for token in tokens_extended:
        f.write(token + '\n')


In [20]:
new_tokens = []

for idx in range(len(data)):
    new_tokens.extend(tokenizer.get_tokens(data[idx]))

new_tokens = list(set(new_tokens))

new_tokens



['[S-]',
 '[n+]',
 '\\',
 '5',
 '.',
 'I',
 '[se]',
 '[Li+]',
 '[S+]',
 '[I-]',
 '[Hg]',
 '[Gd+3]',
 'c',
 'O',
 '4',
 'C',
 '[o+]',
 '[Pt]',
 '[2H]',
 '[Co+3]',
 'P',
 '[Cl-]',
 '[C@]',
 '[n-]',
 '[Cu+2]',
 'n',
 '[C@@]',
 '-',
 '[N@@+]',
 'B',
 '=',
 's',
 '[Ca+2]',
 ')',
 '[nH]',
 '[P@]',
 '[Br-]',
 '2',
 '3',
 '[C@H]',
 '#',
 '[O-]',
 'Br',
 '[C-]',
 '9',
 '/',
 '[N@+]',
 '7',
 '[Se]',
 '[Au+]',
 'N',
 '[Fe+4]',
 'F',
 'S',
 '[C@@H]',
 '[Sb]',
 '[nH+]',
 '[N+]',
 '[Na+]',
 'o',
 '[NH2+]',
 '[As]',
 '8',
 '[cH-]',
 '[Mn+2]',
 '[Zn+2]',
 '1',
 '[N-]',
 '[K+]',
 '[Pt+2]',
 '(',
 '6',
 '[NH+]',
 '[c+]',
 'Cl']

In [21]:
tokens.extend(new_tokens)

tokens = list(set(tokens))



In [22]:
tokens

['6',
 '[S-]',
 '.',
 '[n+]',
 '\\',
 '5',
 'I',
 '[NH+]',
 '[se]',
 '[Li+]',
 '[S+]',
 '[I-]',
 'c',
 'O',
 '4',
 'C',
 '[o+]',
 '(',
 '[Pt]',
 '[2H]',
 '[Co+3]',
 'P',
 '[Cl-]',
 '[K+]',
 '[C@]',
 '[n-]',
 'n',
 '[C@@]',
 '-',
 'B',
 '=',
 '[Ca+2]',
 ')',
 '[nH]',
 '[P@]',
 '[Br-]',
 '2',
 '3',
 '[C@H]',
 '#',
 '[O-]',
 'Br',
 '[C-]',
 '9',
 '/',
 '[N@+]',
 '7',
 '[Se]',
 '[Au+]',
 'N',
 '[Fe+4]',
 'F',
 'S',
 '[C@@H]',
 '[Sb]',
 'Cl',
 '[nH+]',
 '[N+]',
 '[Na+]',
 'o',
 '[NH2+]',
 '[As]',
 '8',
 '[cH-]',
 '[Mn+2]',
 '[Zn+2]',
 '1',
 '[N-]',
 '[Hg]',
 '[Pt+2]',
 's',
 '[Gd+3]',
 '[N@@+]',
 '[c+]',
 '[Cu+2]']