-
Notifications
You must be signed in to change notification settings - Fork 409
Expand file tree
/
Copy pathmol.py
More file actions
68 lines (53 loc) · 2.24 KB
/
mol.py
File metadata and controls
68 lines (53 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from ogb.utils.features import (allowable_features, atom_to_feature_vector,
bond_to_feature_vector, atom_feature_vector_to_dict, bond_feature_vector_to_dict)
from rdkit import Chem
import numpy as np
def ReorderCanonicalRankAtoms(mol):
order = tuple(zip(*sorted([(j, i) for i, j in enumerate(Chem.CanonicalRankAtoms(mol))])))[1]
mol_renum = Chem.RenumberAtoms(mol, order)
return mol_renum, order
def smiles2graph(smiles_string, removeHs=True, reorder_atoms=False):
"""
Converts SMILES string to graph Data object
:input: SMILES string (str)
:return: graph object
"""
mol = Chem.MolFromSmiles(smiles_string)
mol = mol if removeHs else Chem.AddHs(mol)
if reorder_atoms:
mol, _ = ReorderCanonicalRankAtoms(mol)
# atoms
atom_features_list = []
for atom in mol.GetAtoms():
atom_features_list.append(atom_to_feature_vector(atom))
x = np.array(atom_features_list, dtype = np.int64)
# bonds
num_bond_features = 3 # bond type, bond stereo, is_conjugated
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = bond_to_feature_vector(bond)
# add edges in both directions
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = np.array(edges_list, dtype = np.int64).T
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = np.array(edge_features_list, dtype = np.int64)
else: # mol has no bonds
edge_index = np.empty((2, 0), dtype = np.int64)
edge_attr = np.empty((0, num_bond_features), dtype = np.int64)
graph = dict()
graph['edge_index'] = edge_index
graph['edge_feat'] = edge_attr
graph['node_feat'] = x
graph['num_nodes'] = len(x)
return graph
if __name__ == '__main__':
graph = smiles2graph('O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5')
print(graph)