Skip to content

Commit

Permalink
[tf.data] move StructuredFunctionWrapper into a common module
Browse files Browse the repository at this point in the history
refactor using structured_function module

use the new API location of structured_function

move DEBUG_MODE to dataset_ops
  • Loading branch information
kvignesh1420 committed Sep 15, 2021
1 parent 3e57c7c commit efc4dea
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 306 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.experimental.ops import get_single_element
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import structured_function
from tensorflow.python.eager import def_function
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
Expand All @@ -33,7 +34,7 @@ class SingleThreadedFlatMapDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, map_func):
"""See `Dataset.flat_map()` for details."""
self._input_dataset = input_dataset
self._map_func = dataset_ops.StructuredFunctionWrapper(
self._map_func = structured_function.StructuredFunctionWrapper(
map_func,
self._transformation_name(),
dataset=input_dataset,
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/python/data/experimental/ops/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import structured_function
from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
Expand Down Expand Up @@ -331,7 +332,7 @@ def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
drop_remainder, use_legacy_function=False):
self._input_dataset = input_dataset

self._map_func = dataset_ops.StructuredFunctionWrapper(
self._map_func = structured_function.StructuredFunctionWrapper(
map_func,
"tf.data.experimental.map_and_batch()",
dataset=input_dataset,
Expand Down Expand Up @@ -437,7 +438,7 @@ def to_ragged_variant(value):
return spec._to_tensor_list(value)[0] # pylint: disable=protected-access

# Tuples are automatically unpacked by `dataset.map` so we repack them.
if dataset_ops._should_unpack(input_dataset.element_spec): # pylint: disable=protected-access
if structured_function._should_unpack(input_dataset.element_spec): # pylint: disable=protected-access
map_fn = lambda *value: nest.map_structure(to_ragged_variant, value)
else:
map_fn = lambda value: nest.map_structure(to_ragged_variant, value)
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/python/data/experimental/ops/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import structured_function
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
Expand Down Expand Up @@ -285,7 +286,7 @@ def __init__(self, input_dataset, key_func, reducer):

def _make_key_func(self, key_func, input_dataset):
"""Make wrapping defun for key_func."""
self._key_func = dataset_ops.StructuredFunctionWrapper(
self._key_func = structured_function.StructuredFunctionWrapper(
key_func, self._transformation_name(), dataset=input_dataset)
if not self._key_func.output_structure.is_compatible_with(
tensor_spec.TensorSpec([], dtypes.int64)):
Expand All @@ -296,7 +297,7 @@ def _make_key_func(self, key_func, input_dataset):

def _make_init_func(self, init_func):
"""Make wrapping defun for init_func."""
self._init_func = dataset_ops.StructuredFunctionWrapper(
self._init_func = structured_function.StructuredFunctionWrapper(
init_func,
self._transformation_name(),
input_structure=tensor_spec.TensorSpec([], dtypes.int64))
Expand All @@ -313,7 +314,7 @@ def _make_reduce_func(self, reduce_func, input_dataset):
need_to_rerun = True
while need_to_rerun:

wrapped_func = dataset_ops.StructuredFunctionWrapper(
wrapped_func = structured_function.StructuredFunctionWrapper(
reduce_func,
self._transformation_name(),
input_structure=(self._state_structure, input_dataset.element_spec),
Expand Down Expand Up @@ -366,7 +367,7 @@ def _make_reduce_func(self, reduce_func, input_dataset):

def _make_finalize_func(self, finalize_func):
"""Make wrapping defun for finalize_func."""
self._finalize_func = dataset_ops.StructuredFunctionWrapper(
self._finalize_func = structured_function.StructuredFunctionWrapper(
finalize_func, self._transformation_name(),
input_structure=self._state_structure)

Expand Down
5 changes: 3 additions & 2 deletions tensorflow/python/data/experimental/ops/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import structured_function
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
Expand Down Expand Up @@ -176,7 +177,7 @@ def _set_save_dataset_attributes(dataset, shard_func, path):
else:
use_shard_func = True

wrapped_func = dataset_ops.StructuredFunctionWrapper(
wrapped_func = structured_function.StructuredFunctionWrapper(
shard_func,
"save()",
input_structure=dataset.element_spec,
Expand Down Expand Up @@ -221,7 +222,7 @@ def __init__(self, path, element_spec=None, compression=None,
else:
self._element_spec = element_spec
self._compression = compression
self._reader_func = dataset_ops.StructuredFunctionWrapper(
self._reader_func = structured_function.StructuredFunctionWrapper(
reader_func,
"load()",
# Dataset of datasets of input elements
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/data/experimental/ops/prefetching_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import structured_function
from tensorflow.python.data.util import structure
from tensorflow.python.eager import function
from tensorflow.python.framework import device as framework_device
Expand Down Expand Up @@ -238,7 +239,7 @@ def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism

self._map_func = dataset_ops.StructuredFunctionWrapper(
self._map_func = structured_function.StructuredFunctionWrapper(
map_func,
self._transformation_name(),
dataset=input_dataset,
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/data/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,27 @@ package(
licenses = ["notice"],
)

py_library(
name = "structured_function",
srcs = ["structured_function.py"],
srcs_version = "PY3",
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:util",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:structure",
],
)

py_library(
name = "dataset_ops",
srcs = ["dataset_ops.py"],
srcs_version = "PY3",
deps = [
":iterator_ops",
":options",
":structured_function",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dataset_ops_gen",
Expand Down

0 comments on commit efc4dea

Please sign in to comment.