In [2]:
import os
import sys
import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

module_path = os.path.abspath(os.path.join('../chemical_properties_predictor'))
if module_path not in sys.path:
    sys.path.append(module_path)
import models
import data_utils as data

%matplotlib inline
%load_ext autoreload

To best prepare our datasets for training the VAE, there are some pre-processing steps and filtering that needs to occur. The objective of this notebook is to develop functions with the ChEMBL dataset that are applicable to any other dataset. The needed operations are:

- tokenization
- padding
- creating a set (library) of unique tokens for individual datasets
    - or curating a universal library from all available datasets
- generating char_weights, which are a TFIDF-score
- 

In [3]:
chembl_path = '../data/ChEMBL_subset.csv'
# chembl_path = '../ChEMBL_subset.csv'

df = pd.read_csv(chembl_path, delimiter = ';')
col_names = df.columns.tolist()
# df.head()
col_names

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


['Molecule ChEMBL ID',
 'Molecule Name',
 'Molecule Max Phase',
 'Molecular Weight',
 '#RO5 Violations',
 'AlogP',
 'Compound Key',
 'Smiles',
 'Standard Type',
 'Standard Relation',
 'Standard Value',
 'Standard Units',
 'pChEMBL Value',
 'Data Validity Comment',
 'Comment',
 'Uo Units',
 'Ligand Efficiency BEI',
 'Ligand Efficiency LE',
 'Ligand Efficiency LLE',
 'Ligand Efficiency SEI',
 'Potential Duplicate',
 'Assay ChEMBL ID',
 'Assay Description',
 'Assay Type',
 'BAO Format ID',
 'BAO Label',
 'Assay Organism',
 'Assay Tissue ChEMBL ID',
 'Assay Tissue Name',
 'Assay Cell Type',
 'Assay Subcellular Fraction',
 'Assay Parameters',
 'Assay Variant Accession',
 'Assay Variant Mutation',
 'Target ChEMBL ID',
 'Target Name',
 'Target Organism',
 'Target Type',
 'Document ChEMBL ID',
 'Source ID',
 'Source Description',
 'Document Journal',
 'Document Year',
 'Cell ChEMBL ID',
 'Properties']

In [4]:
#drop all rows with a nan in the 3 columns of interest
mols_df = df[['Smiles']].replace('None', np.nan)
mols_df = mols_df.dropna()
mols_df = mols_df.astype({'Smiles':str})

# #filter out molecules with SMILES strings longer than 250
# mols_df = mols_df[mols_df.str.len() <= 250]

# #Take every 10th sample to create toy dataset
# print(mols_df.shape)
# mols_df = mols_df.iloc[lambda x: x.index % 10 == 0]
print(mols_df.shape)
mols_df.head()

(247104, 1)


Unnamed: 0,Smiles
0,O=c1oc(SCc2ccccc2)nc2ccccc12
1,Cc1c(O)c(=O)ccn1C
2,CC(C)c1cccc(C(C)C)c1OC(=O)[N-]S(=O)(=O)Oc1c(C(...
3,CN(O)C(=O)Cc1ccc(CC(=O)C2c3cccc(O)c3C(=O)c3c(O...
4,Nc1nc(=O)n([C@H]2CS[C@@H](CO)O2)cc1I


In [4]:
%autoreload

smiles_encoder = data.SmilesEncoder()

smiles = mols_df['Smiles'].tolist()

tokenized_smiles = []

for smi_str in tqdm(smiles, total = len(smiles)):
#     smi_str = mols_df.iloc[i, 0]
    padded = smiles_encoder.pad(smi_str, max_length = 250)
    tokenized = smiles_encoder.tokenize(padded)
    tokenized_smiles.append(tokenized)
    
print(tokenized_smiles[0])

AssertionError: Supply CHAR_DICT

In [108]:
# print(tokenized_smiles[0])

char_set = []

for tokd in tqdm(tokenized_smiles, total = len(tokenized_smiles)):
    tok_set = list(set(tokd))
    char_set.extend(tok_set)
    char_set = list(set(char_set))

#need to ensure that the padding token is at index 0
pad_idx = char_set.index('*')
del char_set[pad_idx]
char_set.insert(0, '*')

print(len(char_set))
print(char_set)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=236902.0), HTML(value='')))


183
['*', '54', '97', '[S@+]', '91', '[C-]', '85', '[C@@]', '[14CH2]', '191', '[S@@]', '1', '431', '211', '19', 'I', '\\', '10', '543', '[N@+]', '[3H]', '26', '86', '[C@]', '7', '/', '68', '213', 'C', '[18O]', '[P@+]', '[I-]', '321', '36', '432', 'c', '[N@@+]', '18', '20', '[N+]', '[76Br]', '62', '.', '[o+]', '[11c]', '713', '72', 'S', 'Br', '42', '(', '231', '[17F]', '871', '[O+]', '[14C]', '2', '[35S]', '35', '[s+]', '75', '[127I]', '[14cH]', '56', '[131I]', '#', '6', '%', '25', 'Cl', '[Cl+3]', '[11CH3]', '34', '[S@]', '[P@@+]', '5', '69', '[14C@H]', '9', '312', '[P-]', '192', '123', '[124I]', '28', '57', '[PH]', '[76BrH]', '67', '[NH-]', '324', '61', '65', '21', ')', '341', '[NH3+]', '[11CH2]', '32', '[14CH]', '[123I]', '63', '[19F]', '[Na+]', '[14c]', '37', '24', '22', '[O-]', '78', '[CaH2]', '[N-]', '[14C@@H]', '8', '64', '53', '171', '[S+]', '[P+]', '41', '[14CH3]', '[Mg+2]', '162', '[n-]', '[Cl-]', 'O', '23', '[C@H]', '[OH-]', '[B-]', '46', '[125I]', '132', '[O]', '73', '[n+]',

In [109]:
%autoreload

char_dict = {}
for i in range(len(char_set)):
    char_dict[char_set[i]] = i
    
print(len(char_dict))

char_params = {
    'MAX_LENGTH': 250,
    'CHAR_DICT': char_dict,
    'NUM_CHAR': len(char_set)
}

char_weights = smiles_encoder.get_char_weights(tokenized_smiles, char_params)

183


In [110]:
char_weights

array([0.5       , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 0.88435443, 1.        , 1.        ,
       1.        , 0.65874936, 1.        , 1.        , 1.        ,
       1.        , 0.91909856, 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 0.86885986, 1.        ,
       0.80395201, 1.        , 1.        , 0.60633794, 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       0.59377694, 1.        , 1.        , 1.        , 0.94020025,
       1.        , 1.        , 0.89833955, 1.        , 1.        ,
       1.        , 1.        , 0.8142734 , 0.96210695, 1.        ,
       0.625921  , 1.        , 1.        , 1.        , 1.        ,
       1.        , 0.67163758, 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       0.90062842, 0.97570592, 1.        , 1.        , 0.79492957,
       1.        , 1.        , 1.        , 1.        , 1.     

In [111]:
#save char_dict, char_weights for loading into model

with open('../data/CHAR_DICT.json', 'w') as f:
    json.dump(char_dict, f)
    f.close()
    
np.save('../data/CHAR_WEIGHTS.pickle', char_weights, allow_pickle = True)

In [112]:
char_dict['B']

KeyError: 'B'

In [4]:
import re

ELEMENT_SYMBOLS = [#uncommon atoms that can be construed as two atoms (e.g. CS, NO) are excluded
    'Ac', 'Ag', 'Al', 'Am', 'Ar', 'As', 'At', 'Au', 'Ba', 'Be', 'Bh', 'Bi', 'Bk', 'Br', 'B', 'Ca', 'Cd', 'Ce',
    'Cl', 'Cm', 'Cr', 'Cu', 'C', 'Db', 'Ds', 'Dy', 'Er', 'Es', 'Eu', 'Fe', 'Fl', 'Fm', 'Fr', 'F', 'Ga', 'Gd',
    'Ge', 'He', 'Hg', 'H', 'In', 'Ir', 'I', 'Kr', 'K', 'La', 'Li', 'Lr', 'Lu', 'Lv', 'Md', 'Mg', 'Mn', 'Mo',
    'Mt', 'Na', 'Nb', 'Nd', 'Ne', 'Ni', 'N', 'O', 'Pa', 'Pb', 'Pd', 'Pm', 'Po', 'Pr', 'Pt', 'Pu', 'P', 'Ra',
    'Rb', 'Re', 'Rf', 'Rg', 'Rh', 'Rn', 'Ru', 'Sb', 'Se', 'Sg', 'Si', 'Sm', 'Sr', 'S', 'Ta', 'Tb', 'Tc', 'Te',
    'Th', 'Ti', 'Tl', 'Tm', 'Uuo', 'Uup', 'Uus', 'Uut', 'U', 'V', 'W', 'Xe', 'Yb', 'Y', 'Zn', 'Zr'
    ]

BRACKET_SYMBOLS = [
    '\(', '\)', '\[', '\]', '\{', '\}'
]

NUMBERS = '(\d{3}|\d{2}|\d{1})'

SMILES_SYMBOLS = [
    '\.', '=', '#', '-', '\+', '\+', '\\\\', '\/', '_', ':', '~', '@@', '@@', '@', '@', '\?', '>', '\*', '\$', '\%'
]

element_re = re.compile('|'.join(ELEMENT_SYMBOLS), flags = re.I)
number_re = re.compile(NUMBERS)
smiles_re = re.compile('|'.join(SMILES_SYMBOLS))
bracket_re = re.compile('|'.join(BRACKET_SYMBOLS))

def match_brackets(string):
    matches = []
    for m in bracket_re.finditer(string):
        match_span = (m.start(), m.group())
        matches.append(match_span)
    return matches
    
def match_atoms(string):
    matches = []
    for m in element_re.finditer(string):
        match_span = (m.start(), m.group())
        matches.append(match_span)
    return matches

def match_smiles_symbols(string):
    matches = []
    for m in smiles_re.finditer(string):
        match_span = (m.start(), m.group())
        matches.append(match_span)
    return matches

def match_numbers(string):
    matches = []
    for m in number_re.finditer(string):
        match_span = (m.start(), m.group())
        matches.append(match_span)
    return matches

In [5]:
test_smiles = 'Nc1nc(=O)n([C@H]2CS[C@@H](CO)O2)cc1I'
print('Length of smiles string is: \t', len(test_smiles))

brackets = match_brackets(test_smiles)
atoms = match_atoms(test_smiles)
symbols = match_smiles_symbols(test_smiles)
numbers = match_numbers(test_smiles)

print(brackets, '\n\t', len(brackets))
print(atoms, '\n\t', len(atoms))
print(symbols, '\n\t', len(symbols))
print(numbers, '\n\t', len(numbers))

Length of smiles string is: 	 36
[(5, '('), (8, ')'), (10, '('), (11, '['), (15, ']'), (19, '['), (24, ']'), (25, '('), (28, ')'), (31, ')')] 
	 10
[(0, 'N'), (1, 'c'), (3, 'n'), (4, 'c'), (7, 'O'), (9, 'n'), (12, 'C'), (14, 'H'), (17, 'C'), (18, 'S'), (20, 'C'), (23, 'H'), (26, 'C'), (27, 'O'), (29, 'O'), (32, 'c'), (33, 'c'), (35, 'I')] 
	 18
[(6, '='), (13, '@'), (21, '@@')] 
	 3
[(2, '1'), (16, '2'), (30, '2'), (34, '1')] 
	 4


In [6]:
def reconstruct_smiles_from_re(original, brackets, atoms, symbols, numbers):
    smiles_len = len(original)
    reconstructed = ''
    
    for i in range(smiles_len):
        reconstructed += 'z'
        
    for br in brackets:
        ind = br[0]
        tok = br[1]
        if len(tok) > 1:
            tok_num = len(tok)
        else:
            tok_num = 1
        for i in range(tok_num):
            reconstructed = reconstructed[:ind+i] + tok[i]  + reconstructed[ind+i+1:]
                
    for at in atoms:
        ind = at[0]
        tok = at[1]
        if len(tok) > 1:
            tok_num = len(tok)
        else:
            tok_num = 1
        for i in range(tok_num):
            reconstructed = reconstructed[:ind+i] + tok[i]  + reconstructed[ind+i+1:]
            
    for sy in symbols:
        ind = sy[0]
        tok = sy[1]
        if len(tok) > 1:
            tok_num = len(tok)
        else:
            tok_num = 1
        for i in range(tok_num):
            reconstructed = reconstructed[:ind+i] + tok[i]  + reconstructed[ind+i+1:]
            
    for num in numbers:
        ind = num[0]
        tok = num[1]
        if len(tok) > 1:
            tok_num = len(tok)
        else:
            tok_num = 1
        for i in range(tok_num):
            reconstructed = reconstructed[:ind+i] + tok[i]  + reconstructed[ind+i+1:]
                
    assert len(reconstructed) == len(original), ('smiles error: length mismatch between original and reconstructed')
            
    return reconstructed

In [7]:
rec = reconstruct_smiles_from_re(test_smiles, brackets, atoms, symbols, numbers)
print(test_smiles)
print(rec)

Nc1nc(=O)n([C@H]2CS[C@@H](CO)O2)cc1I
Nc1nc(=O)n([C@H]2CS[C@@H](CO)O2)cc1I


In [8]:
def spans_to_index_list(char_dict, brackets, atoms, symbols, numbers):
    spans = brackets + atoms + symbols + numbers
    
    sorted_spans = sorted(spans, key = lambda spn: spn[0])
    print(sorted_spans)
    
    token_list = []
    for span in sorted_spans:
        tok = span[1]
        ind = char_dict[tok]
        token_list.append(ind)
        
    return token_list


def build_char_dict(smiles_strings):
    char_dict = {}
    char_set = []
    for smiles in tqdm(smiles_strings, total = len(smiles_strings)):
        brackets = match_brackets(smiles)
        atoms = match_atoms(smiles)
        symbols = match_smiles_symbols(smiles)
        numbers = match_numbers(smiles)
        
        span_list = brackets+atoms+symbols+numbers
        span_list = [tup[1] for tup in span_list]
        span_set = list(set(span_list))
        
        char_set.extend(span_set)
        char_set = list(set(char_set))
        char_set.sort()
        
    try:
        #need to ensure that the padding token is at index 0
        pad_idx = char_set.index('*')
        del char_set[pad_idx]
        char_set.insert(0, '*')
    except:
        char_set.insert(0, '*')
        
    for i, char in enumerate(char_set):
        char_dict[char] = i
        
    return char_dict

In [5]:
%autoreload

smiles = mols_df['Smiles'].tolist()

tokenizer = data.SmilesTokenizer(char_dict = None)

char_dict = tokenizer.build_char_dict(smiles)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=247104.0), HTML(value='')))




In [6]:
with open('../data/CHAR_DICT.json', 'w') as f:
    json.dump(char_dict, f)
    f.close()

char_dict

{'*': 0,
 '#': 1,
 '%': 2,
 '(': 3,
 ')': 4,
 '+': 5,
 '-': 6,
 '.': 7,
 '/': 8,
 '1': 9,
 '10': 10,
 '102': 11,
 '109': 12,
 '11': 13,
 '115': 14,
 '12': 15,
 '123': 16,
 '124': 17,
 '125': 18,
 '127': 19,
 '13': 20,
 '131': 21,
 '132': 22,
 '14': 23,
 '15': 24,
 '16': 25,
 '162': 26,
 '17': 27,
 '171': 28,
 '175': 29,
 '18': 30,
 '19': 31,
 '191': 32,
 '192': 33,
 '2': 34,
 '20': 35,
 '21': 36,
 '211': 37,
 '213': 38,
 '22': 39,
 '23': 40,
 '231': 41,
 '24': 42,
 '245': 43,
 '25': 44,
 '26': 45,
 '27': 46,
 '28': 47,
 '3': 48,
 '31': 49,
 '312': 50,
 '314': 51,
 '32': 52,
 '321': 53,
 '324': 54,
 '34': 55,
 '341': 56,
 '35': 57,
 '36': 58,
 '37': 59,
 '4': 60,
 '41': 61,
 '412': 62,
 '42': 63,
 '43': 64,
 '431': 65,
 '432': 66,
 '45': 67,
 '46': 68,
 '5': 69,
 '51': 70,
 '52': 71,
 '53': 72,
 '54': 73,
 '543': 74,
 '56': 75,
 '57': 76,
 '58': 77,
 '6': 78,
 '61': 79,
 '62': 80,
 '63': 81,
 '64': 82,
 '642': 83,
 '65': 84,
 '67': 85,
 '68': 86,
 '69': 87,
 '7': 88,
 '713': 89,
 '72': 

In [7]:
%autoreload

char_params = {
    'MAX_LENGTH': 250,
    'CHAR_DICT': char_dict,
    'NUM_CHAR': len(char_dict)
}

tokenizer = data.SmilesTokenizer(char_dict = char_dict)
tokenized_smiles = []
for smi in tqdm(smiles, total = len(smiles)):
#     print(smi)
    tokenized, index_list = tokenizer.tokenize(smi)
#     print(tokenized)
    tokenized_smiles.append(tokenized)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=247104.0), HTML(value='')))




In [8]:
%autoreload

print(tokenized_smiles[0])

smiles_encoder = data.SmilesEncoder(CHAR_DICT = char_dict)
char_weights = smiles_encoder.get_char_weights(tokenized_smiles, char_params)

np.save('../data/CHAR_WEIGHTS.pickle', char_weights, allow_pickle = True)

['O', '=', 'c', '1', 'o', 'c', '(', 'S', 'C', 'c', '2', 'c', 'c', 'c', 'c', 'c', '2', ')', 'n', 'c', '2', 'c', 'c', 'c', 'c', 'c', '12']


In [9]:
char_weights

array([0.5       , 0.90045848, 1.        , 0.61895721, 0.61895721,
       0.89009346, 0.76104495, 0.89135292, 0.80079304, 0.65662237,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       0.86297838, 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 0.66937234,
       1.        , 0.92211497, 1.        , 1.        , 1.        ,
       0.93319289, 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 0.70068396, 1.        ,
       1.        , 1.        , 0.98461713, 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       0.75799074, 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 0.83570046,
       1.        , 1.        , 1.        , 1.        , 1.     