Skip to content

Commit

Permalink
Inject TpuCompiler pass to IfrtBackendCompiler
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623935433
  • Loading branch information
deqiangc authored and tensorflower-gardener committed Apr 11, 2024
1 parent 18fa42d commit 1c56b45
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
Expand Up @@ -146,14 +146,14 @@ absl::Status IfrtBackendCompiler::CompileTensorflow(
tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_before", module);
}

// Run backward compat pass so that we can use bridge to do clustering.
auto backward_compat_result =
tensorflow::RunTPUBackwardCompatConversion(module, {});
if (mlir::failed(backward_compat_result)) {
return diag_handler.Combine(
absl::InternalError("Failed to handle legacy TPU Ops"));
if (tpu_compiler_ != nullptr) {
// Run backward compat pass so that we can use bridge to do clustering.
if (mlir::failed(
tpu_compiler_->RunTPUBackwardCompatConversion(module, {}))) {
return diag_handler.Combine(
absl::InternalError("Failed to handle legacy TPU Ops"));
}
}

if (VLOG_IS_ON(1)) {
tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_after", module);
}
Expand Down
Expand Up @@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_
#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_


#include "absl/status/status.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tfrt/backend_compiler.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"

namespace tensorflow {
Expand All @@ -28,11 +28,17 @@ namespace ifrt_serving {
// Implements the custom backend compiler for IFRT based serving in TFRT.
class IfrtBackendCompiler : public tensorflow::BackendCompiler {
public:
explicit IfrtBackendCompiler(TpuCompiler* tpu_compiler = nullptr)
: tpu_compiler_(tpu_compiler) {}

// Rewrites the tensorflow graph in MLIR for IFRT serving. The methods
// extracts regions for IFRT execution on accelerator (e.g. TPU).
absl::Status CompileTensorflow(
tensorflow::tfrt_stub::ModelRuntimeContext& model_context,
mlir::ModuleOp module) const override;

private:
TpuCompiler* tpu_compiler_; // Not owned.
};

} // namespace ifrt_serving
Expand Down

0 comments on commit 1c56b45

Please sign in to comment.