# Generate molecules

In [106]:
# Imports

from hyformer.configs.tokenizer import TokenizerConfig
from hyformer.configs.model import ModelConfig

from hyformer.utils.tokenizers.auto import AutoTokenizer
from hyformer.models.auto import AutoModel

from hyformer.utils.chemistry import is_valid


# auxiliary imports
import torch
import torch.nn.functional as F

# autoreload magic
%reload_ext autoreload
%autoreload 2


In [107]:
# Paths

MODEL_NAME = 'hyformer_small'
TASK_NAME = 'lm'

TOKENIZER_CONFIG_PATH = 'configs/tokenizers/smiles/config.json'
MODEL_CONFIG_PATH = 'configs/models/hyformer_small/config.json'
MODEL_CHECKPOINT_PATH = f'/lustre/groups/aih/hyformer/results/distribution_learning/guacamol/{MODEL_NAME}/{TASK_NAME}/checkpoint.pt'

In [108]:
# Load model

tokenizer_config = TokenizerConfig.from_config_path(TOKENIZER_CONFIG_PATH)
tokenizer = AutoTokenizer.from_config(tokenizer_config)

# Load model
model_config = ModelConfig.from_config_path(MODEL_CONFIG_PATH)
model = AutoModel.from_config(model_config)
model.load_pretrained(filepath=MODEL_CHECKPOINT_PATH)



In [118]:
batch_size = 64
temperature = 1.2
top_k = None
top_p = 0.9
max_sequence_length = 100

device = 'cuda:0'

generator = model.to_generator(
    tokenizer=tokenizer,
    batch_size=batch_size,
    temperature=temperature,
    top_k=top_k,
    top_p=top_p,
    max_sequence_length=max_sequence_length,
    device=device
    )

samples = generator.generate(number_samples=1000)
is_valid_smiles = [is_valid(sample) for sample in samples]
print("Validity: ", sum(is_valid_smiles) / len(is_valid_smiles))

Generating samples: 100%|██████████| 16/16 [00:15<00:00,  1.04it/s]

Validity:  0.978





In [119]:
samples

['COc1ccc(-c2cc(C(=O)N3CCC(C)CC3)no2)cc1',
 'CCCCNC(=O)COC(=O)c1ccccc1C(=O)Nc1ccc(C)cc1',
 'CC(C)CN(Cc1ccc(OCC(F)(F)F)cc1)S(=O)(=O)C12CCC(C(=O)Nc3ccc(Cl)cc3)(CC1)CC2',
 'CNCCCC(Oc1cccc2ccccc12)c1ccccc1',
 'COc1cc(C=CC(=O)c2ccc(O)c(N=Nc3ccc(C(=O)O)cc3)c2)ccc1O',
 'COc1cccc(NC(=O)C2CCCN(S(=O)(=O)c3ccc(OC)cc3)C2)c1',
 'O=C(NCC(=O)N1CCN(Cc2ccccc2)CC1)c1ccc(Cl)cc1',
 'COc1cccc(-c2cc(C(=O)NCC3CCCO3)c3ccccc3n2)c1',
 'O=C(O)CCC(=O)NC(Cc1ccc(OCc2ccc3ccccc3n2)cc1)C(=O)O',
 'CCOC(=O)c1c(NC(=O)C2CCCO2)sc2c1CCN(C(C)=O)C2',
 'O=C(O)COc1cccc(CC2CCCCC2)c1',
 'COc1ccccc1NC(=O)CSc1nnc(-c2ccc(Cl)cc2)n1C',
 'COc1ccccc1CCN1CCN(Cc2ccc3c(c2)OCO3)CC1',
 'CN(Cc1ccccc1)C(=O)C(Cc1ccccc1)NC(=O)C1Cc2ccccc2CN1C(=O)C(N)Cc1c(C)cc(O)cc1C',
 'CCOC(=O)c1cc(C#N)c(N2CCCCC2)nc1C',
 'O=c1oc2ccccc2cc1C1=Nn2c(nnc2-c2ccc(F)cc2)SC1',
 'CC1=C(CC(=O)NCCCCN2CCN(CC(=O)Nc3ccc(F)cc3)CC2)c2ccccc2C(=O)O1',
 'O=C(CN1CCN(c2ccccc2)CC1)NN=Cc1ccccc1O',
 'O=C(CSc1nc2ccccc2s1)Nc1cccc(F)c1',
 'CC(=O)NC(C)c1ccc(OC2CCN(c3ccnc(N4CC(F)(F)C4)c3F)C2