In [1]:
import sys
from pathlib import Path

# Set the path to your project directory
project_path = Path("/Users/stefanhangler/Documents/Uni/Msc_AI/Thesis/Code.nosync/jku-ml-seminar23/drug-discovery/individual/hangler/thesis")
sys.path.append(str(project_path))

In [2]:
import torch
from transformers import BertTokenizer
from train_autoencoder import ldmol_autoencoder
from utils import AE_SMILES_encoder
from rdkit import Chem

class SMILESEncoder:
    def __init__(self, checkpoint_path, tokenizer_path, config_path):
        # Load tokenizer
        self.tokenizer = BertTokenizer(vocab_file=tokenizer_path, do_lower_case=False, do_basic_tokenize=False)
        # Load autoencoder model
        self.ae_config = {
            'bert_config_decoder': config_path + '/config_decoder.json',
            'bert_config_encoder': config_path + '/config_encoder.json',
            'embed_dim': 256,
        }
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = ldmol_autoencoder(config=self.ae_config, no_train=True, tokenizer=self.tokenizer).to(self.device)
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        self.model.load_state_dict(checkpoint['model'], strict=False)
        self.model.eval()
    
    def encode_smiles(self, smiles_list):
        # Convert SMILES strings to canonical form
        smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smile), isomericSmiles=True, canonical=True) for smile in smiles_list]
        # Encode SMILES strings into embeddings
        embeddings = AE_SMILES_encoder(smiles_list, self.model).permute(0, 2, 1).unsqueeze(-1)
        return embeddings

: 

In [None]:
    # Define paths
    checkpoint_path = "./Pretrain/checkpoint_autoencoder.ckpt"  # Replace with your checkpoint path
    tokenizer_path = "./vocab_bpe_300_sc.txt"  # Replace with your tokenizer vocab file
    config_path = "./"  # Path containing config_encoder.json and config_decoder.json

    # Initialize encoder
    encoder = SMILESEncoder(checkpoint_path, tokenizer_path, config_path)

    # Input SMILES strings
    smiles_strings = ["CCO", "C1=CC=CC=C1", "CC(=O)O", "C1CCCCC1"]

    # Get embeddings
    embeddings = encoder.encode_smiles(smiles_strings)
    print(f"Embeddings shape: {embeddings.shape}")  # Should show (num_smiles, embed_dim, latent_size, 1)
    print("Sample embedding for first SMILES:")
    print(embeddings[0])