From 4070ccf24d67802a08a9f70ca3e7567a0a53a7fc Mon Sep 17 00:00:00 2001 From: kevin <33508488+kjun9@users.noreply.github.com> Date: Wed, 19 Feb 2020 14:38:14 +1100 Subject: [PATCH] Make graphsage node attribute inference reproducible (#844) GraphSAGE NAI is tested for reproducibility, and a new internal class `SeededPerBatch` is introduced to make managing random states in a multi-threaded environment easier. Part of #749 --- CHANGELOG.md | 2 + stellargraph/data/explorer.py | 1 - .../mapper/sampled_link_generators.py | 46 ++---------- .../mapper/sampled_node_generators.py | 33 ++++++--- stellargraph/mapper/sequences.py | 10 ++- stellargraph/random.py | 30 ++++++++ tests/mapper/test_directed_node_generator.py | 4 +- tests/mapper/test_node_mappers.py | 58 +++++++-------- ...ervised_graphsage.py => test_graphsage.py} | 73 ++++++++++++++++++- tests/test_random.py | 38 ++++++++++ 10 files changed, 211 insertions(+), 84 deletions(-) rename tests/reproducibility/{test_unsupervised_graphsage.py => test_graphsage.py} (60%) create mode 100644 tests/test_random.py diff --git a/CHANGELOG.md b/CHANGELOG.md index be6b84319..4dd3600d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ Some new algorithms and features are still under active development, and are ava - CI: [\#760](https://github.com/stellargraph/stellargraph/pull/760) - More detailed information about Heterogeneous GraphSAGE (HinSAGE) has been added to StellarGraph's readthedocs documentation [\#839](https://github.com/stellargraph/stellargraph/pull/839). +- The following algorithms are now reproducible: + - Supervised GraphSAGE Node Attribute Inference [\#844](https://github.com/stellargraph/stellargraph/pull/844) ## [0.9.0](https://github.com/stellargraph/stellargraph/tree/v0.9.0) diff --git a/stellargraph/data/explorer.py b/stellargraph/data/explorer.py index 65704ebbf..5ad79e896 100644 --- a/stellargraph/data/explorer.py +++ b/stellargraph/data/explorer.py @@ -26,7 +26,6 @@ import numpy as np -import random import warnings from collections import defaultdict, deque from scipy import stats diff --git a/stellargraph/mapper/sampled_link_generators.py b/stellargraph/mapper/sampled_link_generators.py index 34e6ee49b..679e4711c 100644 --- a/stellargraph/mapper/sampled_link_generators.py +++ b/stellargraph/mapper/sampled_link_generators.py @@ -28,7 +28,6 @@ import operator import collections import abc -import threading import warnings from functools import reduce from tensorflow import keras @@ -41,7 +40,7 @@ ) from ..core.utils import is_real_iterable from . import LinkSequence, OnDemandLinkSequence -from ..random import random_state +from ..random import SeededPerBatch class BatchedLinkGenerator(abc.ABC): @@ -217,41 +216,12 @@ def __init__(self, G, batch_size, num_samples, seed=None, name=None): self.head_node_types = self.schema.node_types * 2 self._graph = G - self._batch_sampler_rs, _ = random_state(seed) - self._samplers = list() - self._lock = threading.Lock() - - def _sampler(self, batch_num): - """ - Get the sampler for a particular batch number. Each batch number has an associated sampler - with its own random state, so that batches being fetched in parallel do not interfere with - each other's random states. The seed for each sampler is a combination of the Sequence - object's seed and the batch number. For its intended use in a Keras/TF workflow, if there - are N batches in an epoch, there will be N samplers created, each corresponding to a - particular ``batch_num``. - - Args: - batch_num (int): Batch number - - Returns: - SampledBreadthFirstWalk object - """ - self._lock.acquire() - try: - return self._samplers[batch_num] - except IndexError: - # always create a new seeded sampler in ascending order of batch number - # this ensures seeds are deterministic even when batches are run in parallel - for n in range(len(self._samplers), batch_num + 1): - seed = self._batch_sampler_rs.randint(0, 2 ** 32 - 1) - self._samplers.append( - SampledBreadthFirstWalk( - self._graph, graph_schema=self.schema, seed=seed, - ) - ) - return self._samplers[batch_num] - finally: - self._lock.release() + self._samplers = SeededPerBatch( + lambda s: SampledBreadthFirstWalk( + self._graph, graph_schema=self.schema, seed=s + ), + seed=seed, + ) def sample_features(self, head_links, batch_num): """ @@ -290,7 +260,7 @@ def get_levels(loc, lsize, samples_per_hop, walks): # of 2 nodes, so we are extracting 2 head nodes per edge batch_feats = [] for hns in zip(*head_links): - node_samples = self._sampler(batch_num).run( + node_samples = self._samplers[batch_num].run( nodes=hns, n=1, n_size=self.num_samples ) diff --git a/stellargraph/mapper/sampled_node_generators.py b/stellargraph/mapper/sampled_node_generators.py index 2cfa985b5..c6e29daf7 100644 --- a/stellargraph/mapper/sampled_node_generators.py +++ b/stellargraph/mapper/sampled_node_generators.py @@ -46,6 +46,7 @@ from ..core.graph import StellarGraph, GraphSchema from ..core.utils import is_real_iterable from . import NodeSequence +from ..random import SeededPerBatch class BatchedNodeGenerator(abc.ABC): @@ -92,10 +93,10 @@ def __init__(self, G, batch_size, schema=None): self.sampler = None @abc.abstractmethod - def sample_features(self, head_nodes): + def sample_features(self, head_nodes, batch_num): pass - def flow(self, node_ids, targets=None, shuffle=False): + def flow(self, node_ids, targets=None, shuffle=False, seed=None): """ Creates a generator/sequence object for training or evaluation with the supplied node ids and numeric targets. @@ -144,7 +145,12 @@ def flow(self, node_ids, targets=None, shuffle=False): ) return NodeSequence( - self.sample_features, self.batch_size, node_ids, targets, shuffle=shuffle + self.sample_features, + self.batch_size, + node_ids, + targets, + shuffle=shuffle, + seed=seed, ) def flow_from_dataframe(self, node_targets, shuffle=False): @@ -209,9 +215,12 @@ def __init__(self, G, batch_size, num_samples, seed=None, name=None): ) # Create sampler for GraphSAGE - self.sampler = SampledBreadthFirstWalk(G, graph_schema=self.schema, seed=seed) + self._samplers = SeededPerBatch( + lambda s: SampledBreadthFirstWalk(G, graph_schema=self.schema, seed=s), + seed=seed, + ) - def sample_features(self, head_nodes): + def sample_features(self, head_nodes, batch_num): """ Sample neighbours recursively from the head nodes, collect the features of the sampled nodes, and return these as a list of feature arrays for the GraphSAGE @@ -219,6 +228,7 @@ def sample_features(self, head_nodes): Args: head_nodes: An iterable of head nodes to perform sampling on. + batch_num (int): Batch number Returns: A list of the same length as ``num_samples`` of collected features from @@ -227,7 +237,9 @@ def sample_features(self, head_nodes): where num_sampled_at_layer is the cumulative product of `num_samples` for that layer. """ - node_samples = self.sampler.run(nodes=head_nodes, n=1, n_size=self.num_samples) + node_samples = self._samplers[batch_num].run( + nodes=head_nodes, n=1, n_size=self.num_samples + ) # The number of samples for each head node (not including itself) num_full_samples = np.sum(np.cumprod(self.num_samples)) @@ -310,7 +322,7 @@ def __init__(self, G, batch_size, in_samples, out_samples, seed=None, name=None) G, graph_schema=self.schema, seed=seed ) - def sample_features(self, head_nodes): + def sample_features(self, head_nodes, batch_num): """ Sample neighbours recursively from the head nodes, collect the features of the sampled nodes, and return these as a list of feature arrays for the GraphSAGE @@ -318,6 +330,7 @@ def sample_features(self, head_nodes): Args: head_nodes: An iterable of head nodes to perform sampling on. + batch_num (int): Batch number Returns: A list of feature tensors from the sampled nodes at each layer, each of shape: @@ -418,7 +431,7 @@ def __init__( G, graph_schema=self.schema, seed=seed ) - def sample_features(self, head_nodes): + def sample_features(self, head_nodes, batch_num): """ Sample neighbours recursively from the head nodes, collect the features of the sampled nodes, and return these as a list of feature arrays for the GraphSAGE @@ -426,6 +439,7 @@ def sample_features(self, head_nodes): Args: head_nodes: An iterable of head nodes to perform sampling on. + batch_num (int): Batch number Returns: A list of the same length as ``num_samples`` of collected features from @@ -495,13 +509,14 @@ def __init__(self, G, batch_size, name=None): super().__init__(G, batch_size) self.name = name - def sample_features(self, head_nodes): + def sample_features(self, head_nodes, batch_num): """ Sample content features of the head nodes, and return these as a list of feature arrays for the attri2vec algorithm. Args: head_nodes: An iterable of head nodes to perform sampling on. + batch_num (int): Batch number Returns: A list of feature arrays, with each element being the feature of a diff --git a/stellargraph/mapper/sequences.py b/stellargraph/mapper/sequences.py index 10811f061..1b29c4297 100644 --- a/stellargraph/mapper/sequences.py +++ b/stellargraph/mapper/sequences.py @@ -40,6 +40,7 @@ from tensorflow.keras.utils import Sequence from ..data.unsupervised_sampler import UnsupervisedSampler from ..core.utils import is_real_iterable +from ..random import random_state class NodeSequence(Sequence): @@ -62,7 +63,9 @@ class NodeSequence(Sequence): shuffle (bool): If True (default) the ids will be randomly shuffled every epoch. """ - def __init__(self, sample_function, batch_size, ids, targets=None, shuffle=True): + def __init__( + self, sample_function, batch_size, ids, targets=None, shuffle=True, seed=None + ): # Check that ids is an iterable if not is_real_iterable(ids): raise TypeError("IDs must be an iterable or numpy array of graph node IDs") @@ -93,6 +96,7 @@ def __init__(self, sample_function, batch_size, ids, targets=None, shuffle=True) self.data_size = len(self.ids) self.shuffle = shuffle self.batch_size = batch_size + self._rs, _ = random_state(seed) # Shuffle IDs to start self.on_epoch_end() @@ -130,7 +134,7 @@ def __getitem__(self, batch_num): batch_targets = None if self.targets is None else self.targets[batch_indices] # Get features for nodes - batch_feats = self._sample_function(head_ids) + batch_feats = self._sample_function(head_ids, batch_num) return batch_feats, batch_targets @@ -140,7 +144,7 @@ def on_epoch_end(self): """ self.indices = list(range(self.data_size)) if self.shuffle: - random.shuffle(self.indices) + self._rs.shuffle(self.indices) class LinkSequence(Sequence): diff --git a/stellargraph/random.py b/stellargraph/random.py index 6394ba841..6f7b398c1 100644 --- a/stellargraph/random.py +++ b/stellargraph/random.py @@ -23,6 +23,7 @@ import random as rn import numpy.random as np_rn +import threading from collections import namedtuple @@ -75,3 +76,32 @@ def set_seed(seed): _rs = _global_state() else: _rs = _seeded_state(seed) + + +class SeededPerBatch: + """ + Internal utility class for managing a random state per batch number in a multi-threaded + environment. + + """ + + def __init__(self, create_with_seed, seed): + self._create_with_seed = create_with_seed + self._walkers = [] + self._lock = threading.Lock() + self._rs, _ = random_state(seed) + + def __getitem__(self, batch_num): + self._lock.acquire() + try: + return self._walkers[batch_num] + except IndexError: + # always create a new seeded sampler in ascending order of batch number + # this ensures seeds are deterministic even when batches are run in parallel + self._walkers.extend( + self._create_with_seed(self._rs.randrange(2 ** 32)) + for _ in range(len(self._walkers), batch_num + 1) + ) + return self._walkers[batch_num] + finally: + self._lock.release() diff --git a/tests/mapper/test_directed_node_generator.py b/tests/mapper/test_directed_node_generator.py index 882d575ad..ac4408749 100644 --- a/tests/mapper/test_directed_node_generator.py +++ b/tests/mapper/test_directed_node_generator.py @@ -56,7 +56,7 @@ def sample_one_hop(self, num_in_samples, num_out_samples): flow = gen.flow(node_ids=nodes, shuffle=False) # Obtain tree of sampled features - features = gen.sample_features(nodes) + features = gen.sample_features(nodes, 0) num_hops = len(in_samples) tree_len = 2 ** (num_hops + 1) - 1 assert len(features) == tree_len @@ -123,7 +123,7 @@ def test_two_hop(self): ) flow = gen.flow(node_ids=nodes, shuffle=False) - features = gen.sample_features(nodes) + features = gen.sample_features(nodes, 0) num_hops = 2 tree_len = 2 ** (num_hops + 1) - 1 assert len(features) == tree_len diff --git a/tests/mapper/test_node_mappers.py b/tests/mapper/test_node_mappers.py index 7a53caf55..d90b8eca7 100644 --- a/tests/mapper/test_node_mappers.py +++ b/tests/mapper/test_node_mappers.py @@ -191,45 +191,45 @@ def test_nodemapper_1(): GraphSAGENodeGenerator(G1, batch_size=2, num_samples=[2, 2]).flow(["A", "B"]) -def test_nodemapper_shuffle(): +@pytest.mark.parametrize("shuffle", [True, False]) +def test_nodemapper_shuffle(shuffle): n_feat = 1 n_batch = 2 G = example_graph_2(feature_size=n_feat) nodes = list(G.nodes()) - # With shuffle - random.seed(15) - mapper = GraphSAGENodeGenerator(G, batch_size=n_batch, num_samples=[0]).flow( - nodes, nodes, shuffle=True + def flatten_features(seq): + # check (features == labels) and return flattened features + batches = [ + (np.ravel(seq[i][0][0]), np.array(seq[i][1])) for i in range(len(seq)) + ] + features, labels = zip(*batches) + features, labels = np.concatenate(features), np.concatenate(labels) + assert all(features == labels) + return features + + def consecutive_epochs(seq): + features = flatten_features(seq) + seq.on_epoch_end() + features_next = flatten_features(seq) + return features, features_next + + seq = GraphSAGENodeGenerator(G, batch_size=n_batch, num_samples=[0]).flow( + nodes, nodes, shuffle=shuffle ) - expected_node_batches = [[5, 4], [3, 1], [2]] - assert len(mapper) == 3 - for ii in range(len(mapper)): - nf, nl = mapper[ii] - assert all(np.ravel(nf[0]) == expected_node_batches[ii]) - assert all(np.array(nl) == expected_node_batches[ii]) + max_iter = 5 + comparison_results = set() - # This should re-shuffle the IDs - mapper.on_epoch_end() - expected_node_batches = [[4, 3], [1, 5], [2]] - assert len(mapper) == 3 - for ii in range(len(mapper)): - nf, nl = mapper[ii] - assert all(np.ravel(nf[0]) == expected_node_batches[ii]) - assert all(np.array(nl) == expected_node_batches[ii]) + for i in range(max_iter): + f1, f2 = consecutive_epochs(seq) + comparison_results.add(all(f1 == f2)) - # With no shuffle - mapper = GraphSAGENodeGenerator(G, batch_size=n_batch, num_samples=[0]).flow( - nodes, nodes, shuffle=False - ) - expected_node_batches = [[1, 2], [3, 4], [5]] - assert len(mapper) == 3 - for ii in range(len(mapper)): - nf, nl = mapper[ii] - assert all(np.ravel(nf[0]) == expected_node_batches[ii]) - assert all(np.array(nl) == expected_node_batches[ii]) + if not shuffle: + assert comparison_results == {True} + else: + assert False in comparison_results def test_nodemapper_with_labels(): diff --git a/tests/reproducibility/test_unsupervised_graphsage.py b/tests/reproducibility/test_graphsage.py similarity index 60% rename from tests/reproducibility/test_unsupervised_graphsage.py rename to tests/reproducibility/test_graphsage.py index ac7960c55..4053ff616 100644 --- a/tests/reproducibility/test_unsupervised_graphsage.py +++ b/tests/reproducibility/test_graphsage.py @@ -15,10 +15,12 @@ # limitations under the License. +import numpy as np import pytest import random import tensorflow as tf from stellargraph.data.unsupervised_sampler import UnsupervisedSampler +from stellargraph.mapper.sampled_node_generators import GraphSAGENodeGenerator from stellargraph.mapper.sampled_link_generators import GraphSAGELinkGenerator from stellargraph.layer.graphsage import GraphSAGE from stellargraph.layer.link_inference import link_classification @@ -87,12 +89,68 @@ def unsup_gs( return model +def gs_nai_model(num_samples, generator, targets, optimizer, bias, dropout, normalize): + layer_sizes = [50] * len(num_samples) + graphsage = GraphSAGE( + layer_sizes=layer_sizes, + generator=generator, + bias=bias, + dropout=dropout, + normalize=normalize, + ) + # Build the model and expose input and output sockets of graphsage, for node pair inputs: + x_inp, x_out = graphsage.build() + pred = tf.keras.layers.Dense(units=targets.shape[1], activation="softmax")(x_out) + model = tf.keras.Model(inputs=x_inp, outputs=pred) + + model.compile(optimizer=optimizer, loss=tf.keras.losses.categorical_crossentropy) + + return model + + +def gs_nai( + g, + targets, + num_samples, + optimizer, + batch_size=4, + epochs=4, + bias=True, + dropout=0.0, + normalize="l2", + seed=0, + shuffle=True, +): + set_seed(seed) + tf.random.set_seed(seed) + if shuffle: + random.seed(seed) + + nodes = list(g.nodes()) + generator = GraphSAGENodeGenerator(g, batch_size, num_samples) + train_gen = generator.flow(nodes, targets, shuffle=True) + + model = gs_nai_model( + num_samples, generator, targets, optimizer, bias, dropout, normalize + ) + + model.fit_generator( + train_gen, + epochs=epochs, + verbose=1, + use_multiprocessing=False, + workers=4, + shuffle=shuffle, + ) + return model + + @pytest.mark.parametrize("shuffle", [True, False]) -def test_reproducibility(petersen_graph, shuffle): +def test_unsupervised(petersen_graph, shuffle): assert_reproducible( lambda: unsup_gs( petersen_graph, - [2], + [2, 2], tf.optimizers.Adam(1e-3), epochs=4, walk_length=2, @@ -100,3 +158,14 @@ def test_reproducibility(petersen_graph, shuffle): shuffle=shuffle, ) ) + + +@pytest.mark.parametrize("shuffle", [True, False]) +def test_nai(petersen_graph, shuffle): + target_size = 10 + targets = np.random.rand(len(petersen_graph.nodes()), target_size) + assert_reproducible( + lambda: gs_nai( + petersen_graph, targets, [2, 2], tf.optimizers.Adam(1e-3), shuffle=shuffle + ) + ) diff --git a/tests/test_random.py b/tests/test_random.py new file mode 100644 index 000000000..0b971d22d --- /dev/null +++ b/tests/test_random.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2020 Data61, CSIRO +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from stellargraph.random import SeededPerBatch +import numpy as np + + +def test_seeded_per_batch(): + num_batches = 10 + num_iter = 10 + seed = 0 + + # different permutations of batch numbers should always give the same seeds + batch_nums_perms = [np.random.permutation(num_batches) for _ in range(num_iter)] + + def get_batches(batch_nums): + s = SeededPerBatch(create_with_seed=lambda x: x, seed=seed) + batches = [0] * num_batches + + for batch_num in batch_nums: + batches[batch_num] = s[batch_num] + + return tuple(batches) + + assert len({get_batches(batch_nums) for batch_nums in batch_nums_perms}) == 1