Skip to content

Commit

Permalink
Merged commit includes the following changes:
Browse files Browse the repository at this point in the history
511837617  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated rollback of changelist 509256232.

511836298  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA:CPU] Outline fusion regions in the presence of implicit constant-like operands

--
511832706  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Remove decorator usage for parallel devices as a first step in removing `control_flow_ops.py`'s dependency on `def_function.py` to eliminate circular dependencies.

--
511832154  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [StableHLO to MHLO] Handle bounds of Gather op

    Based on:
    openxla/stablehlo#908

--
511832050  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Specify `save_type='checkpoint'` when calling `trackable_children`.

--
511830259  by A. Unique TensorFlower<gardener@tensorflow.org>:

      Uses the same clustering algorithm in the TF2XLA bridge for TAC passes. They have better results than the current clustering in the number of clusters.

--
511830057  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Add an option to strip location information from ops.

    These end up becoming huge during large fusions (like TpuRewritePass) and in practice don't help much with readability, as locations for TF are typically indicated by the op name.

--
511826780  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA] Add all-gather-start/done multi-operand shape inference tests

--
511815596  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Describe potential problem in case if the LLVM toolchain is used.

--
511810373  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Update TFRT dependency to use revision
    http://github.com/tensorflow/runtime/commit/1ed8df8df17f431936090acde5122456c5eed394.

--
511801672  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Plumb CollectiveAllToAllV2 to MLIR.

    Also fix a inconsistency regarding CollectiveGatherV2.

    Follow up of #59598

--
511799604  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Integrate LLVM at llvm/llvm-project@219ba2fb7b0a

    Updates LLVM usage to match
    [219ba2fb7b0a](llvm/llvm-project@219ba2fb7b0a)

--
511797844  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Update TFRT dependency to use revision
    http://github.com/tensorflow/runtime/commit/9b50571c5b76f32103ad5099b0993581af3d0592.

--
511793598  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Update type-checks of `DatasetV1` and `DatasetV2` to use abstract types.

    Changes usages of the internal `DatasetV1` and `DatasetV2` types
    to use the `tensorflow.types.data` versions instead of the concrete
    implementations.

    This helps reduce the tendency for cyclic dependencies involving the
    `dataset_ops.py` module.

    Usages of the concrete type (e.g. instantiation, member access) are not
    affected by this change.

--
511765379  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [GmlSt] Remove 'distribute' attribute from ParallelOp tiling params.

--
511760944  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated rollback of changelist 510948939.

511756866  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated rollback of changelist 482514900.

511749935  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA:CPU Next] Disable tiling of linalg.generic.

--
511746972  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [GmlSt] Do not tile linalg.generic if it was tiled already.

    The "tiled labels" are not populated correctly. Theoretically, we shouldn't
    check for the parent.

--
511745314  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA:CPU Next] Fix scalarization of scf.for.

    If the type of the output is not specified, then FromElementsOp creates a 1D tensor

    void FromElementsOp::build(OpBuilder &builder, OperationState &result,
                               ValueRange elements) {
      assert(!elements.empty() && "expected at least one element");
      Type resultType = RankedTensorType::get(
          {static_cast<int64_t>(elements.size())}, elements.front().getType());
      build(builder, result, resultType, elements);
    }

    The test that i had before in scalarization.mlir was 1D by coincidence and therefore worked.

--
511741158  by A. Unique TensorFlower<gardener@tensorflow.org>:

    FIll calculates the output tensor if all the information required is available during Prepare.

    This means that the output will be available for subsequent operator's Prepare and Eval will be free

--
511737032  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [GmlSt] Remove gml_st.materialize.

--
511736776  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA:CPU Next] Disable scf.if vectorization.

--
511734615  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [DelegatePerformance] Removed a strengthened precondition and changed the metric value type from float to double.

    The change removes the strengthened precondition from the inherited method computeModelReport() in the derived class ModelBenchmarkReport. It also changes the metric value type from float to double to avoid potential calculation errors.

--
511730756  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA:CPU Next] Add a pattern to tile linalg.generic to 1 and fuse greedily.

--
511726589  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [GmlSt] Remove useless peeling label.

--
511719057  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Transforming grouped convolution to depth wise when possible.

--
511716904  by A. Unique TensorFlower<gardener@tensorflow.org>:

    PR #59775: Fix tensor_or_memref build error without math.h

    Imported from GitHub PR #59775

    #59536 (comment)

    Since pevious PR has been rolled back, this one adopted reviewer's feedback to fix Mac build error.

    Thanks @reedwm for the feedback and suggestion!
    Copybara import of the project:

    --
    4f9f856 by Chao Chen <cchen104@amd.com>:

    fix tensor_or_memref build error without math.h

    Merging this change closes #59775

--
511715187  by A. Unique TensorFlower<gardener@tensorflow.org>:

    compat: Update forward compatibility horizon to 2023-02-23

--
511715185  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Update GraphDef version to 1416.

--
511712677  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA] Further speedup HloModule::Print by using integer key for CanonicalNameMap.

--
511712612  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Go: Update generated wrapper functions for TensorFlow ops.

--
511708539  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Update ops-related pbtxt files.

--
511706378  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Add SegmentMaxV2op with num_segments as additional input.

    The only difference with SegmentMax is the additional input  `num_segment`.
    This helps in evaluating the output shape in compile time.

--
511703897  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Go: Update generated wrapper functions for TensorFlow ops.

--
511703121  by A. Unique TensorFlower<gardener@tensorflow.org>:

    PR #59616: ReLU Epilogue Fusion for FP8 GEMMs in XLA

    Imported from GitHub PR #59616

    Enables the epilogue fusion of ReLU activations for FP8 GEMMs.
    Copybara import of the project:

    --
    8813bb2 by Philipp Hack <phack@nvidia.com>:

    Epilogue fusion of ReLU activations for FP8 GEMMs.

    --
    c9ee1d5 by Philipp Hack <phack@nvidia.com>:

    Epilogue fusion of ReLU activations for FP8 GEMMs.

    Merging this change closes #59616

--
511693497  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Add SegmentMinV2op with num_segments as additional input.

    The only difference with SegmentMin is the additional input  `num_segment`.
    This helps in evaluating the output shape in compile time.

--
511692851  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Go: Update generated wrapper functions for TensorFlow ops.

--
511690064  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [jax2tf] Use CUDA and ROCM instead of GPU for XlaCallModuleOp platforms

    JAX is moving to using ROCM and CUDA instead of the generic GPU platform
    type and it is already supporting separate lowerings for ROCM and CUDA.
    To keep up with this functionality, we move the XlaCallModuleOp to supporting
    ROCM and CUDA platforms.

--
511687577  by A. Unique TensorFlower<gardener@tensorflow.org>:

    VLOG(1) upon fallback in phase 2. Whether the new or old bridge ran is usually the first thing to find out when debugging. So when fallback from new bridge to old bridge happens, it should be reported with log level 1.

--
511684159  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Cast BF16 Depthwise conv2D ops to f32 ops.

    This change is to cast BF16 Depthwise Conv2d ops to f32 to make it ready for quantization. But, as the Depthwise conv2D quantization is disabled due to performance improvement issue for now, this change does not guarantee the BF16 Depthwise Conv2D quantization.

--
511677183  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Added a util function to process einsum

--
511676949  by A. Unique TensorFlower<gardener@tensorflow.org>:

    #tf-data Retry empty repetitions when repeating data service dataset.

    The ForeverRepeat op assumes if the first repetition produces no data,
    all future repetitions will produce no data. That is not always true.
    For example, when using tf.data service, different repetitions may
    produce different numbers of elements, and empty repetitions should be
    retried.

--
511657402  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA] More speedups to HloModule::Print.

--
511647352  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated rollback of changelist 511611390.

511644926  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Refactor TfrtPipelineOptions to a separate file so that it can be reused.

--
511642132  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Update TFRT dependency to use revision
    http://github.com/tensorflow/runtime/commit/410c7f3b07f1d1170b13242a2cf9af4ec7edd33f.

--
511640480  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Update `DimsAre` matcher to be polymorphic over both `TfLiteTensor` and `TfLiteIntArrays`.

--
511639990  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix Relayout handling under XLA SPMD.

    According to the plans by samuelslee@, now that chensunx@ contributed
    the pass to lower Relayout to Identity.

--
511639775  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Integrate LLVM at llvm/llvm-project@a7b6978285c1

    Updates LLVM usage to match
    [a7b6978285c1](llvm/llvm-project@a7b6978285c1)

--
511637257  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Include full node_def info in activity watcher.

--
511637139  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Update TFRT dependency to use revision
    http://github.com/tensorflow/runtime/commit/bd588322b3d6660903ee0df9b55d2f589aa5dc8a.

--
511632436  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated rollback of changelist 511597567.

511611390  by A. Unique TensorFlower<gardener@tensorflow.org>:

    PR #57956: Performance Enhancements for Sparse Embedding Lookups

    Imported from GitHub PR #57956

    Introduces performance options for sparse embedding lookups that can appreciably speed up the training of recommendation systems. Sparse lookups alternatively accept inputs described by RaggedTensors which are more memory efficient. Performance is further increased by the optional use of a simplified and typically faster embedding lookup.

    In the sparse embedding micro benchmarks in tensorflow/python/eager/benchmarks_test.py, the number of examples per second on a DGX A100 system increases from approx. 1,300 with SparseTensor and without simplified lookup to approx. 11,200 with RaggedTensor inputs and simplified lookup (+760%). The combination of SparseTensor inputs and simplified lookup yields approx. 3,000 examples per second (+130%).
    Copybara import of the project:

    --
    00ee1a9 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    fe6eb18 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    47157c3 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    8826a9b by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    4e77072 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    3d1096a by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    c115b80 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    842591c by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    2a307f3 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    ee90704 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    9e216e3 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    c393302 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    5165f07 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    f3e88c4 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    8384791 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    497f9c3 by Philipp Hack <phack@nvidia.com>:

    Adds performance enhancements for sparse embedding lookups.

    --
    2166e5d by Philipp Hack <phack@nvidia.com>:

    Performance enhancements for sparse embedding lookups.

    --
    13dd2c8 by Philipp Hack <phack@nvidia.com>:

    Performance enhancements for sparse embedding lookups.

    --
    a83f4d4 by Philipp Hack <phack@nvidia.com>:

    Performance enhancements for sparse embedding lookups.

    Merging this change closes #57956

--
511609283  by A. Unique TensorFlower<gardener@tensorflow.org>:

    #tf-data Ramp down `autotune_buffer_optimization` experiment.

--
511606842  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Avoid double counting graph building time for nested functions.

--
511606056  by A. Unique TensorFlower<gardener@tensorflow.org>:

    PR #59619: [NVIDIA TF] Throw error only for non-empty function definitions.

    Imported from GitHub PR #59619

    During Function Serialization and Deserialization with Saved Models there can be Node ops with type `func` with default values that have no associated name. In such cases, we shouldn't look for undefined empty strings in Function Library. This is a trivial change and helps with how Saved Models are loaded.
    Copybara import of the project:

    --
    e0fe60f by Pavani Majety <pmajety@nvidia.com>:

    [BugFix] Throw error only for non-empty function definitions.

    Add reason for the required change.

    Add comment to both locations.

    Merging this change closes #59619

--
511605660  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix typo in comment

--
511604096  by A. Unique TensorFlower<gardener@tensorflow.org>:

    #tf-data-service Put distributed_save_test and snapshot_ft_test together.

--
511603520  by A. Unique TensorFlower<gardener@tensorflow.org>:

    update tracking bug now that the step id is propagated correctly.

--
511602286  by A. Unique TensorFlower<gardener@tensorflow.org>:

    add a size() member function to HloProtoMap.

--
511599929  by A. Unique TensorFlower<gardener@tensorflow.org>:

    gpu_delegate: Update to support Mali G715

--
511597567  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [tf-lite] Enable parallel transpose.

--
511596662  by A. Unique TensorFlower<gardener@tensorflow.org>:

    add an aggregated stats for per step result.
    the goal is to remove all_reduce_db_per_core.
    we will lose the capabilities of separate compute time and synchronization time for TPU only.
    but I think it is fine, most of collective in TPU now are async ops, so these algorithm is no longer very informative. we will count all async collectives as synchronization time for tpu.
    for gpu, this doesn't apply.

--
511590747  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [PJRT C API] Add a README file to provide communication channel and resources.

--
511585380  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated rollback of changelist 511565925.

511578143  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA] Minor fixes in ShardingPropagation.

    - Moves the misplaced comment block of replicate_on_last_tile_dim_.
    - Uses existing variable root_instr to eliminate redundant accesses.
    - Replaces bitwise and with logical and.
    - Uses reverse iterator instead of reversing the list.

--
511569312  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Remove `type_spec` direct dependency on  `framework/ops.py`.

    Changes `type_spec` references to `tf.Tensor` to use an abstract base type
    for `isinstance` checks.

    Uses the conversion function in `tensor_conversion_registry` instead
    of its wrapper in `ops`. While `tensor_conversion_registry` has an indirect
    dependency on `ops.py`, there is future work planned to remove that dependency.

--
511568795  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Check the shape of indices matches the shape of dense_shape in sparse_fill_empty_rows op

--
511567163  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [xla-next][mlir][sparse] add mhlo sparsity rewriting to pipeline

--
511565925  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Update visibility to fix OSS build

--

PiperOrigin-RevId: 511837617
  • Loading branch information
tensorflower-gardener authored and MichaelHudgins committed Feb 23, 2023
1 parent 62dcf7e commit 5eaf871
Show file tree
Hide file tree
Showing 214 changed files with 3,304 additions and 1,967 deletions.
21 changes: 8 additions & 13 deletions tensorflow/compiler/jit/flags.cc
Expand Up @@ -225,11 +225,10 @@ void AllocateAndParseFlags() {
// bridge, on a per-graph basis).
bool enable_mlir_bridge = false;
bool enable_mlir_bridge_is_explicit = false;
bool enable_mlir_merge_control_flow_pass = true;
bool enable_mlir_convert_control_to_data_outputs_pass = false;
// Dump graphs in TFG dialect.
bool use_tfg_graph_dumper = false;
bool enable_mlir_generic_outside_compilation = false;

mlir_flags = new MlirCommonFlags;

flag_list = new std::vector<Flag>(
{Flag("tf_xla_enable_lazy_compilation",
Expand Down Expand Up @@ -275,25 +274,27 @@ void AllocateAndParseFlags() {
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
&enable_mlir_bridge_is_explicit),
Flag("tf_mlir_enable_merge_control_flow_pass",
&enable_mlir_merge_control_flow_pass,
&mlir_flags->tf_mlir_enable_merge_control_flow_pass,
"Enables MergeControlFlow pass for MLIR-Based TensorFlow Compiler "
"Bridge."),
Flag("tf_mlir_enable_convert_control_to_data_outputs_pass",
&enable_mlir_convert_control_to_data_outputs_pass,
&mlir_flags->tf_mlir_enable_convert_control_to_data_outputs_pass,
"Enables `tf-executor-convert-control-to-data-outputs` pass for "
"MLIR-Based TensorFlow Compiler Bridge."),
Flag("tf_dump_graphs_in_tfg", &use_tfg_graph_dumper,
"When tf_dump_graphs_in_tfg is true, graphs after transformations "
"are dumped in MLIR TFG dialect and not in GraphDef"),
Flag("tf_mlir_strip_debug", &mlir_flags->tf_mlir_strip_debug,
"Strip debug information (like loc) from programs for easier "
"reading of op listings."),
Flag("tf_mlir_enable_generic_outside_compilation",
&enable_mlir_generic_outside_compilation,
&mlir_flags->tf_mlir_enable_generic_outside_compilation,
"Enables OutsideCompilation passes for MLIR-Based TensorFlow "
"Generic Compiler Bridge.")});

AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);

mlir_flags = new MlirCommonFlags;
if (!enable_mlir_bridge_is_explicit) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
Expand All @@ -304,12 +305,6 @@ void AllocateAndParseFlags() {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
}
mlir_flags->tf_mlir_enable_merge_control_flow_pass =
enable_mlir_merge_control_flow_pass;
mlir_flags->tf_mlir_enable_convert_control_to_data_outputs_pass =
enable_mlir_convert_control_to_data_outputs_pass;
mlir_flags->tf_mlir_enable_generic_outside_compilation =
enable_mlir_generic_outside_compilation;

if (use_tfg_graph_dumper) {
UseMlirForGraphDump(MlirDumpConfig{}.elide_large_attributes().emit_dialect(
Expand Down
7 changes: 4 additions & 3 deletions tensorflow/compiler/jit/flags.h
Expand Up @@ -155,9 +155,10 @@ struct BuildXlaOpsPassFlags {
struct MlirCommonFlags {
ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;

bool tf_mlir_enable_merge_control_flow_pass;
bool tf_mlir_enable_convert_control_to_data_outputs_pass;
bool tf_mlir_enable_generic_outside_compilation;
bool tf_mlir_enable_merge_control_flow_pass = true;
bool tf_mlir_enable_convert_control_to_data_outputs_pass = false;
bool tf_mlir_enable_generic_outside_compilation = false;
bool tf_mlir_strip_debug = false;
};

// Flags for the JitRt pipeline -- see tf_jitrt_pipeline.h for details.
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Expand Up @@ -2165,6 +2165,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"ScatterNd",
"SegmentSumV2",
"SegmentProdV2",
"SegmentMinV2",
"SegmentMaxV2",
"SelfAdjointEigV2",
"SoftmaxCrossEntropyWithLogits",
"SpaceToBatch",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/experimental/common/BUILD
Expand Up @@ -8,6 +8,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow:cluster_util",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
Expand Down
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/utils/cluster_util.h"

namespace mlir {
namespace TFL {
Expand Down Expand Up @@ -203,6 +204,9 @@ void ExtractSubgraphToFunc(const Subgraph& subgraph, OpBuilder& builder,
op->dropAllReferences();
op->erase();
}
// Ensure that users of the call op's results appear after the launch op in
// order to preserve the dominance property.
TF::ReorderOpResultUses(call_op);
}

} // namespace common
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/experimental/tac/BUILD
Expand Up @@ -248,7 +248,9 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
"//tensorflow/compiler/mlir/lite/experimental/common:outline_operations",
"//tensorflow/compiler/mlir/lite/experimental/tac/hardwares:target_hardware",
"//tensorflow/compiler/mlir/tensorflow:cluster_util",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
Expand Down

0 comments on commit 5eaf871

Please sign in to comment.