# Pre-trained MolFormer embeddings demo

This notebook demonstrates how to load a MolFormer checkpoint file and use it to output pre-trained embeddings (without fine-tuning) given a list of SMILES strings. These can then be used with a simple model to perform a classification task.

In [5]:
from google.colab import drive
drive.mount('/content/drive', force_remount = True)

Mounted at /content/drive


## Load checkpoint

In [6]:
!pip install Namespace

Collecting Namespace
  Downloading namespace-0.1.4-py3-none-any.whl (2.5 kB)
Collecting modcall<0.2.0,>=0.1.0 (from Namespace)
  Downloading modcall-0.1.0-py3-none-any.whl (2.8 kB)
Installing collected packages: modcall, Namespace
Successfully installed Namespace-0.1.4 modcall-0.1.0


In [1]:
from argparse import Namespace
import yaml

with open('/content/drive/MyDrive/Katritch Lab/Molformer-XL/hparams.yaml', 'r') as f:
    config = Namespace(**yaml.safe_load(f))
config

Namespace(accelerator='ddp', batch_size=64, beam_size=0, checkpoint_every=5000, clip_grad=50, config_load=None, config_save=None, d_dropout=0.2, data_path='', data_root='/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity', dataset_length=None, dataset_name='sol', debug=True, device='cuda', dropout=0.1, eval_every=1000, fast_dev_run=False, fc_h=512, finetune_path='', freeze_model=False, from_scratch=False, gen_save=None, gpus=8, grad_acc=1, log_file=None, lr=0.001, lr_end=0.00030000000000000003, lr_multiplier=8, lr_start=3e-05, max_epochs=4, max_len=202, measure_name='measure', min_len=1, mode='cls', model_arch='BERT_16GPU_Both_10percent_rotate_no_masking', model_load=None, model_save='model.pt', model_save_dir='./models_dump/', n_batch=1800, n_embd=768, n_head=12, n_jobs=1, n_last=1000, n_layer=12, n_samples=None, n_workers=8, nucleus_thresh=0.9, num_epoch=1, num_feats=32, num_nodes=1, num_seq_returned=0, num_workers=0, pretext_size=0, q_dropout=0.5, restart_path='',

In [3]:
import sys
sys.path.append('/content/drive/MyDrive/Katritch Lab/Molformer-XL')

In [4]:
from smiles_tokenizer import MolTranBertTokenizer

tokenizer = MolTranBertTokenizer('/content/drive/MyDrive/Katritch Lab/Molformer-XL/bert_vocab.txt')
tokenizer.vocab

OrderedDict([('<bos>', 0),
             ('<eos>', 1),
             ('<pad>', 2),
             ('<mask>', 3),
             ('C', 4),
             ('c', 5),
             ('(', 6),
             (')', 7),
             ('1', 8),
             ('O', 9),
             ('N', 10),
             ('2', 11),
             ('=', 12),
             ('n', 13),
             ('3', 14),
             ('[C@H]', 15),
             ('[C@@H]', 16),
             ('F', 17),
             ('S', 18),
             ('4', 19),
             ('Cl', 20),
             ('-', 21),
             ('o', 22),
             ('s', 23),
             ('[nH]', 24),
             ('#', 25),
             ('/', 26),
             ('Br', 27),
             ('[C@]', 28),
             ('[C@@]', 29),
             ('[N+]', 30),
             ('[O-]', 31),
             ('5', 32),
             ('\\', 33),
             ('.', 34),
             ('I', 35),
             ('6', 36),
             ('[S@]', 37),
             ('[S@@]', 38),
             ('P', 39)

In [11]:
!pip install rdkit

Collecting rdkit
  Downloading rdkit-2023.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.5/30.5 MB[0m [31m50.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.9.2


In [12]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6


In [19]:
#!pip install pytorch_lightning
!pip install pytorch-lightning==1.6.4

Collecting pytorch-lightning==1.6.4
  Downloading pytorch_lightning-1.6.4-py3-none-any.whl (585 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/585.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.6/585.5 kB[0m [31m3.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━[0m [32m409.6/585.5 kB[0m [31m5.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m585.5/585.5 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting pyDeprecate>=0.3.1 (from pytorch-lightning==1.6.4)
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Collecting protobuf<=3.20.1 (from pytorch-lightning==1.6.4)
  Downloading protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m12.8 MB/s[0m eta [36

In [14]:
!pip install pytorch-fast-transformers

Collecting pytorch-fast-transformers
  Downloading pytorch-fast-transformers-0.4.0.tar.gz (93 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.6/93.6 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytorch-fast-transformers
  Building wheel for pytorch-fast-transformers (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-fast-transformers: filename=pytorch_fast_transformers-0.4.0-cp310-cp310-linux_x86_64.whl size=20071288 sha256=cceff9685a8913e993350988af52add708cae73be009bfe10db095057a9f0cfe
  Stored in directory: /root/.cache/pip/wheels/99/6b/6d/4abca344e31b65962d8e9d6fe298a5d2b89ff448493edc0df5
Successfully built pytorch-fast-transformers
Installing collected packages: pytorch-fast-transformers
Successfully installed pytorch-fast-transformers-0.4.0


In [5]:
!rm -rf apex

In [6]:
%%writefile setup.sh

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./

Overwriting setup.sh


In [7]:
!sh setup.sh

Cloning into 'apex'...
remote: Enumerating objects: 11501, done.[K
remote: Counting objects: 100% (3569/3569), done.[K
remote: Compressing objects: 100% (489/489), done.[K
remote: Total 11501 (delta 3246), reused 3183 (delta 3077), pack-reused 7932[K
Receiving objects: 100% (11501/11501), 15.42 MiB | 5.95 MiB/s, done.
Resolving deltas: 100% (8075/8075), done.
Using pip 23.1.2 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
Processing /content/apex
  Running command Preparing metadata (pyproject.toml)


  torch.__version__  = 2.1.0+cu118


  running dist_info
  creating /tmp/pip-modern-metadata-b1ep1vd4/apex.egg-info
  writing /tmp/pip-modern-metadata-b1ep1vd4/apex.egg-info/PKG-INFO
  writing dependency_links to /tmp/pip-modern-metadata-b1ep1vd4/apex.egg-info/dependency_links.txt
  writing requirements to /tmp/pip-modern-metadata-b1ep1vd4/apex.egg-info/requires.txt
  writing top-level names to /tmp/pip-modern-metadata-b1ep1vd4/apex.egg-info/top_level.txt
  writing man

In [9]:
from train_pubchem_light import LightningModule

ckpt = '/content/drive/MyDrive/Katritch Lab/Molformer-XL/checkpoints/N-Step-Checkpoint_3_30000.ckpt'
lm = LightningModule(config, tokenizer.vocab).load_from_checkpoint(ckpt, config=config, vocab=tokenizer.vocab)
lm

Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding


LightningModule(
  (tok_emb): Embedding(2362, 768)
  (drop): Dropout(p=0.2, inplace=False)
  (blocks): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (attention): RotateAttentionLayer(
          (inner_attention): LinearAttention(
            (feature_map): GeneralizedRandomFeatures()
          )
          (query_projection): Linear(in_features=768, out_features=768, bias=True)
          (key_projection): Linear(in_features=768, out_features=768, bias=True)
          (value_projection): Linear(in_features=768, out_features=768, bias=True)
          (out_projection): Linear(in_features=768, out_features=768, bias=True)
          (rotaryemb): RotaryEmbedding()
        )
        (linear1): Linear(in_features=768, out_features=768, bias=True)
        (linear2): Linear(in_features=768, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise

## Run inference (get embeddings)

note: this runs on CPU

In [10]:
import torch
from fast_transformers.masking import LengthMask as LM

def batch_split(data, batch_size=64):
    i = 0
    while i < len(data):
        yield data[i:min(i+batch_size, len(data))]
        i += batch_size

def embed(model, smiles, tokenizer, batch_size=64):
    model.eval()
    embeddings = []
    for batch in batch_split(smiles, batch_size=batch_size):
        batch_enc = tokenizer.batch_encode_plus(batch, padding=True, add_special_tokens=True)
        idx, mask = torch.tensor(batch_enc['input_ids']), torch.tensor(batch_enc['attention_mask'])
        with torch.no_grad():
            token_embeddings = model.blocks(model.tok_emb(idx), length_mask=LM(mask.sum(-1)))
        # average pooling over tokens
        input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        embedding = sum_embeddings / sum_mask
        embeddings.append(embedding.detach().cpu())
    return torch.cat(embeddings)

## Use linear head for classification task

There are many ways to use the embeddings for downstream tasks, this is clearly a toy example.

In [3]:
import pandas as pd

df = pd.read_csv('/content/drive/MyDrive/Katritch Lab/Molformer-XL/data/bace/train.csv')#.sample(frac=0.1)  # speed things up...
df

Unnamed: 0,smiles,CID,Class,Unnamed: 3,pIC50,MW,AlogP,HBA,HBD,RB,...,PEOE6 (PEOE6),PEOE7 (PEOE7),PEOE8 (PEOE8),PEOE9 (PEOE9),PEOE10 (PEOE10),PEOE11 (PEOE11),PEOE12 (PEOE12),PEOE13 (PEOE13),PEOE14 (PEOE14),canvasUID
0,O1CC[C@@H](NC(=O)[C@@H](Cc2cc3cc(ccc3nc2N)-c2c...,BACE_1,1,,9.154901,431.56979,4.4014,3,2,5,...,53.205711,78.640335,226.855410,107.434910,37.133846,0.000000,7.980170,0.000000,0.000000,1
1,S1(=O)(=O)N(c2cc(cc3c2n(cc3CC)CC1)C(=O)N[C@H](...,BACE_3,1,,8.698970,591.74091,2.5499,4,3,11,...,70.365707,47.941147,192.406520,255.752550,23.654478,0.230159,15.879790,0.000000,24.663788,3
2,S1(=O)(=O)N(c2cc(cc3c2n(cc3CC)CC1)C(=O)N[C@H](...,BACE_5,1,,8.698970,629.71283,3.5086,3,3,11,...,78.945702,39.361153,179.712880,220.461300,23.654478,0.230159,15.879790,0.000000,26.100143,5
3,S(=O)(=O)(CCCCC)C[C@@H](NC(=O)c1cccnc1)C(=O)N[...,BACE_7,1,,8.698970,645.78009,3.1973,5,4,18,...,63.830162,52.390511,263.781340,190.542130,45.370659,0.000000,23.859961,0.000000,24.663788,7
4,O1c2c(cc(cc2)CC)[C@@H]([NH2+]C[C@@H](O)[C@H]2N...,BACE_9,1,,8.602060,556.71503,4.7010,4,3,5,...,53.205711,68.418541,299.000030,140.683620,28.755558,0.000000,15.879790,6.904104,24.663788,9
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1204,Clc1cc2nc([nH]c2cc1)N,BACE_1541,0,,3.136083,167.59560,1.4715,1,2,0,...,26.907076,25.739992,1.916970,40.882919,20.071724,19.404020,0.000000,0.000000,0.000000,1541
1205,Clc1cc2nc(n(c2cc1)C(CC(=O)NCC1CCOCC1)CC)N,BACE_1543,0,,3.000000,364.86969,2.5942,3,2,6,...,37.212799,37.681076,180.226410,95.670128,30.107586,9.368159,7.980170,0.000000,0.000000,1543
1206,Clc1cc2nc(n(c2cc1)C(CC(=O)NCc1ncccc1)CC)N,BACE_1544,0,,3.000000,357.83731,2.8229,3,2,6,...,45.792797,47.349350,122.401500,99.877144,30.107586,9.368159,7.980170,0.000000,0.000000,1544
1207,Brc1cc(ccc1)C1CC1C=1N=C(N)N(C)C(=O)C=1,BACE_1545,0,,2.953115,320.18451,3.0895,2,1,2,...,47.790600,22.563574,96.290794,58.798935,20.071724,9.368159,0.000000,6.904104,0.000000,1545


In [20]:
from rdkit import Chem
from sklearn.linear_model import LogisticRegression

def canonicalize(s):
    return Chem.MolToSmiles(Chem.MolFromSmiles(s), canonical=True, isomericSmiles=False)

smiles = df.smiles.apply(canonicalize)
X = embed(lm, smiles, tokenizer).numpy()
y = df.Class
head = LogisticRegression(max_iter = 1000).fit(X, y)

In [21]:
from sklearn.metrics import roc_auc_score

test_df = pd.read_csv('/content/drive/MyDrive/Katritch Lab/Molformer-XL/data/bace/test.csv').sample(frac=0.1)
X_test = embed(lm, test_df.smiles.apply(canonicalize), tokenizer).numpy()
roc_auc_score(test_df.Class, head.predict_proba(X_test)[:, 1])

0.8636363636363635