Skip to content

Commit

Permalink
Add in and out of scope() tests for various strategy methods.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 239850687
  • Loading branch information
tomhennigan authored and tensorflower-gardener committed Mar 22, 2019
1 parent e0e9847 commit f648dae
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow/python/distribute/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,17 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":distribute_lib",
":input_lib",
":reduce_util",
":values",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)

Expand Down
137 changes: 137 additions & 0 deletions tensorflow/python/distribute/distribute_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,21 @@
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import nest


class _TestReplicaContext(distribute_lib.ReplicaContext):
Expand All @@ -41,6 +49,11 @@ def _get_test_variable(name, synchronization, aggregation):
}


def _test_input_fn(input_context):
del input_context
return dataset_ops.DatasetV2.from_tensors(1.).repeat()


class _TestStrategy(distribute_lib.DistributionStrategy):

def __init__(self):
Expand All @@ -49,6 +62,13 @@ def __init__(self):

class _TestExtended(distribute_lib.DistributionStrategyExtended):

def __init__(self, distribute):
super(_TestExtended, self).__init__(distribute)
device_map = values.ReplicaDeviceMap(["/device:CPU:0"])
worker_device_pairs = [("", ["/device:CPU:0"])]
self._input_workers = input_lib.InputWorkers(device_map,
worker_device_pairs)

def _call_for_each_replica(self, fn, args, kwargs):
with _TestReplicaContext(
self._container_strategy(),
Expand All @@ -59,6 +79,45 @@ def _create_variable(self, next_creator, *args, **kwargs):
return _get_test_variable(kwargs["name"], kwargs["synchronization"],
kwargs["aggregation"])

def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
return input_lib.InputFunctionIterator(
input_fn, self._input_workers, [distribute_lib.InputContext()])

def _local_results(self, value):
return (value,)

def _reduce_to(self, reduce_op, value, destinations):
del reduce_op, destinations
return value

def _experimental_make_numpy_dataset(self, numpy_input, session):
del session
return dataset_ops.DatasetV2.from_tensor_slices(numpy_input)

def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
initial_loop_values=None):
# TODO(tomhennigan) This is missing many things (e.g. ctx.run_op).
ctx = input_lib.MultiStepContext()
for _ in range(iterations):
fn(ctx, iterator.get_next())
return ctx

def _update(self, var, fn, args, kwargs, group):
# The implementations of _update() and _update_non_slot() are identical
# except _update() passes `var` as the first argument to `fn()`.
return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)

def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
del colocate_with
result = fn(*args, **kwargs)
if group:
return result
else:
return nest.map_structure(self._unwrap, result)


def _assert_in_default_state(t):
t.assertIs(ds_context._get_default_replica_context(),
Expand All @@ -69,6 +128,25 @@ def _assert_in_default_state(t):
t.assertFalse(ds_context.has_strategy())


def _run_in_and_out_of_scope(unbound_test_method):
def wrapper(test_case):
dist = _TestStrategy()
# Running in the default (replica) scope should be supported.
_assert_in_default_state(test_case)
unbound_test_method(test_case, dist)
# As well as running in the strategy scope.
with dist.scope():
unbound_test_method(test_case, dist)
_assert_in_default_state(test_case)
# When run under a different strategy the test method should fail.
another_strategy = _TestStrategy()
msg = "Mixing different .*Strategy objects"
with test_case.assertRaisesRegexp(RuntimeError, msg):
with another_strategy.scope():
unbound_test_method(test_case, dist)
return wrapper


class TestStrategyTest(test.TestCase):

def testCallForEachReplica(self):
Expand Down Expand Up @@ -236,6 +314,65 @@ def testSameScopeNesting(self):
self.assertIs(dist, ds_context.get_strategy())
_assert_in_default_state(self)

@_run_in_and_out_of_scope
def testMakeInputFnIterator(self, dist):
self.assertIsNotNone(dist.make_input_fn_iterator(_test_input_fn))

@_run_in_and_out_of_scope
def testReduce(self, dist):
x = constant_op.constant(1.)
x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x)
self.assertEqual(self.evaluate(x), self.evaluate(x_r))

@_run_in_and_out_of_scope
def testExperimentalMakeNumpyDataset(self, dist):
numpy_input = np.ones([10], dtype=np.float32)
dataset = dist.extended.experimental_make_numpy_dataset(numpy_input)
self.assertEqual(
self.evaluate(dataset.reduce(0., lambda a, b: a + b)), 10.)

@_run_in_and_out_of_scope
def testExperimentalRunStepsOnIterator(self, dist):
all_inputs = []
dataset = dataset_ops.Dataset.from_tensors(1.).repeat()
dist.extended.experimental_run_steps_on_iterator(
lambda _, inputs: all_inputs.append(self.evaluate(inputs)),
dataset.make_one_shot_iterator())
self.assertEqual(all_inputs, [1.])

@_run_in_and_out_of_scope
def testReduceTo(self, dist):
x = constant_op.constant(1.)
x_r = dist.extended.reduce_to(reduce_util.ReduceOp.MEAN, x, "/CPU:0")
self.assertEqual(self.evaluate(x), self.evaluate(x_r))

@_run_in_and_out_of_scope
def testBatchReduceTo(self, dist):
x = constant_op.constant(1.)
y = constant_op.constant(1.)
x_r, y_r = dist.extended.batch_reduce_to(reduce_util.ReduceOp.MEAN,
((x, "/CPU:0"), (y, "/CPU:0")))
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
self.assertEqual(self.evaluate(y), self.evaluate(y_r))

@_run_in_and_out_of_scope
def testUpdate(self, dist):
with dist.scope():
v = variables.Variable(1.)
t = constant_op.constant(2.)

def assign_fn(vv, tt):
self.assertIs(vv, v)
self.assertIs(tt, t)
dist.extended.update(v, assign_fn, (t,))

@_run_in_and_out_of_scope
def testUpdateNonSlot(self, dist):
t = constant_op.constant(2.)
update_calls = []
dist.extended.update_non_slot(t, lambda: update_calls.append(1))
self.assertEqual(len(update_calls), 1)


class DefaultDistributionStrategyTest(test.TestCase):

Expand Down

0 comments on commit f648dae

Please sign in to comment.