In [1]:
from argparse import Namespace


In [2]:
from os import path
from time import time
from fast_transformers.masking import LengthMask as LM
from rdkit import Chem
from transformers import BertTokenizer

In [6]:
import pandas as pd
import regex as re
import torch
import yaml

from molformer.train_pubchem_light import LightningModule
from src.molformer import MolTranBertTokenizer

In [4]:
with open('./molformer/hparams.yaml', 'r') as f:
    config = Namespace(**yaml.safe_load(f))

In [7]:
tokenizer = MolTranBertTokenizer('./molformer/bert_vocab.txt')

In [8]:
ckpt = './molformer/checkpoints/N-Step-Checkpoint_3_30000.ckpt'
lm = LightningModule.load_from_checkpoint(ckpt, config=config, vocab=tokenizer.vocab)

/oak/stanford/groups/mrivas/projects/multiomics/tlmenest/conda/envs/docking/lib/python3.10/site-packages/pytorch_lightning/utilities/migration/migration.py:207: You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys that would end up colliding with each other after an upgrade, which means we can't differentiate which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint` callbacks will not be able to reload the state.
Lightning automatically upgraded your loaded checkpoint from v1.1.5 to v2.2.0.post0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint molformer/checkpoints/N-Step-Checkpoint_3_30000.ckpt`


Hi
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding


In [9]:
def batch_split(data, batch_size=64):
    i = 0
    while i < len(data):
        yield data[i:min(i+batch_size, len(data))]
        i += batch_size

def embed(model, smiles, tokenizer, batch_size=64):
    model.eval()
    embeddings = []
    for batch in batch_split(smiles, batch_size=batch_size):
        batch_enc = tokenizer.batch_encode_plus(batch, padding=True, add_special_tokens=True)
        idx, mask = torch.tensor(batch_enc['input_ids']), torch.tensor(batch_enc['attention_mask'])
        with torch.no_grad():
            token_embeddings = model.blocks(model.tok_emb(idx), length_mask=LM(mask.sum(-1)))
        # average pooling over tokens
        input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        embedding = sum_embeddings / sum_mask
        embeddings.append(embedding.detach().cpu())
    return torch.cat(embeddings)

In [10]:
filepath        = 'data/chembl_33_chemreps.txt'  
data = pd.read_table(filepath, sep='\t')

In [11]:
data

Unnamed: 0,chembl_id,canonical_smiles,standard_inchi,standard_inchi_key
0,CHEMBL153534,Cc1cc(-c2csc(N=C(N)N)n2)cn1C,InChI=1S/C10H13N5S/c1-6-3-7(4-15(6)2)8-5-16-10...,MFRNFCWYPYSFQQ-UHFFFAOYSA-N
1,CHEMBL440060,CC[C@H](C)[C@H](NC(=O)[C@H](CC(C)C)NC(=O)[C@@H...,InChI=1S/C123H212N44O34S/c1-19-63(12)96(164-11...,RSEQNZQKBMRQNM-VRGFNVLHSA-N
2,CHEMBL440245,CCCC[C@@H]1NC(=O)[C@@H](NC(=O)[C@H](CC(C)C)NC(...,InChI=1S/C160H268N50O41/c1-23-27-41-95-134(228...,FTKBTEIKPOYCEX-OZSLQWTKSA-N
3,CHEMBL440249,CC(C)C[C@@H]1NC(=O)CNC(=O)[C@H](c2ccc(O)cc2)NC...,InChI=1S/C124H154ClN21O39/c1-57(2)48-81-112(17...,UYSXXKGACMHPIM-KFGDMSGDSA-N
4,CHEMBL405398,Brc1cccc(Nc2ncnc3ccncc23)c1NCCN1CCOCC1,InChI=1S/C19H21BrN6O/c20-15-2-1-3-17(18(15)22-...,VDSXZXJEWIWBCG-UHFFFAOYSA-N
...,...,...,...,...
2372669,CHEMBL4298696,CCCCCCCCCCCCCCCCCCPCCCCCCCCCCCCCC.F[PH](F)(F)(...,InChI=1S/C32H67P.F6HP/c1-3-5-7-9-11-13-15-17-1...,ZAKUDCIPPLAGQL-UHFFFAOYSA-N
2372670,CHEMBL4298698,C[n+]1cn([C@@H]2O[C@H](CO[P@@](=O)(S)OP(=O)([O...,InChI=1S/C11H18N5O13P3S/c1-15-3-16(8-5(15)9(19...,OTIKKVINVWNBOQ-LDJOHHLFSA-N
2372671,CHEMBL4298702,c1ccc(C2CC(C3CC(c4ccccc4)OC(c4ccccc4)C3)CC(c3c...,InChI=1S/C34H34O2/c1-5-13-25(14-6-1)31-21-29(2...,NZIGZXNUFVMHNV-UHFFFAOYSA-N
2372672,CHEMBL4298703,CSCC[C@H](NC=O)C(=O)N[C@@H](CCCNC(=N)NS(=O)(=O...,InChI=1S/C78H107N18O21PS2/c1-43-44(2)65(45(3)5...,IIHLOGWTFCCTPB-WTIPWMETSA-N


In [12]:
missing_tensor = pd.read_csv('missing_tensor.csv')

In [14]:
missing_tensor

Unnamed: 0,0
0,CHEMBL4082509
1,CHEMBL4085507
2,CHEMBL4085491
3,CHEMBL4082078
4,CHEMBL4083033
...,...
798,CHEMBL4082599
799,CHEMBL4082762
800,CHEMBL4081074
801,CHEMBL4081986


In [16]:
missing_tensor = missing_tensor['0'].tolist()

In [18]:
len(missing_tensor)

803

In [19]:
to_do = data[data.chembl_id.isin(missing_tensor)]

In [20]:
to_do

Unnamed: 0,chembl_id,canonical_smiles,standard_inchi,standard_inchi_key
2274676,CHEMBL4085225,COC(=O)C1=C(CN2CCOC(C(=O)O)C2)NC(c2nccs2)=N[C@...,InChI=1S/C21H20ClFN4O5S/c1-31-21(30)16-14(9-27...,INFRXZHVTNHELZ-LWKPJOBUSA-N
2274677,CHEMBL4085226,COc1cc(OC)c(OC)cc1CCC(=O)NC[C@@H]1N[C@H](CO)[C...,InChI=1S/C18H28N2O7/c1-25-13-7-15(27-3)14(26-2...,HZSVKZVOJZURNT-MIFHMHLRSA-N
2274678,CHEMBL4085227,CCCNC(=O)C1CCN(C(=O)c2cc3oc(Br)cc3n2Cc2ccc(OC)...,InChI=1S/C24H28BrN3O4/c1-3-10-26-23(29)17-8-11...,LJQVQJFIHNJHEM-UHFFFAOYSA-N
2274679,CHEMBL4085237,Cc1cc2c(o1)-c1c(c3ccccc3c(=O)n1CCCn1ccnc1)C(=O...,InChI=1S/C22H17N3O4/c1-13-11-16-19(26)20(27)17...,RVUVFPIJQPYVEK-UHFFFAOYSA-N
2274680,CHEMBL4085238,NS(=O)(=O)c1ccc(Nc2nc(OCC3CCCCC3)c3cn[nH]c3n2)cc1,"InChI=1S/C18H22N6O3S/c19-28(25,26)14-8-6-13(7-...",MWINFFOPTSOXJI-UHFFFAOYSA-N
...,...,...,...,...
2275474,CHEMBL4081369,NC(C(=O)O)C(O)c1cccnc1,InChI=1S/C8H10N2O3/c9-6(8(12)13)7(11)5-2-1-3-1...,FHHXXEYZWSACRS-UHFFFAOYSA-N
2275475,CHEMBL4081466,CC(=O)c1ccc(NC(=O)C2CC(O)CN2S(=O)(=O)c2ccc([N+...,InChI=1S/C19H19N3O7S/c1-12(23)13-2-4-14(5-3-13...,CKBSSLROVOVDML-UHFFFAOYSA-N
2275476,CHEMBL4081467,CCCc1sc(NC(=O)C23CC4CC(CC(C4)C2)C3)c(C(=O)O[11...,InChI=1S/C21H29NO3S/c1-4-5-16-12(2)17(19(23)25...,WLAVAJOUZGLRDS-KTXUZGJCSA-N
2275477,CHEMBL4081468,Fc1ccc2nc(Cl)c(Cn3nnc(CSc4ccccc4)n3)cc2c1,InChI=1S/C18H13ClFN5S/c19-18-13(8-12-9-14(20)6...,RLASJFUSYDQSDP-UHFFFAOYSA-N


In [69]:
df = data.head(10)

In [70]:
smiles = data.canonical_smiles.tolist()

In [71]:
smiles = smiles[:10]

In [72]:
len(smiles)

10

In [73]:
def canonicalize(s):
    return Chem.MolToSmiles(Chem.MolFromSmiles(s), canonical=True, isomericSmiles=False)

smiles = df.canonical_smiles.apply(canonicalize)
X = embed(lm, smiles, tokenizer).numpy()

In [74]:
X.shape

(10, 768)

In [75]:
precision = X.dtype

print("Precision:", precision)

Precision: float32


# Do it for all of CHEMBL

In [74]:
def canonicalize(s):
    """
    Canonicalize SMILES strings.
    """
    return Chem.MolToSmiles(Chem.MolFromSmiles(s), canonical=True, isomericSmiles=False)

def embed(model, smiles, tokenizer):
    model.eval()
    start_time = time()

    encoded = tokenizer.encode_plus(smiles, padding='max_length', add_special_tokens=True, return_tensors="pt")
    idx, mask = encoded['input_ids'], encoded['attention_mask']
    with torch.no_grad():
        token_embeddings = model.blocks(model.tok_emb(idx), length_mask=LM(mask.sum(-1)))
    input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    embs = sum_embeddings / sum_mask
    
    end_time = time()
    print(f"Processed {smiles} in {end_time - start_time:.2f} seconds.") 
    print(embs.size())

    return embs

In [75]:
def process_and_save(chembl_df, model, tokenizer):
    for index, row in chembl_df.iterrows():
        chembl_id = row['chembl_id']
        smiles    = row['canonical_smiles']
        print(chembl_id)
        print(smiles)
        embedding = embed(model, smiles, tokenizer)
        torch.save(embedding, path.join('missing', f"{chembl_id}.pt"))

In [76]:
embs = process_and_save(to_do, lm, tokenizer)

CHEMBL4085225
COC(=O)C1=C(CN2CCOC(C(=O)O)C2)NC(c2nccs2)=N[C@H]1c1ccc(F)cc1Cl
Processed COC(=O)C1=C(CN2CCOC(C(=O)O)C2)NC(c2nccs2)=N[C@H]1c1ccc(F)cc1Cl in 0.15 seconds.
torch.Size([1, 768])
CHEMBL4085226
COc1cc(OC)c(OC)cc1CCC(=O)NC[C@@H]1N[C@H](CO)[C@H](O)[C@@H]1O
Processed COc1cc(OC)c(OC)cc1CCC(=O)NC[C@@H]1N[C@H](CO)[C@H](O)[C@@H]1O in 0.12 seconds.
torch.Size([1, 768])
CHEMBL4085227
CCCNC(=O)C1CCN(C(=O)c2cc3oc(Br)cc3n2Cc2ccc(OC)cc2)CC1
Processed CCCNC(=O)C1CCN(C(=O)c2cc3oc(Br)cc3n2Cc2ccc(OC)cc2)CC1 in 0.13 seconds.
torch.Size([1, 768])
CHEMBL4085237
Cc1cc2c(o1)-c1c(c3ccccc3c(=O)n1CCCn1ccnc1)C(=O)C2=O
Processed Cc1cc2c(o1)-c1c(c3ccccc3c(=O)n1CCCn1ccnc1)C(=O)C2=O in 0.15 seconds.
torch.Size([1, 768])
CHEMBL4085238
NS(=O)(=O)c1ccc(Nc2nc(OCC3CCCCC3)c3cn[nH]c3n2)cc1
Processed NS(=O)(=O)c1ccc(Nc2nc(OCC3CCCCC3)c3cn[nH]c3n2)cc1 in 0.10 seconds.
torch.Size([1, 768])
CHEMBL4085239
Nc1noc2c(N3CC[C@@H](NC(=O)c4ccc(-n5cnnc5C(F)F)cc4Cl)C3)ncc(Cl)c12
Processed Nc1noc2c(N3CC[C@@H](NC(=O)c4ccc(-n5cnnc5

In [30]:
embs.size()

torch.Size([64, 183, 768])

In [86]:
embs.shape

torch.Size([64, 456, 768])

In [46]:
embs.shape

torch.Size([64, 768])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = lm.to(device)  
embeddings = embed(model, smiles, tokenizer, device)