In [1]:
from tqdm import trange
import numpy as np
import pandas as pd
import random
from tqdm import trange
import pickle 

# hl_pre

In [2]:
from data import convert_to_single_emb, graph_transform
from trainer import *

In [3]:
def checknodes(df, max_node=150):
    index=[]
    for i in range(len(df)):
        smile = df['smiles'][i]
        mol = Chem.MolFromSmiles(smile)
        num_node = mol.GetNumAtoms()
        if num_node>max_node:
            pass
        else:
            index.append(i)
    return df.loc[index]
def get_inputs_bysmile(smi, max_node, shuffle_=False):
    smile = smi
    num_node, graph, adj = graph_transform(smile, max_node=max_node)
    graph = convert_to_single_emb(graph).long()
    mask = torch.zeros(max_node, 1)   

    pad1 = nn.ConstantPad2d((0, 0, 0, max_node - num_node), 0)
    pad2 = nn.ConstantPad2d((0, max_node - num_node, 0, max_node - num_node), 0)

    padded_graph = pad1(graph)
    padded_adj = pad2(adj)
    mask[:num_node] = 1/num_node

    # if shuffle_:
    #    padded_graph, padded_adj, mask = shuffle(padded_graph, padded_adj, mask, self.max_node)

    assert padded_adj.shape[0] == max_node
    assert padded_graph.shape[0] == max_node

    return smile, padded_graph, padded_adj, mask

def batch_get_inputs_bysmile(smi_list, max_node, shuffle_=False):
    batch_size = len(smi_list)
    graphs = None
    for i in range(len(smi_list)):
        smi = smi_list[i]
        smile, padded_graph, padded_adj, mask = get_inputs_bysmile(smi, max_node=max_node)
        if graphs is None:
            graphs = torch.zeros(batch_size, max_node, padded_graph.shape[-1]).long()
            adjs = torch.zeros(batch_size, max_node, max_node)
            masks = torch.zeros(batch_size, max_node, 1)
        graphs[i, :, :] = padded_graph
        adjs[i, :, :] = padded_adj
        masks[i, :, :] = mask
    return graphs, adjs, masks

def hl_pre_single(smiles, max_node, mean, std):
    graphs, adjs, masks = batch_get_inputs_bysmile([smiles], max_node=max_node)
    graphs = graphs.to(device)
    adjs = adjs.to(device)
    masks = masks.to(device)
    with torch.no_grad():
        results1 = hl_model([graphs, adjs, masks])
    results1 = torch.add(torch.multiply(results1, std), mean)
    return results1[0].numpy()

def hl_pre_batch(smiles_list, max_node, mean, std):
    graphs, adjs, masks = batch_get_inputs_bysmile(smiles_list, max_node=max_node)
    graphs = graphs.to(device)
    adjs = adjs.to(device)
    masks = masks.to(device)
    with torch.no_grad():
        results1 = hl_model([graphs, adjs, masks])
    results1 = torch.add(torch.multiply(results1, std), mean)
    return results1.numpy()

In [None]:
device = torch.device('cpu')
hl_model = torch.load('HL_GENEncoderPre.pt').to(device)
# 若出现兼容性问题：1.10版本的gelu并无该属性，提前定义
for m in hl_model.modules():
    if type(m) is nn.GELU:
        m.approximate = 'none'

hl_model.eval()
mean = hl_model.scaler[0].to(device)
std = hl_model.scaler[1].to(device)

In [10]:
hl_pre_single('N#CC1=CC=C(C=C1C#N)OC2=CC(OB(OC3=CC(OC4=CC(C#N)=C(C=C4)C#N)=CC=C3)OC5=CC=CC(OC6=CC=C(C#N)C(C#N)=C6)=C5)=CC=C2', hl_model.max_node, mean, std)

array([-6.4206886, -3.2992914,  3.1222663], dtype=float32)

In [31]:
hl_pre_single('c1ccccc1', hl_model.max_node, mean, std)

array([-6.393021 , -1.2908527,  5.1026797], dtype=float32)

# tp_pre

In [12]:
from sklearn import svm
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.model_selection import KFold, cross_validate

In [13]:
# calculate descriptors
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors
from rdkit.Chem import Descriptors
from mordred import Calculator, descriptors

In [14]:
def smi2smi(smiles, chiral=True, H=False):
    m = Chem.MolFromSmiles(smiles)
    smi = Chem.MolToSmiles(m, isomericSmiles=chiral, allHsExplicit=H)
    return smi    

def calculate_des(smiles_list, des_cal):
    smiles_ = [smi2smi(smiles) for smiles in smiles_list]
    mols = [Chem.MolFromSmiles(smiles) for smiles in smiles_]
    des_list = []
    for mol in mols:
        des_ = []
        for key in des_cal.keys():
            try:
                tmp = des_cal[key](mol)
            except:
                tmp = np.nan
            des_.append(tmp)
        des_list.append(des_)
    return np.array(des_list)

In [15]:
def get_input_mw(PN_smiles, ini_smiles, dsc_rate, PN_descriptors_mord, ini_descriptors_mord, mw=None):
    if mw is None:
        mw = Chem.Descriptors.ExactMolWt(Chem.MolFromSmiles(smiles))
    hls_1 = hl_pre_single(PN_smiles, hl_model.max_node, mean, std)
    hls_2 = hl_pre_single(ini_smiles, hl_model.max_node, mean, std)
    PN_pre_homo = np.array([hls_1[0]])
    ini_pre_hl = np.array([hls_2[-1]])
    
    PN_des = calculate_des([PN_smiles], PN_descriptors_mord)[0]
    ini_des = calculate_des([ini_smiles], ini_descriptors_mord)[0]
    
    inputs = np.concatenate((PN_des, ini_des, np.array([dsc_rate]),np.array([mw]), PN_pre_homo, ini_pre_hl))

    return inputs

def get_input(PN_smiles, ini_smiles, dsc_rate, PN_descriptors_mord, ini_descriptors_mord):
    hls_1 = hl_pre_single(PN_smiles, hl_model.max_node, mean, std)
    hls_2 = hl_pre_single(ini_smiles, hl_model.max_node, mean, std)
    PN_pre_homo = np.array([hls_1[0]])
    ini_pre_hl = np.array([hls_2[-1]])
    
    PN_des = calculate_des([PN_smiles], PN_descriptors_mord)[0]
    ini_des = calculate_des([ini_smiles], ini_descriptors_mord)[0]
    
    inputs = np.concatenate((PN_des, ini_des, np.array([dsc_rate]), PN_pre_homo, ini_pre_hl))

    return inputs

def get_input_batch(PN_smiles_list, ini_smiles_list, dsc_rate, PN_descriptors_mord, ini_descriptors_mord):
    hls_1 = hl_pre_batch(PN_smiles_list, hl_model.max_node, mean, std)
    hls_2 = hl_pre_batch(ini_smiles_list, hl_model.max_node, mean, std)
    PN_pre_homo = hls_1[:,0].reshape(-1,1)
    ini_pre_hl = hls_2[:,-1].reshape(-1,1)
    
    PN_des = calculate_des(PN_smiles_list, PN_descriptors_mord)
    ini_des = calculate_des(ini_smiles_list, ini_descriptors_mord)
    inputs = np.concatenate((PN_des, ini_des, np.array([dsc_rate] * len(PN_smiles_list)).reshape(-1,1), PN_pre_homo, ini_pre_hl), axis=1)
    
    return inputs    

def tp_pre(model, inputs):
    [x_mean, x_std] = model.x_scaler
    [y_mean, y_std] = model.y_scaler
    inputs_ = (inputs.reshape(-1,x_mean.shape[0]) - x_mean) / (x_std+1e-9)
    inputs_[np.isnan(inputs_)] = 0
    pre = model.predict(inputs_)
    pred = pre * (1e-9+y_std) + y_mean
    return pred

In [26]:
with open('Tp_brgr_svr.pkl', 'rb') as f:
    regr = pickle.load(f)
    
calc = Calculator(descriptors, ignore_3D=True)
PN_descriptors_mord = {}
for i, desc in enumerate(calc.descriptors):
    if desc.__str__() in regr.PN_xnames:
        PN_descriptors_mord[desc.__str__()] = desc
ini_descriptors_mord = {}
for i, desc in enumerate(calc.descriptors):
    if desc.__str__() in regr.ini_xnames:
        ini_descriptors_mord[desc.__str__()] = desc

In [None]:
for key in list(regr.ini_dict):
    PN_smiles_list = 'C[Si](c1cc(Oc2cc(C#N)c(C#N)cc2)ccc1)(c3cc(Oc4cc(C#N)c(C#N)cc4)ccc3)c5ccccc5'
    result = tp_pre(regr, get_input(PN_smiles_list, regr.ini_dict[key], 10, PN_descriptors_mord, ini_descriptors_mord))
    print(key, result)