In [None]:
import json
import torch
from torch.utils.data import IterableDataset
import pandas as pd
import numpy as np
from collections import defaultdict
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pickle
from tqdm import tqdm

In [None]:
data_dict = json.load(open('data/dataset.json'))
df = pd.DataFrame(data_dict)
df.head()

In [None]:
# Creating a dataframe with 50%
# values of original dataframe
train_df = df.sample(frac = 0.8)
 
# Creating dataframe with
# rest of the 50% values
test_df = df.drop(train_df.index)

In [None]:
train_df[:100].to_json('data/train.json')
test_df[:10].to_json('data/test.json')

### Precompute SMILES embeddings

In [None]:
all_smiles = set()

for l in df.smiles_0:
    l = map(lambda x: x[0], l)
    all_smiles.update(l)

for l in df.smiles_1:
    l = map(lambda x: x[0], l)
    all_smiles.update(l)

In [None]:
smiles_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
smiles_model = AutoModelForMaskedLM.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

In [None]:
smiles_inputs = [
    'Cn1c(=O)n(Cc2ccccc2)c(=O)c2cc(COCc3ccccc3)cnc21',
    'CC(CC(=O)CC(C)C1CC(=O)C2(C)C3=C(C(=O)CC12C)C1(C)CCC(O)C(C)(C)C1CC3O)C(=O)O',
    # 'O=C(NC1CCCN(Cc2cccc3ccccc23)C1)N1CCC2NNC(c3ccc4nccn4c3)C2C1',
    # "Cc1cc(Nc2cc(N3CCN(C)CC3)nc(Sc3ccc(NC(=O)C4CC4)cc3)n2)n[nH]1",
]

In [None]:
smiles_tokenized_inputs = smiles_tokenizer(smiles_inputs, padding=True, truncation=True, return_tensors="pt")
smiles_raw_outputs = smiles_model(**smiles_tokenized_inputs)

In [None]:
smiles_mask = torch.unsqueeze(smiles_tokenized_inputs['attention_mask'], dim=2)
smiles_logits = smiles_raw_outputs.logits
smiles_logits = smiles_logits.masked_fill(smiles_mask == 0, 0)

In [None]:
smiles_tokenized_inputs['attention_mask']

In [None]:
seq_lens = torch.sum(smiles_tokenized_inputs['attention_mask'], dim=1)
seq_lens = seq_lens.reshape((-1,1,1))

In [None]:
print(smiles_logits.shape, seq_lens.shape)

In [None]:
smiles_logits_avg = torch.sum(smiles_logits / seq_lens, dim=1)
smiles_logits_avg.shape

In [None]:
pooled_smiles_embeddings = torch.sum(smiles_logits, dim=1) 
pooled_smiles_embeddings.shape

In [None]:
def get_smiles_embeddings(smiles_inputs, tokenizer, model):
    """
    Returns a tensor of pretrained SMILES embeddings for the given SMILES inputs.
    """    
    smiles_tokenized_inputs = tokenizer(smiles_inputs, padding=True, truncation=True, return_tensors="pt")
    smiles_raw_outputs = model(**smiles_tokenized_inputs)

    smiles_mask = torch.unsqueeze(smiles_tokenized_inputs['attention_mask'], dim=2)
    smiles_logits = smiles_raw_outputs.logits
    smiles_logits = smiles_logits.masked_fill(smiles_mask == 0, 0)

    # pooled_smiles_embeddings = torch.sum(smiles_logits, dim=1) 
    seq_lens = torch.sum(smiles_tokenized_inputs['attention_mask'], dim=1)
    seq_lens = seq_lens.reshape((-1,1,1))
    smiles_logits_avg = torch.sum(smiles_logits / seq_lens, dim=1)

    return smiles_logits_avg

In [None]:
smiles_to_embeddings = {}
for smiles in tqdm(all_smiles):
    embed = get_smiles_embeddings([smiles], smiles_tokenizer, smiles_model)
    smiles_to_embeddings[smiles] = embed[0].detach().numpy()

In [None]:
with open('data/smiles_to_embeddings_v2.pickle', 'wb') as f:
    pickle.dump(smiles_to_embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open('data/smiles_to_embeddings_v2.pickle', 'rb') as f:
    loaded = pickle.load(f)

In [None]:
len(loaded)

In [None]:
loaded

In [None]:
smiles_batch_0 = [
    'O=C(NC1CCCN(Cc2cccc3ccccc23)C1)N1CCC2NNC(c3ccc4nccn4c3)C2C1',
    "Cc1cc(Nc2cc(N3CCN(C)CC3)nc(Sc3ccc(NC(=O)C4CC4)cc3)n2)n[nH]1",
]

smiles_batch_1 = [
    'Cn1c(=O)n(Cc2ccccc2)c(=O)c2cc(COCc3ccccc3)cnc21',
    'CC(CC(=O)CC(C)C1CC(=O)C2(C)C3=C(C(=O)CC12C)C1(C)CCC(O)C(C)(C)C1CC3O)C(=O)O',
]

In [None]:
np.stack(list(map(lambda s: loaded[s], smiles_batch_0))).shape