# Generative Pre-Training from Molecules

In this notebook, we demonstrate how to pretrain
[HuggingFace](https://huggingface.co/transformers/)
[GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel) language model 
on a SMILES corpus. [SMILES](https://www.daylight.com/dayhtml/doc/theory/theory.smiles.html) is
a language construct for representing molecules, with its unique syntax and vocabulary of 
molecular constituents. Pretraining GPT-2 on large and diverse corpora allows capturing
general representations of molecules capable of being transferred to such downstream tasks as
molecular-property prediction and low-data de novo molecular design.

---

*Author: Sanjar Adilov* <br/>
*Paper: [Generative Pre-Training from Molecules](https://doi.org/10.33774/chemrxiv-2021-5fwjd)*,
*DOI: 10.33774/chemrxiv-2021-5fwjd* <br/>
*Package: https://github.com/sanjaradylov/smiles-gpt*

## Main Package

Our [`smiles_gpt`](https://github.com/sanjaradylov/smiles-gpt/tree/master/smiles_gpt)
package implements
[pytorch-lightning](https://www.pytorchlightning.ai/)-compatible modules for data loading,
model training and testing. The SMILES tokenizer and downstream regression and
single-/multi-output classification models are also compatible with HuggingFace API.

In [1]:
try:
    import smiles_gpt as gpt
except ImportError:
    import sys
    sys.path.extend([".."])  # Parent directory stores `smiles_gpt` package.
    import smiles_gpt as gpt

  from .autonotebook import tqdm as notebook_tqdm


For demonstration purposes, we use only 10K subset of PubChem data made available by
[ChemBERTa](https://arxiv.org/abs/2010.09885) developers. The original model was pretrained
on the first 5M compounds with the following hyperparameters:
```python
hyperparams = {"batch_size": 128, "max_epochs": 2, "max_length": 512,
               "learning_rate": 5e-4, "weight_decay": 0.0,
               "adam_eps": 1e-8, "adam_betas": (0.9, 0.999),
               "scheduler_T_max": 150_000, "final_learning_rate": 5e-8,
               "vocab_size": 1_000, "min_frequency": 2, "top_p": 0.96,
               "n_layer": 4, "n_head": 8, "n_embd": 512}
```

In [2]:
# 10K subset of PubChem SMILES dataset.
filename = "/home/piyush22194/RNN/smiles-gpt-master/smiles-gpt-master/notebooks/inhibitors_1_modified.txt"
# Directory to serialize a tokenizer and model.
checkpoint = "/home/piyush22194/RNN/smiles-gpt-master/smiles-gpt-master/checkpoints/benchmark-10m"
tokenizer_filename = "/home/piyush22194/RNN/smiles-gpt-master/smiles-gpt-master/checkpoints/benchmark-10m/tokenizer.json"

# Tokenizer, model, optimizer, scheduler, and trainer hyperparameters.
hyperparams = {"batch_size": 32, "max_epochs": 10, "min_epochs": 10,
               "max_length": 512, "learning_rate": 5e-4, "weight_decay": 0.0,
               "adam_eps": 1e-8, "adam_betas": (0.9, 0.999),
               "scheduler_T_max": 150_000, "final_learning_rate": 5e-8,
               "vocab_size": 1_000, "min_frequency": 2, "top_p": 0.96,
               "n_layer": 6, "n_head": 12, "n_embd": 12 * 48}

gpus = 4  # Specify either a list of GPU devices or an integer (0 for no GPU).
num_workers = 32  # Number of dataloader worker processes.
is_tokenizer_pretrained = True

## Tokenization

`smiles_gpt.SMILESBPETokenizer` first splits SMILES strings into characters, runs
byte-pair encoding, and augments the resulting list with `"<s>"` (beginning-of-SMILES) and
`"</s>"` (end-of-SMILES) special tokens. `smiles_gpt.SMILESAlphabet` stores 72 possible
characters as an initial vocabulary.

In [3]:
tokenizer = gpt.SMILESBPETokenizer(dropout=None)

if not is_tokenizer_pretrained:
    alphabet = list(gpt.SMILESAlphabet().get_alphabet())
    tokenizer.train(filename,
                    vocab_size=hyperparams["vocab_size"] + len(alphabet),
                    min_frequency=hyperparams["min_frequency"],
                    initial_alphabet=alphabet)
    tokenizer.save_model(checkpoint)
    tokenizer.save(tokenizer_filename)
else:
    import os
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    tokenizer = tokenizer.from_file("/home/piyush22194/RNN/smiles-gpt-master/smiles-gpt-master/checkpoints/benchmark-10m/vocab.json",
                                    "/home/piyush22194/RNN/smiles-gpt-master/smiles-gpt-master/checkpoints/benchmark-10m/merges.txt")

tokenizer

Tokenizer(vocabulary_size=1072, model=BPE, unk_token=<unk>, suffix=, dropout=None)

[`SMILESBPETokenizer`](https://github.com/sanjaradylov/smiles-gpt/blob/master/smiles_gpt/tokenization.py#L23)
inherits `BaseTokenizer` from
[Tokenizers](https://huggingface.co/docs/tokenizers/python/latest/index.html). It is already
useful by itself, however, to make it more convenient and follow HuggingFace API, we load
`transformers.PreTrainedTokenizerFast` instance of our tokenizer:

In [4]:
from pprint import pprint

tokenizer = gpt.SMILESBPETokenizer.get_hf_tokenizer(
    tokenizer_filename, model_max_length=hyperparams["max_length"])

smiles_string = "CC(Cl)=CCCC=C(C)Cl"
smiles_encoded = tokenizer(smiles_string)
smiles_merges = tokenizer.convert_ids_to_tokens(smiles_encoded["input_ids"])

pprint(smiles_encoded)
pprint(smiles_merges)

{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1],
 'input_ids': [1, 78, 142, 24, 101, 218, 109, 63, 2],
 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0]}
['<s>', 'CC', '(Cl)', '=', 'CCCC', '=C', '(C)C', 'l', '</s>']


## Data Module

[`smiles_gpt.LMDataModule`](https://github.com/sanjaradylov/smiles-gpt/blob/master/smiles_gpt/data.py#L248)
is a lightning data module that loads SMILES data, encodes them
with `tokenizer`, and returns pytorch data loader with
`transformers.DataCollatorForLanguageModeling` collator. Encodings contain tensors of shape
`hyperparameters["max_length"]`: `"input_ids"` and `"lables"`.

In [5]:
datamodule = gpt.LMDataModule(filename, tokenizer,
                              batch_size=hyperparams["batch_size"],
                              num_workers=num_workers)
# datamodule.setup()

# batch = next(iter(datamodule.train_dataloader()))
# pprint(batch)

## GPT-2 Model

Now we load HuggingFace
[`GPT2LMHeadModel`](https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel)
with the configuration composed of previously
defined model hyperparameters. The model processes mini-batch of input ids and labels, then
returns predictions and cross-entropy loss between labels and predictions.

In [6]:
from transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=tokenizer.vocab_size,
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    n_layer=hyperparams["n_layer"],
                    n_head=hyperparams["n_head"],
                    n_embd=hyperparams["n_embd"],
                    n_positions=hyperparams["max_length"],
                    n_ctx=hyperparams["max_length"])
model = GPT2LMHeadModel(config)

# outputs = model(**batch)
# outputs.keys()

In [7]:
import torch

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [9]:
model.to(device)

GPT2LMHeadModel(
  (shared_parameters): ModuleDict()
  (transformer): GPT2Model(
    (shared_parameters): ModuleDict()
    (invertible_adapters): ModuleDict()
    (wte): Embedding(1072, 576)
    (wpe): Embedding(512, 576)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((576,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): MergedLinear(
            in_features=576, out_features=1728, bias=True
            (loras): ModuleDict()
          )
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
          (prefix_tuning): PrefixTuningShim(
            (prefix_gates): ModuleDict()
            (pool): PrefixTuningPool(
              (prefix_tunings): ModuleDict()
            )
          )
        )
        (ln_2): LayerNorm((576,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
       

In [10]:
import pytorch_lightning as pl
print(pl.__version__)

2.1.2


## Trainer

GPT-2 is trained with autoregressive language modeling objective:
$$
P(\boldsymbol{s}) = P(s_1) \cdot P(s_2 | s_1) \cdots P(s_T | s_1, \ldots, s_{T-1}) =
\prod_{t=1}^{T} P(s_t | s_{j < t}),
$$
where $\boldsymbol{s}$ is a tokenized (encoded) SMILES string, $s_t$ is a token from pretrained 
vocabulary $\mathcal{V}$.

We use `pytorch_lightning.Trainer` to train GPT-2. Since `Trainer` requires lightning modules,
we import our
[`smiles_gpt.GPT2LitModel`](https://github.com/sanjaradylov/smiles-gpt/blob/master/smiles_gpt/language_modeling.py#L10)
wrapper that implements training phases for
`GPT2LMHeadModel`, configures an `Adam` optimizer with `CosineAnnealingLR` scheduler, and
logs average perplexity every epoch.

In [11]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


checkpoint_cb = ModelCheckpoint(f"/home/piyush22194/RNN/smiles-gpt-master/smiles-gpt-master/inhibitors_checkpoint")
early_stopping_ppl = EarlyStopping(
    monitor="ppl",
    patience=4,
    min_delta=5e-3,
    check_finite=True,
    stopping_threshold=1.1,
    divergence_threshold=hyperparams["vocab_size"] / 10,
    verbose=True,
    mode="min",
    check_on_train_epoch_end=True,
)
trainer = Trainer(
    strategy="auto",
    callbacks=[checkpoint_cb,early_stopping_ppl],
    max_epochs=hyperparams["max_epochs"],
    min_epochs=hyperparams["min_epochs"],
    val_check_interval=0.4,
    limit_train_batches=0.2,
    log_every_n_steps=200,
)
lit_model = gpt.GPT2LitModel(
    model,
    batch_size=hyperparams["batch_size"],
    learning_rate=hyperparams["learning_rate"],
    final_learning_rate=hyperparams["final_learning_rate"],
    weight_decay=hyperparams["weight_decay"],
    adam_eps=hyperparams["adam_eps"],
    adam_betas=hyperparams["adam_betas"],
    scheduler_T_max=hyperparams["scheduler_T_max"],
)
trainer.fit(lit_model, datamodule)
lit_model.transformer.save_pretrained(f"/home/piyush22194/RNN/smiles-gpt-master/smiles-gpt-master/inhibitors_checkpoint")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type            | Params
------------------------------------------------
0 | transformer | GPT2LMHeadModel | 24.8 M
------------------------------------------------
24.8 M    Trainable params
0         Non-trainable params
24.8 M    Total params
99.385    Total estimated model params size (MB)
/home/piyush22194/.conda/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batche

Epoch 0: 100%|██████████| 45/45 [00:06<00:00,  7.04it/s, v_num=46, ppl=37.50] 

Metric ppl improved. New best score: 37.454


Epoch 1: 100%|██████████| 45/45 [00:04<00:00,  9.40it/s, v_num=46, ppl=19.60]

Metric ppl improved by 17.895 >= min_delta = 0.005. New best score: 19.559


Epoch 3: 100%|██████████| 45/45 [00:04<00:00,  9.67it/s, v_num=46, ppl=12.90]

Metric ppl improved by 6.621 >= min_delta = 0.005. New best score: 12.938


Epoch 4: 100%|██████████| 45/45 [00:05<00:00,  8.41it/s, v_num=46, ppl=7.210]

Metric ppl improved by 5.727 >= min_delta = 0.005. New best score: 7.212


Epoch 6: 100%|██████████| 45/45 [00:04<00:00,  9.56it/s, v_num=46, ppl=4.450]

Metric ppl improved by 2.758 >= min_delta = 0.005. New best score: 4.454


Epoch 8: 100%|██████████| 45/45 [00:04<00:00,  9.31it/s, v_num=46, ppl=3.900]

Metric ppl improved by 0.557 >= min_delta = 0.005. New best score: 3.897


Epoch 9: 100%|██████████| 45/45 [00:05<00:00,  8.65it/s, v_num=46, ppl=3.980]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 45/45 [00:06<00:00,  7.16it/s, v_num=46, ppl=3.980]


## Interpretability

[BertViz](https://github.com/jessevig/bertviz) inspects attention heads of transformers
capturing specific patterns in data. Each head can be representative of some syntactic
or short-/long-term relationships between tokens.

In [12]:
import torch
from bertviz import head_view

smiles = "CC[NH+](CC)C1CCC([NH2+]C2CC2)(C(=O)[O-])C1"
inputs = tokenizer(smiles, add_special_tokens=False, return_tensors="pt")
input_ids_list = inputs["input_ids"].tolist()[0]
model_1 = GPT2LMHeadModel.from_pretrained("/home/piyush22194/RNN/smiles-gpt-master/smiles-gpt-master/inhibitors_checkpoint", output_attentions=True)
attention = model_1(torch.LongTensor(input_ids_list))[-1]
tokens = tokenizer.convert_ids_to_tokens(input_ids_list)

# Don't worry if a snippet is not displayed---just rerun this cell.
head_view(attention, tokens)

<IPython.core.display.Javascript object>

In [13]:
from bertviz import model_view

# Don't worry if a snippet is not displayed---just rerun this cell.
model_view(attention, tokens)

<IPython.core.display.Javascript object>

## Sampling

Finally, we generate novel SMILES strings with top-$p$ sampling$-$i.e., sampling from the
smallest vocabulary subset $\mathcal{V}^{(p)} \subset \mathcal{V}$ s.t. it takes up the most
probable tokens whose cumulative probability mass exceeds $p$, $0 < p < 1$. Model
terminates the procedure upon encountering `"</s>"` or reaching maximum number
`hyperparams["max_length"]`. Special tokens are eventually removed.

In [13]:
import tqdm

model_1.eval()  # Set the base model to evaluation mode.

generated_smiles_list = []
n_generated = 350

for _ in tqdm.tqdm(range(n_generated)):
    # Generate from "<s>" so that the next token is arbitrary.
    smiles_start = torch.LongTensor([[tokenizer.bos_token_id]])
    # Get generated token IDs.
    generated_ids = model_1.generate(smiles_start,
                                   max_length=hyperparams["max_length"],
                                   do_sample=True, top_p=hyperparams["top_p"],
                                   pad_token_id=tokenizer.eos_token_id)
    # Decode the IDs into tokens and remove "<s>" and "</s>".
    generated_smiles = tokenizer.decode(generated_ids[0],
                                        skip_special_tokens=True)
    generated_smiles_list.append(generated_smiles)

generated_smiles_list[:5]

100%|██████████| 350/350 [29:50<00:00,  5.12s/it]  


['N#Cc1cnc(C(=O)Nc2ccc(F)c([C@]34C[C@H](C(F)(F)F)C(N)=N4)n2)nc1',
 'CC#Cc1cncc(-c2ccc(F)c([C@]3(C)C[C@@]3(CF)COC(N)=N3)c(Cl)c1',
 'N#c1cnc(C(=O)Nc2ccc(F)c([C@@]3(COC(N)=N3)c2)nc(C)c1',
 'CN1C(=O)C(CCOc2ccc(OC(F)F)cc2)(c2ccc3c(c2)C12N=C(N)NC[C@@H]2C)N2CC(F)(F)2',
 'CCO[C@H]1CC[C@]2(CC1)Cc1ccc(-c4cccc(C#N)c3)cc1C21N=C(N)N(C)O1']

In [14]:
import pandas as pd
import tqdm
import torch

# Assuming generated_smiles_list contains the generated SMILES strings

# Create a DataFrame with a column named 'Molecule'
df = pd.DataFrame({'Molecule': generated_smiles_list})

# Save the DataFrame to a CSV file named 'Molecule.csv' (adjust the filename as needed)
df.to_csv('generated_inhibitors.csv', index=False)

print("Generated SMILES have been successfully written to 'Molecule.csv'.")

Generated SMILES have been successfully written to 'Molecule.csv'.


## Further Reading

The pretrained model can be used for transferring knowledge to downstream tasks
including molecular property prediction. Check out
[`smiles_gpt`](https://github.com/sanjaradylov/smiles-gpt/tree/master/smiles_gpt)
repository for implementation details and
[smiles-gpt/scripts](https://github.com/sanjaradylov/smiles-gpt/scripts)
directory for single-/multi-output classification scripts. To evaluate generated
molecules, consider distribution-learning metrics from
[moleculegen-ml](https://github.com/sanjaradylov/moleculegen-ml).

If you find `smiles_gpt` as well as examples from this repository useful in your
research, please consider citing
> Adilov, Sanjar (2021): Generative Pre-Training from Molecules. ChemRxiv. Preprint. https://doi.org/10.33774/chemrxiv-2021-5fwjd