In [1]:
import json
import torch
from torch.utils.data import IterableDataset
import pandas as pd
import numpy as np
from collections import defaultdict
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pickle
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dict = json.load(open('data/dataset.json'))
df = pd.DataFrame(data_dict)
df.head()

Unnamed: 0,protein_target,smiles_1,smiles_0
0,MPHEPHEPLTPPFSALPDPAGAPSRRQSRQRPQLSSDSPSAFRASR...,"[[CCN1C(=CC(C)=O)Sc2ccc(OC)cc21, 1], [CC12OC(C...","[[CC(C)N1NC(=C2C=c3cc(O)ccc3=N2)c2c(N)ncnc21, ..."
1,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,"[[CSc1nc2ccccc2n1CC(=O)c1ccc(S(N)(=O)=O)cc1, 1...","[[CNC(=O)c1ccc(S(N)(=O)=O)cc1, 0], [Nc1ccc(CC(..."
2,MDPLNLSWYDDDLERQNWSRPFNGSEGKADRPHYNYYAMLLTLLIF...,[[CC(C)CC1C(=O)N2CCCC2C2(O)OC(NC(=O)C3C=C4c5cc...,"[[Oc1cc2c(cc1O)C1c3ccccc3CNC1CC2, 0], [CCCN1Cc..."
3,MDRSKENCISGPVKATAPVGGPKRVLVTQQFPCQNPLPVNSGQAQR...,[[Cc1cc(Nc2nc(Sc3ccc(NC(=O)CN4CCC(O)C4)cc3)nn3...,[[CC(Oc1cc(-c2cnn(C3CCNCC3)c2)cnc1N)c1c(Cl)ccc...
4,MRVVVIGAGVIGLSTALCIHERYHSVLQPLDIKVYADRFTPLTTTD...,"[[O=c1[nH]c2ccc(F)cc2cc1O, 1], [O=c1[nH]c2ccc(...","[[O=C(O)c1cc(CCc2ccc(Cl)cc2)c[nH]1, 0], [O=c1o..."


In [3]:
df

Unnamed: 0,protein_target,smiles_1,smiles_0
0,MPHEPHEPLTPPFSALPDPAGAPSRRQSRQRPQLSSDSPSAFRASR...,"[[CCN1C(=CC(C)=O)Sc2ccc(OC)cc21, 1], [CC12OC(C...","[[CC(C)N1NC(=C2C=c3cc(O)ccc3=N2)c2c(N)ncnc21, ..."
1,MSHHWGYGKHNGPEHWHKDFPIAKGERQSPVDIDTHTAKYDPSLKP...,"[[CSc1nc2ccccc2n1CC(=O)c1ccc(S(N)(=O)=O)cc1, 1...","[[CNC(=O)c1ccc(S(N)(=O)=O)cc1, 0], [Nc1ccc(CC(..."
2,MDPLNLSWYDDDLERQNWSRPFNGSEGKADRPHYNYYAMLLTLLIF...,[[CC(C)CC1C(=O)N2CCCC2C2(O)OC(NC(=O)C3C=C4c5cc...,"[[Oc1cc2c(cc1O)C1c3ccccc3CNC1CC2, 0], [CCCN1Cc..."
3,MDRSKENCISGPVKATAPVGGPKRVLVTQQFPCQNPLPVNSGQAQR...,[[Cc1cc(Nc2nc(Sc3ccc(NC(=O)CN4CCC(O)C4)cc3)nn3...,[[CC(Oc1cc(-c2cnn(C3CCNCC3)c2)cnc1N)c1c(Cl)ccc...
4,MRVVVIGAGVIGLSTALCIHERYHSVLQPLDIKVYADRFTPLTTTD...,"[[O=c1[nH]c2ccc(F)cc2cc1O, 1], [O=c1[nH]c2ccc(...","[[O=C(O)c1cc(CCc2ccc(Cl)cc2)c[nH]1, 0], [O=c1o..."
...,...,...,...
1042,MARARPPPPPSPPPGLLPLLPPLLLLPLLLLPAGCRALEETLMDTK...,"[[Cc1ccc(O)cc1Nc1ccnc(Nc2cccc(C(N)=O)c2)n1, 1]...",[[COc1cc(N2CCC(N3CCN(C)CC3)CC2)ccc1Nc1ncc(Cl)c...
1043,MRANDALQVLGLLFSLARGSEVGNSQAVCPGTLNGLSVTGDAENQY...,[[C=CC(=O)Nc1cc2c(Nc3ccc(F)c(Cl)c3)ncnc2cc1OCC...,[[Cc1cc(Nc2cc(N3CCN(C)CC3)nc(Sc3ccc(NC(=O)C4CC...
1044,MDDKDIDKELRQKLNFSYCEETEIEGQKKVEESREASSQTPEKGEV...,[[CCOc1cc2ncc(C#N)c(Nc3ccc(OCc4ccccn4)c(Cl)c3)...,[[Cc1cnc(Nc2ccc(OCCN3CCCC3)cc2)nc1Nc1cccc(S(=O...
1045,MAGWIQAQQLQGDALRQMQVLYGQHFPIEVRHYLAQWIESQPWDAI...,[],[[CN(CC(=O)N(Cc1ccc(C2CCCCC2)cc1)c1ccc(C(=O)O)...


In [4]:
train_df = df.sample(frac = 0.8)

temp_df = df.drop(train_df.index)

val_df = temp_df.sample(frac = 0.5)

test_df = temp_df.drop(val_df.index)

In [15]:
train_df.to_json('data/train.json')
val_df.to_json('data/val.json')
test_df.to_json('data/test.json')

### Precompute SMILES embeddings

In [5]:
all_smiles = set()

for l in df.smiles_0:
    l = map(lambda x: x[0], l)
    all_smiles.update(l)

for l in df.smiles_1:
    l = map(lambda x: x[0], l)
    all_smiles.update(l)

In [6]:
smiles_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
smiles_model = AutoModelForMaskedLM.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

In [7]:
def get_smiles_embeddings(smiles_inputs, tokenizer, model):
    """
    Returns a tensor of pretrained SMILES embeddings for the given SMILES inputs.
    """    
    smiles_tokenized_inputs = tokenizer(smiles_inputs, padding=True, truncation=True, return_tensors="pt")
    smiles_raw_outputs = model(**smiles_tokenized_inputs)

    smiles_mask = torch.unsqueeze(smiles_tokenized_inputs['attention_mask'], dim=2)
    smiles_logits = smiles_raw_outputs.logits
    smiles_logits = smiles_logits.masked_fill(smiles_mask == 0, 0)

    # pooled_smiles_embeddings = torch.sum(smiles_logits, dim=1) 
    seq_lens = torch.sum(smiles_tokenized_inputs['attention_mask'], dim=1)
    seq_lens = seq_lens.reshape((-1,1,1))
    smiles_logits_avg = torch.sum(smiles_logits / seq_lens, dim=1)

    return smiles_logits_avg

In [8]:
smiles_to_embeddings = {}
for smiles in tqdm(all_smiles):
    embed = get_smiles_embeddings([smiles], smiles_tokenizer, smiles_model)
    smiles_to_embeddings[smiles] = embed[0].detach().numpy()

100%|██████████| 7862/7862 [02:51<00:00, 45.93it/s]


In [9]:
with open('data/smiles_to_embeddings.pickle', 'wb') as f:
    pickle.dump(smiles_to_embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)

In [10]:
with open('data/smiles_to_embeddings.pickle', 'rb') as f:
    loaded = pickle.load(f)