-
Notifications
You must be signed in to change notification settings - Fork 425
/
sequences.py
525 lines (421 loc) · 19.5 KB
/
sequences.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
# -*- coding: utf-8 -*-
#
# Copyright 2018-2020 Data61, CSIRO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sequences to provide input to Keras
"""
__all__ = [
"NodeSequence",
"LinkSequence",
"OnDemandLinkSequence",
"FullBatchSequence",
"SparseFullBatchSequence",
"RelationalFullBatchNodeSequence",
]
import warnings
import operator
import random
import collections
import numpy as np
import itertools as it
import networkx as nx
import scipy.sparse as sps
from tensorflow.keras import backend as K
from functools import reduce
from tensorflow.keras.utils import Sequence
from ..data.unsupervised_sampler import UnsupervisedSampler
from ..core.utils import is_real_iterable
from ..random import random_state
class NodeSequence(Sequence):
"""Keras-compatible data generator to use with the Keras
methods :meth:`keras.Model.fit_generator`, :meth:`keras.Model.evaluate_generator`,
and :meth:`keras.Model.predict_generator`.
This class generated data samples for node inference models
and should be created using the `.flow(...)` method of
:class:`GraphSAGENodeGenerator` or :class:`DirectedGraphSAGENodeGenerator`
or :class:`HinSAGENodeGenerator` or :class:`Attri2VecNodeGenerator`.
These generator classes are used within the NodeSequence to generate
the required features for downstream ML tasks from the graph.
Args:
sample_function (Callable): A function that returns features for supplied head nodes.
ids (list): A list of the node_ids to be used as head-nodes in the downstream task.
targets (list, optional): A list of targets or labels to be used in the downstream task.
shuffle (bool): If True (default) the ids will be randomly shuffled every epoch.
"""
def __init__(
self, sample_function, batch_size, ids, targets=None, shuffle=True, seed=None
):
# Check that ids is an iterable
if not is_real_iterable(ids):
raise TypeError("IDs must be an iterable or numpy array of graph node IDs")
# Check targets is iterable & has the correct length
if targets is not None:
if not is_real_iterable(targets):
raise TypeError("Targets must be None or an iterable or numpy array ")
if len(ids) != len(targets):
raise ValueError(
"The length of the targets must be the same as the length of the ids"
)
self.targets = np.asanyarray(targets)
else:
self.targets = None
# Store the generator to draw samples from graph
if isinstance(sample_function, collections.Callable):
self._sample_function = sample_function
else:
raise TypeError(
"({}) The sampling function expects a callable function.".format(
type(self).__name__
)
)
self.ids = list(ids)
self.data_size = len(self.ids)
self.shuffle = shuffle
self.batch_size = batch_size
self._rs, _ = random_state(seed)
# Shuffle IDs to start
self.on_epoch_end()
def __len__(self):
"""Denotes the number of batches per epoch"""
return int(np.ceil(self.data_size / self.batch_size))
def __getitem__(self, batch_num):
"""
Generate one batch of data
Args:
batch_num (int): number of a batch
Returns:
batch_feats (list): Node features for nodes and neighbours sampled from a
batch of the supplied IDs
batch_targets (list): Targets/labels for the batch.
"""
start_idx = self.batch_size * batch_num
end_idx = start_idx + self.batch_size
if start_idx >= self.data_size:
raise IndexError("Mapper: batch_num larger than length of data")
# print("Fetching batch {} [{}]".format(batch_num, start_idx))
# The ID indices for this batch
batch_indices = self.indices[start_idx:end_idx]
# Get head (root) nodes
head_ids = [self.ids[ii] for ii in batch_indices]
# Get corresponding targets
batch_targets = None if self.targets is None else self.targets[batch_indices]
# Get features for nodes
batch_feats = self._sample_function(head_ids, batch_num)
return batch_feats, batch_targets
def on_epoch_end(self):
"""
Shuffle all head (root) nodes at the end of each epoch
"""
self.indices = list(range(self.data_size))
if self.shuffle:
self._rs.shuffle(self.indices)
class LinkSequence(Sequence):
"""
Keras-compatible data generator to use with Keras methods :meth:`keras.Model.fit_generator`,
:meth:`keras.Model.evaluate_generator`, and :meth:`keras.Model.predict_generator`
This class generates data samples for link inference models
and should be created using the :meth:`flow` method of
:class:`GraphSAGELinkGenerator` or :class:`HinSAGELinkGenerator` or :class:`Attri2VecLinkGenerator`.
Args:
sample_function (Callable): A function that returns features for supplied head nodes.
ids (iterable): Link IDs to batch, each link id being a tuple of (src, dst) node ids.
targets (list, optional): A list of targets or labels to be used in the downstream task.
shuffle (bool): If True (default) the ids will be randomly shuffled every epoch.
"""
def __init__(self, sample_function, batch_size, ids, targets=None, shuffle=True):
# Check that ids is an iterable
if not is_real_iterable(ids):
raise TypeError("IDs must be an iterable or numpy array of graph node IDs")
# Check targets is iterable & has the correct length
if targets is not None:
if not is_real_iterable(targets):
raise TypeError("Targets must be None or an iterable or numpy array ")
if len(ids) != len(targets):
raise ValueError(
"The length of the targets must be the same as the length of the ids"
)
self.targets = np.asanyarray(targets)
else:
self.targets = None
# Ensure number of labels matches number of ids
if targets is not None and len(ids) != len(targets):
raise ValueError("Length of link ids must match length of link targets")
# Store the generator to draw samples from graph
if isinstance(sample_function, collections.Callable):
self._sample_features = sample_function
else:
raise TypeError(
"({}) The sampling function expects a callable function.".format(
type(self).__name__
)
)
self.batch_size = batch_size
self.ids = list(ids)
self.data_size = len(self.ids)
self.shuffle = shuffle
# Shuffle the IDs to begin
self.on_epoch_end()
def __len__(self):
"""Denotes the number of batches per epoch"""
return int(np.ceil(self.data_size / self.batch_size))
def __getitem__(self, batch_num):
"""
Generate one batch of data
Args:
batch_num (int): number of a batch
Returns:
batch_feats (list): Node features for nodes and neighbours sampled from a
batch of the supplied IDs
batch_targets (list): Targets/labels for the batch.
"""
start_idx = self.batch_size * batch_num
end_idx = start_idx + self.batch_size
if start_idx >= self.data_size:
raise IndexError("Mapper: batch_num larger than length of data")
# print("Fetching {} batch {} [{}]".format(self.name, batch_num, start_idx))
# The ID indices for this batch
batch_indices = self.indices[start_idx:end_idx]
# Get head (root) nodes for links
head_ids = [self.ids[ii] for ii in batch_indices]
# Get targets for nodes
batch_targets = None if self.targets is None else self.targets[batch_indices]
# Get node features for batch of link ids
batch_feats = self._sample_features(head_ids, batch_num)
return batch_feats, batch_targets
def on_epoch_end(self):
"""
Shuffle all link IDs at the end of each epoch
"""
self.indices = list(range(self.data_size))
if self.shuffle:
random.shuffle(self.indices)
class OnDemandLinkSequence(Sequence):
"""
Keras-compatible data generator to use with Keras methods :meth:`keras.Model.fit_generator`,
:meth:`keras.Model.evaluate_generator`, and :meth:`keras.Model.predict_generator`
This class generates data samples for link inference models
and should be created using the :meth:`flow` method of
:class:`GraphSAGELinkGenerator` or :class:`Attri2VecLinkGenerator`.
Args:
sample_function (Callable): A function that returns features for supplied head nodes.
sampler (UnsupersizedSampler): An object that encapsulates the neighbourhood sampling of a graph.
The generator method of this class returns a batch of positive and negative samples on demand.
"""
def __init__(self, sample_function, batch_size, walker, shuffle=True):
# Store the generator to draw samples from graph
if isinstance(sample_function, collections.Callable):
self._sample_features = sample_function
else:
raise TypeError(
"({}) The sampling function expects a callable function.".format(
type(self).__name__
)
)
if not isinstance(walker, UnsupervisedSampler):
raise TypeError(
"({}) UnsupervisedSampler is required.".format(type(self).__name__)
)
self.batch_size = batch_size
self.walker = walker
self.shuffle = shuffle
# FIXME(#681): all batches are created at once, so this is no longer "on demand"
self._batches = self._create_batches()
self.length = len(self._batches)
self.data_size = sum(len(batch[0]) for batch in self._batches)
def __getitem__(self, batch_num):
"""
Generate one batch of data.
Args:
batch_num<int>: number of a batch
Returns:
batch_feats<list>: Node features for nodes and neighbours sampled from a
batch of the supplied IDs
batch_targets<list>: Targets/labels for the batch.
"""
if batch_num >= self.__len__():
raise IndexError(
"Mapper: batch_num larger than number of esstaimted batches for this epoch."
)
# print("Fetching {} batch {} [{}]".format(self.name, batch_num, start_idx))
# Get head nodes and labels
head_ids, batch_targets = self._batches[batch_num]
# Obtain features for head ids
batch_feats = self._sample_features(head_ids, batch_num)
return batch_feats, batch_targets
def __len__(self):
"""Denotes the number of batches per epoch"""
return self.length
def _create_batches(self):
return self.walker.run(self.batch_size)
def on_epoch_end(self):
"""
Shuffle all link IDs at the end of each epoch
"""
if self.shuffle:
self._batches = self._create_batches()
def _full_batch_array_and_reshape(array, propagate_none=False):
"""
Args:
array: an array-like object
propagate_none: if True, return None when array is None
Returns:
array as a numpy array with an extra first dimension (batch dimension) equal to 1
"""
# if it's ok, just short-circuit on None (e.g. for target arrays, that may or may not exist)
if propagate_none and array is None:
return None
as_np = np.asanyarray(array)
return np.reshape(as_np, (1,) + as_np.shape)
class FullBatchSequence(Sequence):
"""
Keras-compatible data generator for for node inference models
that require full-batch training (e.g., GCN, GAT).
Use this class with the Keras methods :meth:`keras.Model.fit_generator`,
:meth:`keras.Model.evaluate_generator`, and
:meth:`keras.Model.predict_generator`,
This class should be created using the `.flow(...)` method of
:class:`FullBatchNodeGenerator`.
Args:
features (np.ndarray): An array of node features of size (N x F),
where N is the number of nodes in the graph, F is the node feature size
A (np.ndarray or sparse matrix): An adjacency matrix of the graph of size (N x N).
targets (np.ndarray, optional): An optional array of node targets of size (N x C),
where C is the target size (e.g., number of classes for one-hot class targets)
indices (np.ndarray, optional): Array of indices to the feature and adjacency matrix
of the targets. Required if targets is not None.
"""
use_sparse = False
def __init__(self, features, A, targets=None, indices=None):
if (targets is not None) and (len(indices) != len(targets)):
raise ValueError(
"When passed together targets and indices should be the same length."
)
# Store features and targets as np.ndarray
self.features = np.asanyarray(features)
self.target_indices = np.asanyarray(indices)
# Convert sparse matrix to dense:
if sps.issparse(A) and hasattr(A, "toarray"):
self.A_dense = _full_batch_array_and_reshape(A.toarray())
elif isinstance(A, (np.ndarray, np.matrix)):
self.A_dense = _full_batch_array_and_reshape(A)
else:
raise TypeError(
"Expected input matrix to be either a Scipy sparse matrix or a Numpy array."
)
# Reshape all inputs to have batch dimension of 1
self.features = _full_batch_array_and_reshape(features)
self.target_indices = _full_batch_array_and_reshape(indices)
self.inputs = [self.features, self.target_indices, self.A_dense]
self.targets = _full_batch_array_and_reshape(targets, propagate_none=True)
def __len__(self):
return 1
def __getitem__(self, index):
return self.inputs, self.targets
class SparseFullBatchSequence(Sequence):
"""
Keras-compatible data generator for for node inference models
that require full-batch training (e.g., GCN, GAT).
Use this class with the Keras methods :meth:`keras.Model.fit_generator`,
:meth:`keras.Model.evaluate_generator`, and
:meth:`keras.Model.predict_generator`,
This class uses sparse matrix representations to send data to the models,
and only works with the Keras tensorflow backend. For any other backends,
use the :class:`FullBatchSequence` class.
This class should be created using the `.flow(...)` method of
:class:`FullBatchNodeGenerator`.
Args:
features (np.ndarray): An array of node features of size (N x F),
where N is the number of nodes in the graph, F is the node feature size
A (sparse matrix): An adjacency matrix of the graph of size (N x N).
targets (np.ndarray, optional): An optional array of node targets of size (N x C),
where C is the target size (e.g., number of classes for one-hot class targets)
indices (np.ndarray, optional): Array of indices to the feature and adjacency matrix
of the targets. Required if targets is not None.
"""
use_sparse = True
def __init__(self, features, A, targets=None, indices=None):
if (targets is not None) and (len(indices) != len(targets)):
raise ValueError(
"When passed together targets and indices should be the same length."
)
# Ensure matrix is in COO format to extract indices
if sps.isspmatrix(A):
A = A.tocoo()
else:
raise ValueError("Adjacency matrix not in expected sparse format")
# Convert matrices to list of indices & values
self.A_indices = np.expand_dims(
np.hstack((A.row[:, None], A.col[:, None])), 0
).astype("int64")
self.A_values = np.expand_dims(A.data, 0)
# Reshape all inputs to have batch dimension of 1
self.target_indices = _full_batch_array_and_reshape(indices)
self.features = _full_batch_array_and_reshape(features)
self.inputs = [
self.features,
self.target_indices,
self.A_indices,
self.A_values,
]
self.targets = _full_batch_array_and_reshape(targets, propagate_none=True)
def __len__(self):
return 1
def __getitem__(self, index):
return self.inputs, self.targets
class RelationalFullBatchNodeSequence(Sequence):
"""
Keras-compatible data generator for for node inference models on relational graphs
that require full-batch training (e.g., RGCN).
Use this class with the Keras methods :meth:`keras.Model.fit_generator`,
:meth:`keras.Model.evaluate_generator`, and
:meth:`keras.Model.predict_generator`,
This class uses either dense or sparse representations to send data to the models.
This class should be created using the `.flow(...)` method of
:class:`RelationalFullBatchNodeGenerator`.
Args:
features (np.ndarray): An array of node features of size (N x F),
where N is the number of nodes in the graph, F is the node feature size
As (list of sparse matrices): A list of length R of adjacency matrices of the graph of size (N x N)
where R is the number of relationships in the graph.
targets (np.ndarray, optional): An optional array of node targets of size (N x C),
where C is the target size (e.g., number of classes for one-hot class targets)
indices (np.ndarray, optional): Array of indices to the feature and adjacency matrix
of the targets. Required if targets is not None.
"""
def __init__(self, features, As, use_sparse, targets=None, indices=None):
if (targets is not None) and (len(indices) != len(targets)):
raise ValueError(
"When passed together targets and indices should be the same length."
)
self.use_sparse = use_sparse
# Convert all adj matrices to dense and reshape to have batch dimension of 1
if self.use_sparse:
self.A_indices = [
np.expand_dims(np.hstack((A.row[:, None], A.col[:, None])), 0)
for A in As
]
self.A_values = [np.expand_dims(A.data, 0) for A in As]
self.As = self.A_indices + self.A_values
else:
self.As = [np.expand_dims(A.todense(), 0) for A in As]
# Make sure all inputs are numpy arrays, and have batch dimension of 1
self.target_indices = _full_batch_array_and_reshape(indices)
self.features = _full_batch_array_and_reshape(features)
self.inputs = [self.features, self.target_indices] + self.As
self.targets = _full_batch_array_and_reshape(targets, propagate_none=True)
def __len__(self):
return 1
def __getitem__(self, index):
return self.inputs, self.targets