# Generate molecules

In [40]:
# Imports

from hyformer.configs.dataset import DatasetConfig
from hyformer.configs.tokenizer import TokenizerConfig
from hyformer.configs.model import ModelConfig

from hyformer.utils.datasets.auto import AutoDataset
from hyformer.utils.tokenizers.auto import AutoTokenizer
from hyformer.models.auto import AutoModel

from hyformer.configs.trainer import TrainerConfig
from hyformer.trainers.trainer import Trainer

# auxiliary imports
import torch
import torch.nn.functional as F

from hyformer.models.wrappers import HyformerEncoderWrapper

# autoreload magic
%reload_ext autoreload
%autoreload 2


In [41]:
# Paths

MODEL_NAME = 'hyformer'
TASK_NAME = 'combined'
OUTPUT_DIR = f"/lustre/groups/aih/hyformer/results/distribution_learning/guacamol/{MODEL_NAME}/{TASK_NAME}/hpo/lr_6e-4"
DATASET_DIR = "/lustre/groups/aih/hyformer/data"

DATASET_CONFIG_PATH = 'configs/datasets/guacamol/config.json'
TOKENIZER_CONFIG_PATH = 'configs/tokenizers/smiles/guacamol/config.json'
MODEL_CONFIG_PATH = 'configs/models/guacamol_vocab/hyformer/config.json'
MODEL_CHECKPOINT_PATH = f'{OUTPUT_DIR}/ckpt.pt'

In [10]:
# Load model

dataset_config = DatasetConfig.from_config_filepath(DATASET_CONFIG_PATH)
test_dataset = AutoDataset.from_config(dataset_config, split="test", root=DATASET_DIR)

tokenizer_config = TokenizerConfig.from_config_filepath(TOKENIZER_CONFIG_PATH)
tokenizer = AutoTokenizer.from_config(tokenizer_config)

# Load model
model_config = ModelConfig.from_config_filepath(MODEL_CONFIG_PATH)
model = AutoModel.from_config(model_config)
model.load_pretrained(filepath=MODEL_CHECKPOINT_PATH)




In [50]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 16

encoder = model.to_encoder(tokenizer=tokenizer, batch_size=batch_size, device=device)


In [64]:
smiles = ['C1CCCCC1', 'C1CCCC1']
embeddings = encoder.encode(smiles)

Encoding samples: 100%|██████████| 1/1 [00:04<00:00,  4.92s/it]


In [65]:
embeddings

array([[-1.23591805, -0.71528178, -1.16197789, ..., -2.22820139,
         1.0620836 , -1.64014733],
       [-1.24176145, -0.73540109, -1.1245482 , ..., -2.21864176,
         0.95939815, -1.68918061]])