Skip to content

Commit

Permalink
Make graphsage node attribute inference reproducible (#844)
Browse files Browse the repository at this point in the history
GraphSAGE NAI is tested for reproducibility, and a new internal class 
`SeededPerBatch` is introduced to make managing random states in a 
multi-threaded environment easier.

Part of #749
  • Loading branch information
kjun9 committed Feb 19, 2020
1 parent c13a988 commit 4070ccf
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 84 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -29,6 +29,8 @@ Some new algorithms and features are still under active development, and are ava
- CI: [\#760](https://github.com/stellargraph/stellargraph/pull/760)

- More detailed information about Heterogeneous GraphSAGE (HinSAGE) has been added to StellarGraph's readthedocs documentation [\#839](https://github.com/stellargraph/stellargraph/pull/839).
- The following algorithms are now reproducible:
- Supervised GraphSAGE Node Attribute Inference [\#844](https://github.com/stellargraph/stellargraph/pull/844)

## [0.9.0](https://github.com/stellargraph/stellargraph/tree/v0.9.0)

Expand Down
1 change: 0 additions & 1 deletion stellargraph/data/explorer.py
Expand Up @@ -26,7 +26,6 @@


import numpy as np
import random
import warnings
from collections import defaultdict, deque
from scipy import stats
Expand Down
46 changes: 8 additions & 38 deletions stellargraph/mapper/sampled_link_generators.py
Expand Up @@ -28,7 +28,6 @@
import operator
import collections
import abc
import threading
import warnings
from functools import reduce
from tensorflow import keras
Expand All @@ -41,7 +40,7 @@
)
from ..core.utils import is_real_iterable
from . import LinkSequence, OnDemandLinkSequence
from ..random import random_state
from ..random import SeededPerBatch


class BatchedLinkGenerator(abc.ABC):
Expand Down Expand Up @@ -217,41 +216,12 @@ def __init__(self, G, batch_size, num_samples, seed=None, name=None):
self.head_node_types = self.schema.node_types * 2

self._graph = G
self._batch_sampler_rs, _ = random_state(seed)
self._samplers = list()
self._lock = threading.Lock()

def _sampler(self, batch_num):
"""
Get the sampler for a particular batch number. Each batch number has an associated sampler
with its own random state, so that batches being fetched in parallel do not interfere with
each other's random states. The seed for each sampler is a combination of the Sequence
object's seed and the batch number. For its intended use in a Keras/TF workflow, if there
are N batches in an epoch, there will be N samplers created, each corresponding to a
particular ``batch_num``.
Args:
batch_num (int): Batch number
Returns:
SampledBreadthFirstWalk object
"""
self._lock.acquire()
try:
return self._samplers[batch_num]
except IndexError:
# always create a new seeded sampler in ascending order of batch number
# this ensures seeds are deterministic even when batches are run in parallel
for n in range(len(self._samplers), batch_num + 1):
seed = self._batch_sampler_rs.randint(0, 2 ** 32 - 1)
self._samplers.append(
SampledBreadthFirstWalk(
self._graph, graph_schema=self.schema, seed=seed,
)
)
return self._samplers[batch_num]
finally:
self._lock.release()
self._samplers = SeededPerBatch(
lambda s: SampledBreadthFirstWalk(
self._graph, graph_schema=self.schema, seed=s
),
seed=seed,
)

def sample_features(self, head_links, batch_num):
"""
Expand Down Expand Up @@ -290,7 +260,7 @@ def get_levels(loc, lsize, samples_per_hop, walks):
# of 2 nodes, so we are extracting 2 head nodes per edge
batch_feats = []
for hns in zip(*head_links):
node_samples = self._sampler(batch_num).run(
node_samples = self._samplers[batch_num].run(
nodes=hns, n=1, n_size=self.num_samples
)

Expand Down
33 changes: 24 additions & 9 deletions stellargraph/mapper/sampled_node_generators.py
Expand Up @@ -46,6 +46,7 @@
from ..core.graph import StellarGraph, GraphSchema
from ..core.utils import is_real_iterable
from . import NodeSequence
from ..random import SeededPerBatch


class BatchedNodeGenerator(abc.ABC):
Expand Down Expand Up @@ -92,10 +93,10 @@ def __init__(self, G, batch_size, schema=None):
self.sampler = None

@abc.abstractmethod
def sample_features(self, head_nodes):
def sample_features(self, head_nodes, batch_num):
pass

def flow(self, node_ids, targets=None, shuffle=False):
def flow(self, node_ids, targets=None, shuffle=False, seed=None):
"""
Creates a generator/sequence object for training or evaluation
with the supplied node ids and numeric targets.
Expand Down Expand Up @@ -144,7 +145,12 @@ def flow(self, node_ids, targets=None, shuffle=False):
)

return NodeSequence(
self.sample_features, self.batch_size, node_ids, targets, shuffle=shuffle
self.sample_features,
self.batch_size,
node_ids,
targets,
shuffle=shuffle,
seed=seed,
)

def flow_from_dataframe(self, node_targets, shuffle=False):
Expand Down Expand Up @@ -209,16 +215,20 @@ def __init__(self, G, batch_size, num_samples, seed=None, name=None):
)

# Create sampler for GraphSAGE
self.sampler = SampledBreadthFirstWalk(G, graph_schema=self.schema, seed=seed)
self._samplers = SeededPerBatch(
lambda s: SampledBreadthFirstWalk(G, graph_schema=self.schema, seed=s),
seed=seed,
)

def sample_features(self, head_nodes):
def sample_features(self, head_nodes, batch_num):
"""
Sample neighbours recursively from the head nodes, collect the features of the
sampled nodes, and return these as a list of feature arrays for the GraphSAGE
algorithm.
Args:
head_nodes: An iterable of head nodes to perform sampling on.
batch_num (int): Batch number
Returns:
A list of the same length as ``num_samples`` of collected features from
Expand All @@ -227,7 +237,9 @@ def sample_features(self, head_nodes):
where num_sampled_at_layer is the cumulative product of `num_samples`
for that layer.
"""
node_samples = self.sampler.run(nodes=head_nodes, n=1, n_size=self.num_samples)
node_samples = self._samplers[batch_num].run(
nodes=head_nodes, n=1, n_size=self.num_samples
)

# The number of samples for each head node (not including itself)
num_full_samples = np.sum(np.cumprod(self.num_samples))
Expand Down Expand Up @@ -310,14 +322,15 @@ def __init__(self, G, batch_size, in_samples, out_samples, seed=None, name=None)
G, graph_schema=self.schema, seed=seed
)

def sample_features(self, head_nodes):
def sample_features(self, head_nodes, batch_num):
"""
Sample neighbours recursively from the head nodes, collect the features of the
sampled nodes, and return these as a list of feature arrays for the GraphSAGE
algorithm.
Args:
head_nodes: An iterable of head nodes to perform sampling on.
batch_num (int): Batch number
Returns:
A list of feature tensors from the sampled nodes at each layer, each of shape:
Expand Down Expand Up @@ -418,14 +431,15 @@ def __init__(
G, graph_schema=self.schema, seed=seed
)

def sample_features(self, head_nodes):
def sample_features(self, head_nodes, batch_num):
"""
Sample neighbours recursively from the head nodes, collect the features of the
sampled nodes, and return these as a list of feature arrays for the GraphSAGE
algorithm.
Args:
head_nodes: An iterable of head nodes to perform sampling on.
batch_num (int): Batch number
Returns:
A list of the same length as ``num_samples`` of collected features from
Expand Down Expand Up @@ -495,13 +509,14 @@ def __init__(self, G, batch_size, name=None):
super().__init__(G, batch_size)
self.name = name

def sample_features(self, head_nodes):
def sample_features(self, head_nodes, batch_num):
"""
Sample content features of the head nodes, and return these as a list of feature
arrays for the attri2vec algorithm.
Args:
head_nodes: An iterable of head nodes to perform sampling on.
batch_num (int): Batch number
Returns:
A list of feature arrays, with each element being the feature of a
Expand Down
10 changes: 7 additions & 3 deletions stellargraph/mapper/sequences.py
Expand Up @@ -40,6 +40,7 @@
from tensorflow.keras.utils import Sequence
from ..data.unsupervised_sampler import UnsupervisedSampler
from ..core.utils import is_real_iterable
from ..random import random_state


class NodeSequence(Sequence):
Expand All @@ -62,7 +63,9 @@ class NodeSequence(Sequence):
shuffle (bool): If True (default) the ids will be randomly shuffled every epoch.
"""

def __init__(self, sample_function, batch_size, ids, targets=None, shuffle=True):
def __init__(
self, sample_function, batch_size, ids, targets=None, shuffle=True, seed=None
):
# Check that ids is an iterable
if not is_real_iterable(ids):
raise TypeError("IDs must be an iterable or numpy array of graph node IDs")
Expand Down Expand Up @@ -93,6 +96,7 @@ def __init__(self, sample_function, batch_size, ids, targets=None, shuffle=True)
self.data_size = len(self.ids)
self.shuffle = shuffle
self.batch_size = batch_size
self._rs, _ = random_state(seed)

# Shuffle IDs to start
self.on_epoch_end()
Expand Down Expand Up @@ -130,7 +134,7 @@ def __getitem__(self, batch_num):
batch_targets = None if self.targets is None else self.targets[batch_indices]

# Get features for nodes
batch_feats = self._sample_function(head_ids)
batch_feats = self._sample_function(head_ids, batch_num)

return batch_feats, batch_targets

Expand All @@ -140,7 +144,7 @@ def on_epoch_end(self):
"""
self.indices = list(range(self.data_size))
if self.shuffle:
random.shuffle(self.indices)
self._rs.shuffle(self.indices)


class LinkSequence(Sequence):
Expand Down
30 changes: 30 additions & 0 deletions stellargraph/random.py
Expand Up @@ -23,6 +23,7 @@

import random as rn
import numpy.random as np_rn
import threading
from collections import namedtuple


Expand Down Expand Up @@ -75,3 +76,32 @@ def set_seed(seed):
_rs = _global_state()
else:
_rs = _seeded_state(seed)


class SeededPerBatch:
"""
Internal utility class for managing a random state per batch number in a multi-threaded
environment.
"""

def __init__(self, create_with_seed, seed):
self._create_with_seed = create_with_seed
self._walkers = []
self._lock = threading.Lock()
self._rs, _ = random_state(seed)

def __getitem__(self, batch_num):
self._lock.acquire()
try:
return self._walkers[batch_num]
except IndexError:
# always create a new seeded sampler in ascending order of batch number
# this ensures seeds are deterministic even when batches are run in parallel
self._walkers.extend(
self._create_with_seed(self._rs.randrange(2 ** 32))
for _ in range(len(self._walkers), batch_num + 1)
)
return self._walkers[batch_num]
finally:
self._lock.release()
4 changes: 2 additions & 2 deletions tests/mapper/test_directed_node_generator.py
Expand Up @@ -56,7 +56,7 @@ def sample_one_hop(self, num_in_samples, num_out_samples):
flow = gen.flow(node_ids=nodes, shuffle=False)

# Obtain tree of sampled features
features = gen.sample_features(nodes)
features = gen.sample_features(nodes, 0)
num_hops = len(in_samples)
tree_len = 2 ** (num_hops + 1) - 1
assert len(features) == tree_len
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_two_hop(self):
)
flow = gen.flow(node_ids=nodes, shuffle=False)

features = gen.sample_features(nodes)
features = gen.sample_features(nodes, 0)
num_hops = 2
tree_len = 2 ** (num_hops + 1) - 1
assert len(features) == tree_len
Expand Down
58 changes: 29 additions & 29 deletions tests/mapper/test_node_mappers.py
Expand Up @@ -191,45 +191,45 @@ def test_nodemapper_1():
GraphSAGENodeGenerator(G1, batch_size=2, num_samples=[2, 2]).flow(["A", "B"])


def test_nodemapper_shuffle():
@pytest.mark.parametrize("shuffle", [True, False])
def test_nodemapper_shuffle(shuffle):
n_feat = 1
n_batch = 2

G = example_graph_2(feature_size=n_feat)
nodes = list(G.nodes())

# With shuffle
random.seed(15)
mapper = GraphSAGENodeGenerator(G, batch_size=n_batch, num_samples=[0]).flow(
nodes, nodes, shuffle=True
def flatten_features(seq):
# check (features == labels) and return flattened features
batches = [
(np.ravel(seq[i][0][0]), np.array(seq[i][1])) for i in range(len(seq))
]
features, labels = zip(*batches)
features, labels = np.concatenate(features), np.concatenate(labels)
assert all(features == labels)
return features

def consecutive_epochs(seq):
features = flatten_features(seq)
seq.on_epoch_end()
features_next = flatten_features(seq)
return features, features_next

seq = GraphSAGENodeGenerator(G, batch_size=n_batch, num_samples=[0]).flow(
nodes, nodes, shuffle=shuffle
)

expected_node_batches = [[5, 4], [3, 1], [2]]
assert len(mapper) == 3
for ii in range(len(mapper)):
nf, nl = mapper[ii]
assert all(np.ravel(nf[0]) == expected_node_batches[ii])
assert all(np.array(nl) == expected_node_batches[ii])
max_iter = 5
comparison_results = set()

# This should re-shuffle the IDs
mapper.on_epoch_end()
expected_node_batches = [[4, 3], [1, 5], [2]]
assert len(mapper) == 3
for ii in range(len(mapper)):
nf, nl = mapper[ii]
assert all(np.ravel(nf[0]) == expected_node_batches[ii])
assert all(np.array(nl) == expected_node_batches[ii])
for i in range(max_iter):
f1, f2 = consecutive_epochs(seq)
comparison_results.add(all(f1 == f2))

# With no shuffle
mapper = GraphSAGENodeGenerator(G, batch_size=n_batch, num_samples=[0]).flow(
nodes, nodes, shuffle=False
)
expected_node_batches = [[1, 2], [3, 4], [5]]
assert len(mapper) == 3
for ii in range(len(mapper)):
nf, nl = mapper[ii]
assert all(np.ravel(nf[0]) == expected_node_batches[ii])
assert all(np.array(nl) == expected_node_batches[ii])
if not shuffle:
assert comparison_results == {True}
else:
assert False in comparison_results


def test_nodemapper_with_labels():
Expand Down

0 comments on commit 4070ccf

Please sign in to comment.