# Calculate substrate embeddings by Mole-Bert

In [3]:
import os
import torch
from mole_bert.model import GNN_graphpred
from mole_bert.loader import mol_to_graph_data_obj_simple
import pandas as pd
from rdkit import Chem
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
current_dir = os.getcwd()

MODEL_CONFIG = {
    "num_layer": 5,  # number of graph conv layers
    "emb_dim": 300,  # embedding dimension in graph conv layers
    "num_tasks": 1,  # output feature dimention
    "drop_ratio": 0.5,  # dropout ratio
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
mole_bert_model = GNN_graphpred(MODEL_CONFIG['num_layer'], MODEL_CONFIG['emb_dim'], num_tasks=MODEL_CONFIG['num_tasks'], drop_ratio=MODEL_CONFIG['drop_ratio']).to(device)
mole_bert_model.from_pretrained(f'{current_dir}/mole_bert/model_gin/Mole-BERT.pth')
for p in mole_bert_model.parameters():
    p.requires_grad = False

# config
df_all = pd.read_json(f'{current_dir}/../dataset/df_multi_tasks.json')
input_smiles = df_all['smiles'].drop_duplicates()
print(f'number of smiles : {len(input_smiles)}')

embeddings = []
valid_smiles = []
error_smiles = []
mole_bert_model.eval()
with torch.no_grad():
    for mol_smiles in tqdm(input_smiles):
        try:
            mol = Chem.MolFromSmiles(mol_smiles)
            graph = mol_to_graph_data_obj_simple(mol).to(device)
            result = mole_bert_model(graph)
            embedding = mole_bert_model(graph)[1].cpu().numpy().mean(axis=0)
            embeddings.append(embedding)
            valid_smiles.append(mol_smiles)

        except Exception as e:
            error_smiles.append(mol_smiles)

df_result = pd.DataFrame({'smiles': valid_smiles, 'molebert': embeddings})
df_result.to_json(f'{current_dir}/results/df_molebert.json')
print(f'num of valid smiles: {len(valid_smiles)}')

# 定义文件名
output_file = f"{current_dir}/results/mole_bert_error_smiles.txt"

# 写入文件
print(f'num of error smiles: {len(error_smiles)}')
with open(output_file, "w") as file:
    for smiles in error_smiles:
        file.write(smiles + "\n")

df_result.head()

number of smiles : 3770


100%|██████████| 3770/3770 [00:18<00:00, 209.00it/s]


num of valid smiles: 3737
num of error smiles: 33


Unnamed: 0,smiles,molebert
0,CC=CCO,"[0.21251829, -0.26969907, -0.043005787, 0.0656..."
1,CC=CC=O,"[0.09008882, -0.30007035, 0.0593893, 0.2563640..."
2,C1CC2=CC=CC=C2C1O,"[-0.13591726, -0.14269216, -0.25000495, 0.0303..."
3,CCC(C)O,"[-0.2930693, -0.2486641, -0.20994978, 0.161511..."
4,CCCC(C)O,"[0.051676642, -0.19323665, 0.010754765, 0.1088..."


# Smiles_trans

In [1]:
import torch
from smiles_transformer.build_vocab import WordVocab
from smiles_transformer.pretrain_trfm import TrfmSeq2seq
from smiles_transformer.utils import split
import pandas as pd
from tqdm import tqdm


def smiles_to_vec(Smiles):
    pad_index = 0
    unk_index = 1
    eos_index = 2
    sos_index = 3
    mask_index = 4
    vocab = WordVocab.load_vocab(f'{current_dir}/smiles_transformer/vocab.pkl')
    def get_inputs(sm):
        seq_len = 220
        sm = sm.split()
        if len(sm)>218:
            print('SMILES is too long ({:d})'.format(len(sm)))
            sm = sm[:109]+sm[-109:]
        ids = [vocab.stoi.get(token, unk_index) for token in sm]
        ids = [sos_index] + ids + [eos_index]
        seg = [1]*len(ids)
        padding = [pad_index]*(seq_len - len(ids))
        ids.extend(padding), seg.extend(padding)
        return ids, seg

    def get_array(smiles):
        x_id, x_seg = [], []
        for sm in smiles:
            a,b = get_inputs(sm)
            x_id.append(a)
            x_seg.append(b)
        return torch.tensor(x_id), torch.tensor(x_seg)

    trfm = TrfmSeq2seq(len(vocab), 256, len(vocab), 4)
    trfm.load_state_dict(torch.load(f'{current_dir}/smiles_transformer/trfm_12_23000.pkl'))
    trfm.eval()
    x_split = [split(sm) for sm in Smiles]
    xid, xseg = get_array(x_split)
    X = trfm.encode(torch.t(xid))

    return X

# config
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
df_all = pd.read_json(f'{current_dir}/../ori_data/df_sequence_smiles_valid.json')
input_smiles = df_all['smiles'].drop_duplicates()
print(f'number of smiles : {len(input_smiles)}')

error_smiles = []
valid_smiles = []
embeddings = []
for mol_smiles in tqdm(input_smiles):
    try:
        embedding = smiles_to_vec(mol_smiles)
        embeddings.append(embedding.mean(axis=0))
        valid_smiles.append(mol_smiles)

    except Exception as e:
        error_smiles.append(mol_smiles)

df_result = pd.DataFrame({'smiles': valid_smiles, 'transsmiles': embeddings})
df_result.to_json(f'{current_dir}/results/df_transsmiles.json')
print(f'num of valid smiles: {len(valid_smiles)}')

# 定义文件名
output_file = f"{current_dir}/results/trans_smiles_error_smiles.txt"

# 写入文件
print(f'num of error smiles: {len(error_smiles)}')
with open(output_file, "w") as file:
    for smiles in error_smiles:
        file.write(smiles + "\n")

df_result.head()

number of smiles : 3770


  trfm.load_state_dict(torch.load('./smiles_transformer/trfm_12_23000.pkl'))
100%|██████████| 3770/3770 [13:24<00:00,  4.69it/s]


num of valid smiles: 3770
num of error smiles: 0


Unnamed: 0,smiles,transsmiles
0,CC=CCO,"[-0.08475136, -0.28812924, 0.332107, 0.1950996..."
1,CC=CC=O,"[-0.07424254, -0.27599233, 0.31990165, 0.21637..."
2,C1CC2=CC=CC=C2C1O,"[-0.070939295, -0.25465176, 0.32130536, 0.2360..."
3,CCC(C)O,"[-0.09016073, -0.28187382, 0.31951085, 0.24044..."
4,CCCC(C)O,"[-0.093634106, -0.29001388, 0.32210785, 0.2253..."
