In [1]:
# %load preprocessing.py
import numpy as np
import pickle as pkl
import os, sys, sparse
from util import atomFeatures, bondFeatures2
from rdkit import Chem, RDConfig, rdBase
from rdkit.Chem import AllChem, ChemicalFeatures
from sklearn.metrics.pairwise import euclidean_distances
import pandas as pds
import save_dict as sd
from API import get_mpnn_input

In [2]:
from sklearn.feature_selection import VarianceThreshold

In [3]:
index_dict,pu_dict,pair_atom_dict,adj_list,structure,n_max,feature,pu_feature = get_mpnn_input('polymer.csv','polymer.sdf')

In [8]:
pu_feature = pds.DataFrame(pu_feature).T

In [9]:
dim_node = 167

In [10]:
#对浮点数四舍五入
def get_int(num):
    num1, num2 = str(num).split('.')
    if float(str(0) + '.' + num2) >= 0.5:
        return(int(num1) + 1)
    else:
        return(int(num1))

In [None]:
dim_edge=8
atom_list=['Se','H','Li','B','C','N','O','F','Na','Mg','Si','P','S','Cl','K','Ca','Br','Bi','Ge']
DV = []
DE = []
DP = [] 
DY = []
Dsmi = []
for i, mol in enumerate(structure):
    name = mol.GetProp('_Name')#单体的名称
    num_index = list(pu_dict[name].values())#节点的种类
    num_node = len(num_index)#每个单体中节点的个数
    node_name = list(pu_dict[name].keys())#节点的编号
    pair_atom = pair_atom_dict[name]
    if mol==None: continue
    try: Chem.SanitizeMol(mol)
    except: continue
    smi = Chem.MolToSmiles(mol)
    if '.' in Chem.MolToSmiles(mol): continue
    n_atom = mol.GetNumAtoms()
    rings = mol.GetRingInfo().AtomRings()
    
    # node DV
    node = np.zeros((n_max, dim_node), dtype=np.int8)
    for j in range(num_node):
        index = num_index[j]
        node[j, :] = pu_feature.loc[index]
        
     # 3D pos DP
    pos = mol.GetConformer().GetPositions()
    pos_array = np.array(pos)
    node_ave = []
    for j in range(num_node):
        index = node_name[j]
        atom_num = index_dict[name][index]
        pu_pos = pos_array[atom_num]
        ave_pos = pos_array[atom_num].mean(axis=0)
        node_ave.append(ave_pos)
        
    node_pos = np.array(node_ave)
    proximity = np.zeros((n_max, n_max))
    proximity[:num_node, :num_node] = euclidean_distances(node_pos)
    
    # property DY 
    pid = np.where(feature[:,0]==mol.GetProp('_Name'))[0][0]
    property = feature[pid,1:]

    ##edge DE
    adj = adj_list[i]
    node_bond = get_int(sum(sum(adj))/2)
    pair_index = pair_atom['pair_index']
    atom_pair = pair_atom['pair_atom']
    num_pair = len(atom_pair)
    
    
    edge = np.zeros((n_max,n_max,dim_edge),dtype=np.float32)
    for j in range(num_node-1):
        for k in range(num_node):
            if adj[j,k]==1:
                j_name=node_name[j]
                k_name=node_name[k]
                j_index=num_index[j]
                k_index=num_index[k]
                pair1=[j_name,k_name]
                pair2=[k_name,j_name]
                if pair1 in pair_index:   
                    locat = pair_index.index(pair1)
                elif pair2 in pair_index:
                    locat = pair_index.index(pair2)
                else:
                    #print("Index_error: {} no found".format(pair1))
                    continue
                pair = atom_pair[locat]
                edge[j,k,:6] = bondFeatures2(pos,pair[0],pair[1], mol, rings)
                edge[j,k,6] = j_index*0.01
                edge[j,k,7] = k_index*0.01
                edge[k, j, :] = edge[j, k, :]
    
    # append
    DV.append(np.array(node))
    DE.append(np.array(edge))
    DP.append(np.array(proximity))
    DY.append(np.array(property,dtype = np.float64))
    Dsmi.append(smi)
    
    if i % 50 == 0:
        print(i, Chem.MolToSmiles(Chem.RemoveHs(mol)), flush=True)

In [12]:
# np array    
DV = np.asarray(DV, dtype=np.int8)
DE = np.asarray(DE, dtype=np.int8)
DP = np.asarray(DP)
DY = np.asarray(DY)
Dsmi = np.asarray(Dsmi)

# compression
DV = sparse.COO.from_numpy(DV)
DE = sparse.COO.from_numpy(DE)

print(DV.shape, DE.shape, DP.shape, DY.shape)

# save
with open('MACCS.pkl','wb') as fw:
    pkl.dump([DV, DE, DP, DY, Dsmi], fw)

(697, 20, 167) (697, 20, 20, 8) (697, 20, 20) (697, 4)
