# Run PMO experiments

In [17]:
# Imports

import os, logging, argparse, sys

import torch
import torch.distributed as dist

from torch.distributed import init_process_group, destroy_process_group
from torch.distributed.elastic.multiprocessing.errors import record

from hyformer.configs.dataset import DatasetConfig
from hyformer.configs.tokenizer import TokenizerConfig
from hyformer.configs.model import ModelConfig
from hyformer.configs.trainer import TrainerConfig


from hyformer.utils.datasets.auto import AutoDataset
from hyformer.utils.tokenizers.auto import AutoTokenizer
from hyformer.models.auto import AutoModel


from hyformer.trainers.trainer import Trainer


from hyformer.utils.reproducibility import set_seed

from experiments.pmo.utils import PMOOracle

# autoreload
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
# set seed and device

set_seed(0)
device = 'cuda:0'

In [18]:
# Load the model 

TOKENIZER_CONFIG_PATH = "configs/tokenizers/smiles/config.json"
MODEL_CONFIG_PATH = "configs/models/hyformer_small/config.json"
TRAINER_CONFIG_PATH = "configs/trainers/pmo/config.json"

MODEL_CKPT_PATH = "/lustre/groups/aih/hyformer/results/distribution_learning/guacamol/hyformer_small/lm_enumerated/checkpoint.pt"

OUT_DIR = "/lustre/groups/aih/hyformer/results/pmo/guacamol/{oracle_name}/hyformer_small/pmo"

# Hyperparameters

NUM_RETRAIN_EPOCHS = 2



In [16]:
# Load configurations

tokenizer_config = TokenizerConfig.from_config_filepath(TOKENIZER_CONFIG_PATH)
model_config = ModelConfig.from_config_filepath(MODEL_CONFIG_PATH)
trainer_config = TrainerConfig.from_config_filepath(TRAINER_CONFIG_PATH)

# Modify trainer config
trainer_config.max_epochs = NUM_RETRAIN_EPOCHS

# Initialize
tokenizer = AutoTokenizer.from_config(tokenizer_config)
model = AutoModel.from_config(model_config)
model.load_pretrained(MODEL_CKPT_PATH)
model.to(device)
   


Hyformer(
  (token_embedding): Embedding(511, 256)
  (layers): ModuleList(
    (0-5): 6 x TransformerLayer(
      (attention_layer): Attention(
        (q_proj): Linear(in_features=256, out_features=256, bias=False)
        (k_proj): Linear(in_features=256, out_features=256, bias=False)
        (v_proj): Linear(in_features=256, out_features=256, bias=False)
        (out): Linear(in_features=256, out_features=256, bias=False)
        (relative_embedding): RotaryEmbedding()
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=256, out_features=1024, bias=False)
        (w3): Linear(in_features=256, out_features=1024, bias=False)
        (w2): Linear(in_features=1024, out_features=256, bias=False)
      )
      (attention_layer_normalization): RMSNorm()
      (feed_forward_normalization): RMSNorm()
    )
  )
  (layer_norm): RMSNorm()
  (lm_head): Linear(in_features=256, out_features=511, bias=False)
  (mlm_head): Linear(in_features=256, out_features=511, bias=False)

In [23]:
# Initialize trainer
trainer = Trainer(
    config=trainer_config,
    model=model,
    tokenizer=tokenizer,
    device=device,
    )

In [19]:
ORACLE_NAME = 'zaleplon_mpo'

ORACLE_BUDGET = 10000
NUM_ITERATIONS = 100
FREQ_LOG = 100
assert NUM_ITERATIONS * FREQ_LOG == ORACLE_BUDGET
SEED = 0
TEMPERATURE = 1.0  # scheduling temperature // start with 1.0 and anneal to 1.5 and then down to 0.9
TOP_K = 10
TOP_P = 0.95


In [61]:
oracle = PMOOracle(
    name=ORACLE_NAME,
    max_number_of_calls=ORACLE_BUDGET,
    freq_log=FREQ_LOG,
    dtype='none'
    )

Initializing GuacaMol benchmark: zaleplon_mpo
Benchmark type: <class 'guacamol.goal_directed_benchmark.GoalDirectedBenchmark'>
Oracle type: <class 'guacamol.scoring_function.GeometricMeanScoringFunction'>


In [None]:
# WHAT ARE SATURN TRICKS?

# get_data(data_transform=..., replay_buffer=...)

In [29]:
model = model.to_generator(tokenizer=tokenizer, batch_size=1024, device=device, temperature=2.0, top_k=10, max_sequence_length=100, top_p=None)


In [31]:
samples = model.generate(number_samples=1000)



Generating samples: 100%|██████████| 1/1 [00:20<00:00, 20.80s/it]


In [62]:
scores = oracle(samples)

In [63]:
scores

[nan,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,


In [64]:
oracle.get_metrics()

{'avg_top1': -1.0,
 'avg_top10': -1.0,
 'avg_top100': -1.0,
 'auc_top1': nan,
 'auc_top10': nan,
 'auc_top100': nan,
 'n_oracle': 540}

In [42]:
smiles = "Cc1ccc(cc1)C(C)C(C)C"

In [43]:
oracle.oracle.score([smiles])

-1.0

In [40]:
from hyformer.utils.chemistry import is_valid

is_valid(smiles)

True

In [47]:
oracle.oracle.score(smiles)

-1.0

In [None]:
""" Get Guacamol objectives.

Source: https://github.com/BenevolentAI/guacamol/blob/master/guacamol/standard_benchmarks.py
"""

import math
import numpy as np

import networkx as nx
from tqdm import tqdm

from rdkit import Chem, DataStructs
from rdkit.Chem import Crippen, rdmolops
from rdkit.Chem.QED import qed
from rdkit.Chem.Fingerprints import FingerprintMols
from guacamol import standard_benchmarks

from rdkit import Chem


med1 = standard_benchmarks.median_camphor_menthol()  # 'Median molecules 1'
med2 = standard_benchmarks.median_tadalafil_sildenafil()  # 'Median molecules 2',
pdop = standard_benchmarks.perindopril_rings()  # 'Perindopril MPO',
osmb = standard_benchmarks.hard_osimertinib()  # 'Osimertinib MPO',
adip = standard_benchmarks.amlodipine_rings()  # 'Amlodipine MPO'
siga = standard_benchmarks.sitagliptin_replacement()  # 'Sitagliptin MPO'
zale = standard_benchmarks.zaleplon_with_other_formula()  # 'Zaleplon MPO'
valt = standard_benchmarks.valsartan_smarts()  # 'Valsartan SMARTS',
dhop = standard_benchmarks.decoration_hop()  # 'Deco Hop'
shop = standard_benchmarks.scaffold_hop()  # Scaffold Hop'
rano= standard_benchmarks.ranolazine_mpo()  # 'Ranolazine MPO'
fexo = standard_benchmarks.hard_fexofenadine()  # 'Fexofenadine MPO'... 'make fexofenadine less greasy'


guacamol_objs = {"med1": med1, "pdop": pdop, "adip": adip, "rano": rano, "osmb": osmb, "siga": siga, "zale": zale,
                 "valt": valt, "med2": med2, "dhop": dhop, "shop": shop, 'fexo': fexo}


GUACAMOL_TASK_NAMES = [
    'med1', 'pdop', 'adip', 'rano', 'osmb', 'siga',
    'zale', 'valt', 'med2', 'dhop', 'shop', 'fexo',
    'qed', 'qed_classification'
]


def smile_is_valid_mol(smile):
    if smile is None or len(smile)==0:
        return False
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return False
    return True


def smile_to_guacamole_score(obj_func_key, smile):
    if smile is None or len(smile)==0:
        return None
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None
    func = guacamol_objs[obj_func_key]
    score = func.objective.score(smile)
    if score is None:
        return None
    if score < 0:
        return None
    return score


def smile_to_rdkit_mol(smile):
    return Chem.MolFromSmiles(smile)


def smile_to_QED(smile):
    """
    Computes RDKit's QED score
    """
    if smile is None:
        return None
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None
    qed_score = qed(mol)
    return qed_score


def smile_to_sa(smile):
    """Synthetic Accessibility Score (SA):
    a heuristic estimate of how hard (10)
    or how easy (1) it is to synthesize a given molecule."""
    if smile is None:
        return None
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None
    return sascorer.calculateScore(mol)


def smile_to_penalized_logP(smile):
    """ calculate penalized logP for a given smiles string """
    if smile is None:
        return None
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None
    logp = Crippen.MolLogP(mol)
    sa = sascorer.calculateScore(mol)
    cycle_length = _cycle_score(mol)
    """
    Calculate final adjusted score.
    These magic numbers are the empirical means and
    std devs of the dataset.

    I agree this is a weird way to calculate a score...
    but this is what previous papers did!
    """
    score = (
            (logp - 2.45777691) / 1.43341767
            + (-sa + 3.05352042) / 0.83460587
            + (-cycle_length - -0.04861121) / 0.28746695
    )
    return max(score, -float("inf"))


def _cycle_score(mol):
    cycle_list = nx.cycle_basis(nx.Graph(rdmolops.GetAdjacencyMatrix(mol)))
    if len(cycle_list) == 0:
        cycle_length = 0
    else:
        cycle_length = max([len(j) for j in cycle_list])
    if cycle_length <= 6:
        cycle_length = 0
    else:
        cycle_length = cycle_length - 6
    return cycle_length


def smiles_to_desired_scores(smiles_list, task_id="logp", verbose=False):
    if verbose:
        return smiles_to_desired_scores_with_verbose(smiles_list, task_id)
    else:
        return smiles_to_desired_scores_without_verbose(smiles_list, task_id)


def smiles_to_desired_scores_with_verbose(smiles_list, task_id="logp"):
    scores = []
    for smiles_str in tqdm(smiles_list):
        if task_id == "logp":
            score_ = smile_to_penalized_logP(smiles_str)
        elif task_id == "qed":
            score_ = smile_to_QED(smiles_str)
        else:  # otherwise, assume it is a guacamol task
            score_ = smile_to_guacamole_score(task_id, smiles_str)
        if (score_ is not None) and (math.isfinite(score_)):
            scores.append(score_)
        else:
            scores.append(np.nan)

    return np.array(scores)


def smiles_to_desired_scores_without_verbose(smiles_list, task_id="logp"):
    scores = []
    for smiles_str in tqdm(smiles_list):
        if task_id == "logp":
            score_ = smile_to_penalized_logP(smiles_str)
        elif task_id == "qed":
            score_ = smile_to_QED(smiles_str)
        else:  # otherwise, assume it is a guacamol task
            score_ = smile_to_guacamole_score(task_id, smiles_str)
        if (score_ is not None) and (math.isfinite(score_)):
            scores.append(score_)
        else:
            scores.append(np.nan)

    return np.array(scores)


def get_fingerprint_similarity(smile1, smile2):
    mol1 = Chem.MolFromSmiles(smile1)
    mol2 = Chem.MolFromSmiles(smile2)
    if (mol1 is None) or (mol2 is None):
        print("one of the input smiles is not a valid molecule!")
        return None
    fp1 = FingerprintMols.FingerprintMol(mol1)
    fp2 = FingerprintMols.FingerprintMol(mol2)
    fps = DataStructs.FingerprintSimilarity(fp1, fp2)
    return fps

In [54]:
value = smile_to_guacamole_score(smile=smiles, obj_func_key='zale')

In [55]:
value

In [59]:
func = guacamol_objs["zale"]
score = func.objective.score("CCCccc")

In [60]:
score

-1.0