Skip to content

Commit

Permalink
[xla:cpu] Add error handling to async host kernel launch
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642121945
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Jun 11, 2024
1 parent dd6e541 commit 84c5401
Show file tree
Hide file tree
Showing 37 changed files with 450 additions and 198 deletions.
11 changes: 8 additions & 3 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>

* `tf.data`
* Add optional `synchronous` argument to `map`, to specify that the `map`
should run synchronously, as opposed to be parallelizable when
`options.experimental_optimization.map_parallelization=True`. This saves
memory compared to setting `num_parallel_calls=1`.
* `tf.lite`
* `Dequantize` op supports `TensorType_INT4`.
* This change includes per-channel dequantization.
Expand Down Expand Up @@ -158,7 +163,7 @@ release. It may break some edge cases of TensorFlow API usage.
delegate features.
* Flatbuffer version update:
* `GetTemporaryPointer()` bug fixed.
* Add int64 data type support for dynamic update slice's indice tensor.
* Add int64 data type support for dynamic update slice's indice tensor.

* `tf.data`
* Add `wait` to `tf.data.Dataset.load`. If `True`, for snapshots written
Expand All @@ -168,9 +173,9 @@ release. It may break some edge cases of TensorFlow API usage.
`distributed_save` are recommended to set it to `True`.

* `tf.tpu.experimental.embedding.TPUEmbeddingV2`
* Add `compute_sparse_core_stats` for sparse core users to profile the
* Add `compute_sparse_core_stats` for sparse core users to profile the
data with this API to get the `max_ids` and `max_unique_ids`. These
numbers will be needed to configure the sparse core embedding mid level
numbers will be needed to configure the sparse core embedding mid level
api.
* Remove the `preprocess_features` method since that's no longer needed.

Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -8604,6 +8604,7 @@ Creates a dataset that applies `f` to the outputs of `input_dataset`.
ConfinedAttr<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_inter_op_parallelism,
DefaultValuedOptionalAttr<BoolAttr, "false">:$preserve_cardinality,
DefaultValuedOptionalAttr<BoolAttr, "false">:$force_synchronous,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$metadata
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def FuseMapAndBatch : Pat<
(TF_BatchDatasetV2Op
(TF_MapDatasetOp $input_dataset, $other_arguments, $f, $output_types,
$output_shapes, $use_inter_op_parallelism, $preserve_cardinality,
$map_dataset_metadata),
$force_synchronous, $map_dataset_metadata),
$batch_size, $drop_remainder, $parallel_copy, $batch_output_types,
$batch_output_shapes, $unused_batch_dataset_metadata),
(TF_MapAndBatchDatasetOp $input_dataset, $other_arguments, $batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,11 @@ Status MapParallelization::OptimizeAndCollectStats(Cluster* cluster,

auto* function =
function_library.Find(map_node->attr().at("f").func().name());
if (function_utils::IsFunctionStateful(function_library, *function, true))
if (function_utils::IsFunctionStateful(function_library, *function, true) ||
(map_node->attr().contains("force_synchronous") &&
map_node->attr().at("force_synchronous").b())) {
continue;
}

auto* parallel_map =
graph.AddNode(MakeParallelMap(map_node->name(), &graph));
Expand Down
19 changes: 14 additions & 5 deletions tensorflow/core/kernels/data/map_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,19 @@ namespace data {
/* static */ constexpr const char* const MapDatasetOp::kOutputShapes;
/* static */ constexpr const char* const MapDatasetOp::kUseInterOpParallelism;
/* static */ constexpr const char* const MapDatasetOp::kPreserveCardinality;
/* static */ constexpr const char* const MapDatasetOp::kForceSynchronous;

class MapDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
bool preserve_cardinality)
bool preserve_cardinality, bool force_synchronous)
: DatasetBase(DatasetContext(ctx)),
input_(input),
preserve_cardinality_(preserve_cardinality),
force_synchronous_(force_synchronous),
captured_func_(std::move(captured_func)),
output_types_(output_types),
output_shapes_(output_shapes) {
Expand Down Expand Up @@ -142,14 +144,18 @@ class MapDatasetOp::Dataset : public DatasetBase {
AttrValue preserve_cardinality_attr;
b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);

// Attr: force_synchronous
AttrValue force_synchronous_attr;
b->BuildAttrValue(force_synchronous_, &force_synchronous_attr);

TF_RETURN_IF_ERROR(b->AddDataset(
this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs.
{std::make_pair(1, other_arguments)}, // Tensor list inputs.
{std::make_pair(kFunc, f_attr),
std::make_pair(kTarguments, other_arguments_types_attr),
std::make_pair(kUseInterOpParallelism, use_inter_op_parallelism_attr),
std::make_pair(kPreserveCardinality,
preserve_cardinality_attr)}, // Attrs
std::make_pair(kPreserveCardinality, preserve_cardinality_attr),
std::make_pair(kForceSynchronous, force_synchronous_attr)}, // Attrs
output));
return absl::OkStatus();
}
Expand Down Expand Up @@ -230,6 +236,7 @@ class MapDatasetOp::Dataset : public DatasetBase {

const DatasetBase* const input_;
const bool preserve_cardinality_;
const bool force_synchronous_;
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
Expand All @@ -250,6 +257,7 @@ MapDatasetOp::MapDatasetOp(OpKernelConstruction* ctx)
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kForceSynchronous, &force_synchronous_));
}

void MapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
Expand All @@ -259,8 +267,9 @@ void MapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
&captured_func));

*output = new Dataset(ctx, input, std::move(captured_func), output_types_,
output_shapes_, preserve_cardinality_);
*output =
new Dataset(ctx, input, std::move(captured_func), output_types_,
output_shapes_, preserve_cardinality_, force_synchronous_);
}

namespace {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/data/map_dataset_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
"use_inter_op_parallelism";
static constexpr const char* const kPreserveCardinality =
"preserve_cardinality";
static constexpr const char* const kForceSynchronous = "force_synchronous";

explicit MapDatasetOp(OpKernelConstruction* ctx);

Expand All @@ -47,6 +48,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
bool preserve_cardinality_;
bool force_synchronous_;
};

} // namespace data
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/ops/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ REGISTER_OP("MapDataset")
.Attr("output_shapes: list(shape) >= 1")
.Attr("use_inter_op_parallelism: bool = true")
.Attr("preserve_cardinality: bool = false")
.Attr("force_synchronous: bool = false")
.Attr("metadata: string = ''")
.SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
"output_types"))
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/ops/experimental_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ REGISTER_OP("ExperimentalMapDataset")
.Attr("output_shapes: list(shape) >= 1")
.Attr("use_inter_op_parallelism: bool = true")
.Attr("preserve_cardinality: bool = false")
.Attr("force_synchronous: bool = false")
.SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
"output_types"))
.SetShapeFn(shape_inference::ScalarShape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ def testMapParallelization(self, function, should_optimize):
self.assertDatasetProduces(
dataset, expected_output=[function(x) for x in range(5)])

@combinations.generate(test_base.default_test_combinations())
def testNoMapParallelizationWhenSynchronous(self):
dataset = (
dataset_ops.Dataset.range(5)
.apply(testing.assert_next(["Map"]))
.map(lambda x: x + 1, synchronous=True)
)
options = options_lib.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.map_parallelization = True
dataset = dataset.with_options(options)
self.assertDatasetProduces(
dataset, expected_output=[x + 1 for x in range(5)]
)

@combinations.generate(test_base.default_test_combinations())
def testCapturedConstant(self):
captured_t = constant_op.constant(42, dtype=dtypes.int64)
Expand Down
39 changes: 31 additions & 8 deletions tensorflow/python/data/ops/dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2152,7 +2152,12 @@ def sparse_batch(self, batch_size, row_shape, name=None) -> "DatasetV2":
# pylint: disable=g-import-not-at-top,protected-access

def map(
self, map_func, num_parallel_calls=None, deterministic=None, name=None
self,
map_func,
num_parallel_calls=None,
deterministic=None,
synchronous=None,
name=None,
) -> "DatasetV2":
"""Maps `map_func` across the elements of this dataset.
Expand Down Expand Up @@ -2299,6 +2304,15 @@ def map(
determinism for performance. If not specified, the
`tf.data.Options.deterministic` option (`True` by default) controls the
behavior.
synchronous: (Optional.) Whether to force the map transformation to run
synchronously. This only matters when
`options.experimental_optimization.map_parallelization=True`. That
option would normally change the map to run with
`num_parallel_calls=tf.data.AUTOTUNE`, but if `synchronous=True` is
specified, the map will not be parallelized at all. This is useful for
saving memory, since even setting `num_parallel_calls=1` will cause one
batch to be buffered, while with `synchronous=True` the map
transformation doesn't buffer anything.
name: (Optional.) A name for the tf.data operation.
Returns:
Expand All @@ -2308,12 +2322,15 @@ def map(
# dataset_ops).
# pylint: disable=g-import-not-at-top,protected-access
from tensorflow.python.data.ops import map_op

return map_op._map_v2(
self,
map_func,
num_parallel_calls=num_parallel_calls,
deterministic=deterministic,
name=name)
synchronous=synchronous,
name=name,
)
# pylint: enable=g-import-not-at-top,protected-access

def flat_map(self, map_func, name=None) -> "DatasetV2":
Expand Down Expand Up @@ -4069,20 +4086,26 @@ def padded_batch(self,
name=name))

@functools.wraps(DatasetV2.map)
def map(self,
map_func,
num_parallel_calls=None,
deterministic=None,
name=None):
def map(
self,
map_func,
num_parallel_calls=None,
deterministic=None,
synchronous=None,
name=None,
):
# Loaded lazily due to a circular dependency (dataset_ops -> map_op ->
# dataset_ops).
# pylint: disable=g-import-not-at-top,protected-access
from tensorflow.python.data.ops import map_op

return map_op._map_v1(
self,
map_func,
num_parallel_calls=num_parallel_calls,
deterministic=deterministic)
deterministic=deterministic,
synchronous=synchronous,
)
# pylint: enable=g-import-not-at-top,protected-access

@deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
Expand Down

0 comments on commit 84c5401

Please sign in to comment.