In [62]:
from argparse import Namespace
from os import path
from time import time
from fast_transformers.masking import LengthMask as LM
from rdkit import Chem
from transformers import BertTokenizer

import pandas as pd
import regex as re
import torch
import yaml

from molformer.train_pubchem_light import LightningModule

In [63]:
with open('./molformer/hparams.yaml', 'r') as f:
    config = Namespace(**yaml.safe_load(f))

In [64]:
tokenizer = MolTranBertTokenizer('./molformer/bert_vocab.txt')

In [65]:
ckpt = './molformer/checkpoints/N-Step-Checkpoint_3_30000.ckpt'
lm = LightningModule.load_from_checkpoint(ckpt, config=config, vocab=tokenizer.vocab)

/oak/stanford/groups/mrivas/projects/multiomics/tlmenest/conda/envs/docking/lib/python3.10/site-packages/pytorch_lightning/utilities/migration/migration.py:207: You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys that would end up colliding with each other after an upgrade, which means we can't differentiate which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint` callbacks will not be able to reload the state.
Lightning automatically upgraded your loaded checkpoint from v1.1.5 to v2.2.0.post0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint molformer/checkpoints/N-Step-Checkpoint_3_30000.ckpt`


Hi
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding


In [66]:
def batch_split(data, batch_size=64):
    i = 0
    while i < len(data):
        yield data[i:min(i+batch_size, len(data))]
        i += batch_size

def embed(model, smiles, tokenizer, batch_size=64):
    model.eval()
    embeddings = []
    for batch in batch_split(smiles, batch_size=batch_size):
        batch_enc = tokenizer.batch_encode_plus(batch, padding=True, add_special_tokens=True)
        idx, mask = torch.tensor(batch_enc['input_ids']), torch.tensor(batch_enc['attention_mask'])
        with torch.no_grad():
            token_embeddings = model.blocks(model.tok_emb(idx), length_mask=LM(mask.sum(-1)))
        # average pooling over tokens
        input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        embedding = sum_embeddings / sum_mask
        embeddings.append(embedding.detach().cpu())
    return torch.cat(embeddings)

In [67]:
filepath        = 'data/chembl_33_chemreps.txt'  
data = pd.read_table(filepath, sep='\t')

In [68]:
data

Unnamed: 0,chembl_id,canonical_smiles,standard_inchi,standard_inchi_key
0,CHEMBL153534,Cc1cc(-c2csc(N=C(N)N)n2)cn1C,InChI=1S/C10H13N5S/c1-6-3-7(4-15(6)2)8-5-16-10...,MFRNFCWYPYSFQQ-UHFFFAOYSA-N
1,CHEMBL440060,CC[C@H](C)[C@H](NC(=O)[C@H](CC(C)C)NC(=O)[C@@H...,InChI=1S/C123H212N44O34S/c1-19-63(12)96(164-11...,RSEQNZQKBMRQNM-VRGFNVLHSA-N
2,CHEMBL440245,CCCC[C@@H]1NC(=O)[C@@H](NC(=O)[C@H](CC(C)C)NC(...,InChI=1S/C160H268N50O41/c1-23-27-41-95-134(228...,FTKBTEIKPOYCEX-OZSLQWTKSA-N
3,CHEMBL440249,CC(C)C[C@@H]1NC(=O)CNC(=O)[C@H](c2ccc(O)cc2)NC...,InChI=1S/C124H154ClN21O39/c1-57(2)48-81-112(17...,UYSXXKGACMHPIM-KFGDMSGDSA-N
4,CHEMBL405398,Brc1cccc(Nc2ncnc3ccncc23)c1NCCN1CCOCC1,InChI=1S/C19H21BrN6O/c20-15-2-1-3-17(18(15)22-...,VDSXZXJEWIWBCG-UHFFFAOYSA-N
...,...,...,...,...
2372669,CHEMBL4298696,CCCCCCCCCCCCCCCCCCPCCCCCCCCCCCCCC.F[PH](F)(F)(...,InChI=1S/C32H67P.F6HP/c1-3-5-7-9-11-13-15-17-1...,ZAKUDCIPPLAGQL-UHFFFAOYSA-N
2372670,CHEMBL4298698,C[n+]1cn([C@@H]2O[C@H](CO[P@@](=O)(S)OP(=O)([O...,InChI=1S/C11H18N5O13P3S/c1-15-3-16(8-5(15)9(19...,OTIKKVINVWNBOQ-LDJOHHLFSA-N
2372671,CHEMBL4298702,c1ccc(C2CC(C3CC(c4ccccc4)OC(c4ccccc4)C3)CC(c3c...,InChI=1S/C34H34O2/c1-5-13-25(14-6-1)31-21-29(2...,NZIGZXNUFVMHNV-UHFFFAOYSA-N
2372672,CHEMBL4298703,CSCC[C@H](NC=O)C(=O)N[C@@H](CCCNC(=N)NS(=O)(=O...,InChI=1S/C78H107N18O21PS2/c1-43-44(2)65(45(3)5...,IIHLOGWTFCCTPB-WTIPWMETSA-N


In [69]:
df = data.head(10)

In [70]:
smiles = data.canonical_smiles.tolist()

In [71]:
smiles = smiles[:10]

In [72]:
len(smiles)

10

In [73]:
def canonicalize(s):
    return Chem.MolToSmiles(Chem.MolFromSmiles(s), canonical=True, isomericSmiles=False)

smiles = df.canonical_smiles.apply(canonicalize)
X = embed(lm, smiles, tokenizer).numpy()

In [74]:
X.shape

(10, 768)

In [75]:
precision = X.dtype

print("Precision:", precision)

Precision: float32


# Do it for all of CHEMBL

In [82]:
def batch_split(data, batch_size=64):
    """
    Generator to yield batches of data.
    """
    i = 0
    while i < len(data):
        yield data[i:min(i+batch_size, len(data))]
        i += batch_size

def canonicalize(s):
    """
    Canonicalize SMILES strings.
    """
    return Chem.MolToSmiles(Chem.MolFromSmiles(s), canonical=True, isomericSmiles=False)

def embed(model, smiles, tokenizer, batch_size=64):
    model.eval()
    embeddings = []
    for batch in batch_split(smiles, batch_size=batch_size):
        start_time = time() 
        
        batch_enc = tokenizer.batch_encode_plus(batch, padding=True, add_special_tokens=True, return_tensors="pt")
        idx, mask = batch_enc['input_ids'], batch_enc['attention_mask']
        with torch.no_grad():
            token_embeddings = model.blocks(model.tok_emb(idx), length_mask=LM(mask.sum(-1)))
        embeddings.append(token_embeddings.detach().cpu())

        end_time = time() 
        print(f"Processed batch in {end_time - start_time:.2f} seconds.") 

    return torch.cat(embeddings)

In [83]:
def process_and_save(chembl_df, model, tokenizer, batch_size=64):
    for batch in batch_split(chembl_df, batch_size=batch_size):
        chembl_ids  = batch['chembl_id'].tolist()
        smiles_list = batch['canonical_smiles'].apply(canonicalize).tolist()
        embeddings = embed(model, smiles_list, tokenizer, batch_size=len(smiles_list))

        for chembl_id, embedding in zip(chembl_ids, embeddings):
            torch.save(embedding, path.join('ligands', f"{chembl_id}.pt"))

        return embeddings

In [84]:
embs = process_and_save(data, lm, tokenizer, batch_size=64)

Processed batch in 51.42 seconds.


In [87]:
data

Unnamed: 0,chembl_id,canonical_smiles,standard_inchi,standard_inchi_key
0,CHEMBL153534,Cc1cc(-c2csc(N=C(N)N)n2)cn1C,InChI=1S/C10H13N5S/c1-6-3-7(4-15(6)2)8-5-16-10...,MFRNFCWYPYSFQQ-UHFFFAOYSA-N
1,CHEMBL440060,CC[C@H](C)[C@H](NC(=O)[C@H](CC(C)C)NC(=O)[C@@H...,InChI=1S/C123H212N44O34S/c1-19-63(12)96(164-11...,RSEQNZQKBMRQNM-VRGFNVLHSA-N
2,CHEMBL440245,CCCC[C@@H]1NC(=O)[C@@H](NC(=O)[C@H](CC(C)C)NC(...,InChI=1S/C160H268N50O41/c1-23-27-41-95-134(228...,FTKBTEIKPOYCEX-OZSLQWTKSA-N
3,CHEMBL440249,CC(C)C[C@@H]1NC(=O)CNC(=O)[C@H](c2ccc(O)cc2)NC...,InChI=1S/C124H154ClN21O39/c1-57(2)48-81-112(17...,UYSXXKGACMHPIM-KFGDMSGDSA-N
4,CHEMBL405398,Brc1cccc(Nc2ncnc3ccncc23)c1NCCN1CCOCC1,InChI=1S/C19H21BrN6O/c20-15-2-1-3-17(18(15)22-...,VDSXZXJEWIWBCG-UHFFFAOYSA-N
...,...,...,...,...
2372669,CHEMBL4298696,CCCCCCCCCCCCCCCCCCPCCCCCCCCCCCCCC.F[PH](F)(F)(...,InChI=1S/C32H67P.F6HP/c1-3-5-7-9-11-13-15-17-1...,ZAKUDCIPPLAGQL-UHFFFAOYSA-N
2372670,CHEMBL4298698,C[n+]1cn([C@@H]2O[C@H](CO[P@@](=O)(S)OP(=O)([O...,InChI=1S/C11H18N5O13P3S/c1-15-3-16(8-5(15)9(19...,OTIKKVINVWNBOQ-LDJOHHLFSA-N
2372671,CHEMBL4298702,c1ccc(C2CC(C3CC(c4ccccc4)OC(c4ccccc4)C3)CC(c3c...,InChI=1S/C34H34O2/c1-5-13-25(14-6-1)31-21-29(2...,NZIGZXNUFVMHNV-UHFFFAOYSA-N
2372672,CHEMBL4298703,CSCC[C@H](NC=O)C(=O)N[C@@H](CCCNC(=N)NS(=O)(=O...,InChI=1S/C78H107N18O21PS2/c1-43-44(2)65(45(3)5...,IIHLOGWTFCCTPB-WTIPWMETSA-N


In [86]:
embs.shape

torch.Size([64, 456, 768])

In [46]:
embs.shape

torch.Size([64, 768])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = lm.to(device)  
embeddings = embed(model, smiles, tokenizer, device)