Skip to content

Commit

Permalink
apply test review change
Browse files Browse the repository at this point in the history
  • Loading branch information
kushanam authored and serach24 committed Jun 4, 2021
1 parent 35426e1 commit d5facb6
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions tensorflow/python/distribute/input_lib_type_spec_test.py
Expand Up @@ -460,26 +460,17 @@ def f(v):
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
enable_get_next_as_optional=[True, False],
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
],))
experimental_place_dataset_on_device=[True,False],
experimental_prefetch_to_device=[True, False],))
def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(self, distribution,
enable_get_next_as_optional,
input_options):
experimental_place_dataset_on_device,
experimental_prefetch_to_device):

if experimental_place_dataset_on_device and experimental_prefetch_to_device:
self.skipTest("Setting experimental_place_dataset_on_device and "
"experimental_prefetch_to_device to `True` is not allowed "
"when using distribute_lib.InputReplicationMode.PER_REPLICA.")

fname1 = os.path.join(self.get_temp_dir(), "1.txt")
_create_text_file(fname1, 5)
Expand All @@ -492,12 +483,16 @@ def dataset_fn(input_context):
input_context.input_pipeline_id)
return readers.TextLineDatasetV2(dataset).map(
string_ops.string_to_number).batch(
input_context.get_per_replica_batch_size(4))
input_context.get_per_replica_batch_size(4))

options = distribute_lib.InputOptions(
experimental_place_dataset_on_device = experimental_place_dataset_on_device,
experimental_prefetch_to_device = experimental_prefetch_to_device,
experimental_replication_mode = distribute_lib.InputReplicationMode.PER_REPLICA)

distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
ds = distribution.experimental_distribute_datasets_from_function(dataset_fn,
input_options)
ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, options)

iterator = iter(ds)
_check_type_spec_structure(iterator)
Expand Down

0 comments on commit d5facb6

Please sign in to comment.