forked from chemprop/chemprop
/
vocab.py
239 lines (186 loc) · 9.16 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from argparse import Namespace
from copy import deepcopy
from functools import partial
from multiprocessing import Pool
import random
from typing import Callable, List, FrozenSet, Set, Tuple, Union
from collections import Counter
from rdkit import Chem
import torch
from chemprop.features import atom_features, bond_features, get_atom_fdim, FunctionalGroupFeaturizer
class Vocab:
def __init__(self, args: Namespace, smiles: List[str]):
self.substructure_sizes = args.bert_substructure_sizes
self.vocab_func = partial(
atom_vocab,
vocab_func=args.bert_vocab_func,
substructure_sizes=self.substructure_sizes,
args=args
)
if args.bert_vocab_func == 'feature_vector':
self.unk = None
self.output_size = get_atom_fdim(args, is_output=True)
return # don't need a real vocab list here
self.unk = 'unk'
self.smiles = smiles
self.vocab = get_vocab(args, self.vocab_func, self.smiles)
self.vocab.add(self.unk)
self.vocab_size = len(self.vocab)
self.vocab_mapping = {word: i for i, word in enumerate(sorted(self.vocab))}
self.output_size = self.vocab_size
def w2i(self, word: str) -> int:
if self.unk is None:
return word # in this case, we didn't map to a vocab at all; we're just predicting the original features
return self.vocab_mapping[word] if word in self.vocab_mapping else self.vocab_mapping[self.unk]
def smiles2indices(self, smiles: str) -> Tuple[List[int], List[List[int]]]:
features, nb_indices = self.vocab_func(smiles, nb_info=True)
return [self.w2i(word) for word in features], nb_indices
def get_substructures_from_atom(atom: Chem.Atom,
max_size: int,
substructure: Set[int] = None) -> Set[FrozenSet[int]]:
"""
Recursively gets all substructures up to a maximum size starting from an atom in a substructure.
:param atom: The atom to start at.
:param max_size: The maximum size of the substructure to fine.
:param substructure: The current substructure that atom is in.
:return: A set of substructures starting at atom where each substructure is a frozenset of indices.
"""
assert max_size >= 1
if substructure is None:
substructure = {atom.GetIdx()}
substructures = {frozenset(substructure)}
if len(substructure) == max_size:
return substructures
# Get neighbors which are not already in the substructure
new_neighbors = [neighbor for neighbor in atom.GetNeighbors() if neighbor.GetIdx() not in substructure]
for neighbor in new_neighbors:
# Define new substructure with neighbor
new_substructure = deepcopy(substructure)
new_substructure.add(neighbor.GetIdx())
# Skip if new substructure has already been considered
if frozenset(new_substructure) in substructures:
continue
# Recursively get substructures including this substructure plus neighbor
new_substructures = get_substructures_from_atom(neighbor, max_size, new_substructure)
# Add those substructures to current set of substructures
substructures |= new_substructures
return substructures
def get_substructures(atoms: List[Chem.Atom],
sizes: List[int],
max_count: int = None) -> Set[FrozenSet[int]]:
"""
Gets up to max_count substructures (frozenset of atom indices) from a molecule.
Note: Uses randomness to guarantee that the first max_count substructures
found are a random sample of the substructures in the molecule.
(It's not perfectly random, depending on the graph structure, but probably good enough
for our purposes. There's a bit of bias toward substructures on the periphery.)
:param atoms: A list of atoms in the molecule.
:param sizes: The sizes of substructures to find.
:param max_count: The maximum number of substructures to find.
:return: A set of substructures where each substructure is a frozenset of indices.
"""
max_count = max_count or float('inf')
random.shuffle(atoms)
substructures = set()
for atom in atoms:
# Get all substructures up to max size starting from atom
new_substructures = get_substructures_from_atom(atom, max(sizes))
# Filter substructures to those which are one of the desired sizes
new_substructures = [substructure for substructure in new_substructures if len(substructure) in sizes]
for new_substructure in new_substructures:
if len(substructures) >= max_count:
return substructures
substructures.add(new_substructure)
return substructures
def substructure_to_feature(mol: Chem.Mol,
substructure: FrozenSet[int],
fg_features: List[List[int]] = None) -> str:
"""
Converts a substructure (set of atom indices) to a feature string
by sorting and concatenating atom and bond feature vectors.
:param mol: A molecule.
:param substructure: A set of atom indices representing a substructure.
:param fg_features: A list of k-hot vector indicating the functional groups the atom belongs to.
:return: A string representing the featurization of the substructure.
"""
if fg_features is None:
fg_features = [None] * mol.GetNumAtoms()
substructure = list(substructure)
atoms = [Chem.Mol.GetAtomWithIdx(mol, idx) for idx in substructure]
bonds = []
for i in range(len(substructure)):
for j in range(i + 1, len(substructure)):
a1, a2 = substructure[i], substructure[j]
bond = mol.GetBondBetweenAtoms(a1, a2)
if bond is not None:
bonds.append(bond)
features = [str(atom_features(atom, fg_features[atom.GetIdx()])) for atom in atoms] + \
[str(bond_features(bond)) for bond in bonds]
features.sort() # ensure identical feature string for different atom/bond ordering
features = str(features)
return features
def atom_vocab(smiles: str,
vocab_func: str,
args: Namespace = None,
substructure_sizes: List[int] = None,
nb_info: bool = False) -> Union[List[str],
Tuple[List[str], List[List[int]]]]:
if vocab_func not in ['atom', 'atom_features', 'feature_vector', 'substructure']:
raise ValueError(f'vocab_func "{vocab_func}" not supported.')
mol = Chem.MolFromSmiles(smiles)
atoms = mol.GetAtoms()
if args is not None and \
('functional_group' in args.additional_atom_features or
'functional_group' in args.additional_output_features):
fg_featurizer = FunctionalGroupFeaturizer(args)
fg_features = fg_featurizer.featurize(mol)
else:
fg_features = [None] * len(atoms)
if vocab_func == 'feature_vector':
features = [atom_features(atom, fg) for atom, fg in zip(atoms, fg_features)]
elif vocab_func == 'atom_features':
features = [str(atom_features(atom, fg)) for atom, fg in zip(atoms, fg_features)]
elif vocab_func == 'atom':
features = [str(atom.GetAtomicNum()) for atom in atoms]
elif vocab_func == 'substructure':
substructures = get_substructures(list(atoms), substructure_sizes)
features = [substructure_to_feature(mol, substructure, fg_features) for substructure in substructures]
else:
raise ValueError(f'vocab_func "{vocab_func}" not supported.')
if nb_info:
nb_indices = []
for atom in atoms:
nb_indices.append([nb.GetIdx() for nb in atom.GetNeighbors()]) # atoms are sorted by idx
return features, nb_indices
return features
def vocab(pair: Tuple[Callable, str, bool]) -> Set[str]:
vocab_func, smiles, as_set = pair
return set(vocab_func(smiles, nb_info=False)) if as_set else vocab_func(smiles, nb_info=False)
def get_vocab(args: Namespace, vocab_func: Callable, smiles: List[str]) -> Set[str]:
sequential, max_vocab_size, smiles_to_sample = args.sequential, args.bert_max_vocab_size, args.bert_smiles_to_sample
if smiles_to_sample > 0 and smiles_to_sample < len(smiles):
random.shuffle(smiles)
smiles = smiles[:smiles_to_sample]
pairs = [(vocab_func, smile, max_vocab_size == 0) for smile in smiles]
if max_vocab_size == 0:
if sequential:
return set.union(*map(vocab, pairs))
with Pool() as pool:
return set.union(*pool.map(vocab, pairs))
else:
if sequential:
vocab_lists = map(vocab, pairs)
else:
with Pool() as pool:
vocab_lists = pool.map(vocab, pairs)
counter = Counter()
for elt_list in vocab_lists:
counter.update(elt_list)
return set([elt for elt, count in counter.most_common(max_vocab_size)])
def load_vocab(path: str) -> Vocab:
"""
Loads the Vocab a model was trained with.
:param path: Path where the model checkpoint is saved.
:return: The Vocab object that the model was trained with.
"""
return torch.load(path, map_location=lambda storage, loc: storage)['args'].vocab