Skip to content

Commit

Permalink
Tool to convert TF SavedModel to StableHLO
Browse files Browse the repository at this point in the history
Here is the signature of the provide API:

```c++
// Converts a TensorFlow model (either from a SavedModel or an MLIR module) to a
// StableHLO MLIR module.
//
// Args:
//  input_path: The path to the input TensorFlow SavedModel or MLIR module.
//  context: The MLIR context to use for parsing or creating the MLIR module.
//  exported_model_signatures: A comma-separated list of exported model
//    signatures (functions) to convert.
//  tag_names: A comma-separated list of tag names used for loading SavedModel.
//  input_arg_shapes_str: A string representation of input argument shapes.
//    Shapes for different tensors are separated by ':', and dimension sizes for
//    the same tensor are separated by ','. For example,
//    'input-arg-shapes=1,2::1,?' expresses input arguments with shapes [1,2],
//    [] and [1,?].
//  is_input_mlir_module: If true, `input_path` is treated as an MLIR
//    module instead of a SavedModel.
//
// Returns:
//   An absl::StatusOr containing the converted StableHLO MLIR module on
//   success, or an absl::Status with an error message on failure.
absl::StatusOr<OwningOpRef<ModuleOp>> TfToStablehlo(
    absl::string_view input_path, MLIRContext* context,
    absl::string_view exported_model_signatures, absl::string_view tag_names,
    absl::string_view input_arg_shapes_str, bool is_input_mlir_module = false);
```

Reverts 2ba594d

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11721 from inailuig:mpicollectives_pytype 3924cc0fbbb63e9503f38a59aede3b8e817b17fa
PiperOrigin-RevId: 625351420
  • Loading branch information
sdasgup3 authored and tensorflower-gardener committed Apr 24, 2024
1 parent b251941 commit f67804f
Show file tree
Hide file tree
Showing 48 changed files with 1,112 additions and 247 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ cc_library(
deps = [
":stablehlo_util",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include <vector>

#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
Expand Down Expand Up @@ -68,14 +69,33 @@ class RenameEntrypointToMainPass
}

if (entrypoints.empty()) {
fail(module, "No entrypoints found");
} else if (entrypoints.size() == 1) {
return fail(module, "No entrypoints found");
}
if (entrypoints.size() == 1) {
auto entrypoint = entrypoints.begin()->second;
Builder builder(entrypoint);
entrypoint.setName(builder.getStringAttr("main"));
} else {
fail(module, "Too many entrypoints found");
return;
}

// In case we have more than 1 entry points, choose the one with
// 'tf.entry_function' attribute set.
llvm::SmallVector<func::FuncOp, 4> candidate_funcs;
for (auto& entrypoint : entrypoints) {
if (entrypoint.second->hasAttr("tf.entry_function")) {
candidate_funcs.push_back(entrypoint.second);
}
}

if (candidate_funcs.empty()) {
return fail(module, "No entrypoints found");
}
if (candidate_funcs.size() > 1) {
return fail(module, "Too many entrypoints found");
}
// Found entrypoint
Builder builder(candidate_funcs[0]);
candidate_funcs[0].setName(builder.getStringAttr("main"));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ absl::StatusOr<ImportedMlirModuleOp> SavedModelToMlirModuleOp(
module_op.status().ToString()));
}

return std::make_pair(module_op->release(), std::move(bundle));
return std::make_pair(std::move(*module_op), std::move(bundle));
}

absl::StatusOr<absl::flat_hash_map<FunctionName, FunctionAlias>>
Expand Down Expand Up @@ -119,7 +119,7 @@ void UpdateFunctionAliases(
});
}

absl::StatusOr<ModuleOp> ImportSavedModel(
absl::StatusOr<OwningOpRef<ModuleOp>> ImportSavedModel(
const absl::string_view saved_model_path,
const std::vector<std::string>& signature_keys,
const std::unordered_set<std::string>& tags,
Expand All @@ -132,7 +132,7 @@ absl::StatusOr<ModuleOp> ImportSavedModel(
SavedModelToMlirModuleOp(saved_model_path, tags, signature_keys, ctx));
auto [module_op, saved_model_bundle] = std::move(imported_module);

UpdateFunctionAliases(function_aliases, module_op);
UpdateFunctionAliases(function_aliases, *module_op);

// Collect the names of the functions that have aliases so that they may not
// be inlined.
Expand All @@ -143,11 +143,11 @@ absl::StatusOr<ModuleOp> ImportSavedModel(

TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph(
mlir_dump_file_prefix, /*is_inliner_run=*/true,
/*noinline_functions=*/aliased_function_names, module_op, &ctx,
/*noinline_functions=*/aliased_function_names, *module_op, &ctx,
saved_model_bundle == nullptr ? nullptr
: saved_model_bundle->GetSession(),
/*run_tf_to_stablehlo=*/true, /*deserialize_xla_call_module=*/false));
return module_op;
return std::move(module_op);
}

} // namespace mlir::quant::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
Expand All @@ -38,7 +39,8 @@ namespace mlir::quant::stablehlo {
// `tensorflow::Session` which may be useful when reading values from resources
// (e.g. `TF::VarHandleOp`s).
using ImportedMlirModuleOp =
std::pair<ModuleOp, std::unique_ptr<::tensorflow::SavedModelBundle>>;
std::pair<OwningOpRef<ModuleOp>,
std::unique_ptr<::tensorflow::SavedModelBundle>>;

// Loads a SavedModel at `saved_model_path` and converts it to `mlir::ModuleOp`.
//
Expand Down Expand Up @@ -72,7 +74,7 @@ void UpdateFunctionAliases(
// Loads a SavedModel to `mlir::ModuleOp` and performs preprocesses including
// shape inference and graph freezing.
// TODO: b/329206105 - Add unit tests after decomposing preprocessing passes.
absl::StatusOr<ModuleOp> ImportSavedModel(
absl::StatusOr<OwningOpRef<ModuleOp>> ImportSavedModel(
absl::string_view saved_model_path,
const std::vector<std::string>& signature_keys,
const std::unordered_set<std::string>& tags,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h"
Expand Down Expand Up @@ -105,22 +106,22 @@ absl::Status QuantizeStaticRangePtq(
}

TF_ASSIGN_OR_RETURN(
ModuleOp module_op,
OwningOpRef<ModuleOp> module,
ImportSavedModel(src_saved_model_path, signature_keys, tags,
quantization_config, PreCalibrationComponent::kName,
*function_aliases, *ctx));

StaticRangePtqComponent static_range_ptq_component(
ctx.get(), &py_function_library, src_saved_model_path, signature_keys,
tags, signature_def_map, *function_aliases);
TF_ASSIGN_OR_RETURN(module_op, static_range_ptq_component.Run(
module_op, quantization_config));
TF_ASSIGN_OR_RETURN(
*module, static_range_ptq_component.Run(*module, quantization_config));

TF_ASSIGN_OR_RETURN(
const ExportedModel post_calibrated_exported_model,
CreateExportedModel(signature_keys, tags, quantization_config,
PostCalibrationComponent::kName, *function_aliases,
*ctx, module_op));
*ctx, *module));

// Remove the `tpu` tag for exporting because the output quantized model is
// essentially a CPU model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,20 @@ absl::Status QuantizeWeightOnlyPtq(
}

TF_ASSIGN_OR_RETURN(
ModuleOp module_op,
auto module,
ImportSavedModel(src_saved_model_path, signature_keys, tags,
quantization_config, WeightOnlyPtqComponent::kName,
*function_aliases, *ctx));

WeightOnlyPtqComponent weight_only_ptq_component(ctx.get());
TF_ASSIGN_OR_RETURN(
module_op, weight_only_ptq_component.Run(module_op, quantization_config));
*module, weight_only_ptq_component.Run(*module, quantization_config));

TF_ASSIGN_OR_RETURN(
const ExportedModel post_calibrated_exported_model,
CreateExportedModel(signature_keys, tags, quantization_config,
WeightOnlyPtqComponent::kName, *function_aliases,
*ctx, module_op));
*ctx, *module));

// Remove the `tpu` tag for exporting because the output quantized model is
// essentially a CPU model.
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/quantization/tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
Expand All @@ -23,6 +24,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
Expand Down Expand Up @@ -69,7 +71,9 @@ void AddUnfuseMhloOpsPasses(mlir::PassManager& pm) {

// Converts TF SavedModel to StableHLO module. The input TF SavedModel can have
// StableHLO module serialized into a XlaCallModuleOp. (ex: JAX/PyTorch models)
void AddTFToStablehloPasses(mlir::PassManager& pm) {
void AddTFToStablehloPasses(
mlir::PassManager& pm,
llvm::ArrayRef<llvm::ArrayRef<int64_t>> input_arg_shapes) {
pm.addPass(mlir::odml::CreateRenameEntrypointToMainPass());
// TODO: b/230572023 - Consider improving shape inference for While op instead
// of dropping the attribute. This need not be correct for models not trained
Expand Down Expand Up @@ -97,7 +101,7 @@ void AddTFToStablehloPasses(mlir::PassManager& pm) {
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());
// Propagates shapes on the TensorFlow graph.
pm.addPass(mlir::TF::CreateTFShapeInferencePass());
pm.addPass(mlir::TF::CreateTFShapeInferencePass(input_arg_shapes));
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::TFDevice::CreateDecomposeResourceOpsPass());
Expand All @@ -110,7 +114,7 @@ void AddTFToStablehloPasses(mlir::PassManager& pm) {

// Generic MLIR optimization passes.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::TF::CreateTFShapeInferencePass());
pm.addPass(mlir::TF::CreateTFShapeInferencePass(input_arg_shapes));

// Legalizes TF UniformQuantized types into MHLO. Part of the official
// TF/XLA bridge component.
Expand All @@ -137,7 +141,8 @@ absl::Status PreprocessAndFreezeGraph(
const absl::flat_hash_set<std::string>& noinline_functions,
mlir::ModuleOp module_op, mlir::MLIRContext* context,
std::optional<Session*> session, const bool run_tf_to_stablehlo,
const bool deserialize_xla_call_module) {
const bool deserialize_xla_call_module,
llvm::ArrayRef<llvm::ArrayRef<int64_t>> input_arg_shapes) {
mlir::PassManager pm_before_freezing_variables(context);
mlir::StatusScopedDiagnosticHandler statusHandler(module_op.getContext(),
/*propagate=*/true);
Expand Down Expand Up @@ -169,7 +174,7 @@ absl::Status PreprocessAndFreezeGraph(
if (run_tf_to_stablehlo) {
// AddLegalizeTFToStablehloPasses expects frozen TF variables when
// legalizing to stablehlo.constant.
AddTFToStablehloPasses(pm_after_freezing_variables);
AddTFToStablehloPasses(pm_after_freezing_variables, input_arg_shapes);
}

if (deserialize_xla_call_module) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_

#include <cstdint>
#include <optional>
#include <string>

Expand Down Expand Up @@ -45,7 +46,8 @@ absl::Status PreprocessAndFreezeGraph(
const absl::flat_hash_set<std::string>& noinline_functions,
mlir::ModuleOp module_op, mlir::MLIRContext* context,
std::optional<Session*> session, bool run_tf_to_stablehlo,
bool deserialize_xla_call_module);
bool deserialize_xla_call_module,
llvm::ArrayRef<llvm::ArrayRef<int64_t>> input_arg_shapes = {});

// Overload of `PreprocessAndFreezeGraph` that uses the default MLIR dump file
// prefix.
Expand All @@ -56,10 +58,15 @@ inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op,
/*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix,
/*is_inliner_run=*/true, /*noinline_functions=*/{}, module_op, context,
session, /*run_tf_to_stablehlo=*/false,
/*deserialize_xla_call_module=*/false);
/*deserialize_xla_call_module=*/false, /*input_arg_shapes=*/{});
}

void AddTFToStablehloPasses(mlir::PassManager& pm);
// TF->StableHLO has limited support for dynamic shapes.
// Some models can only be converted with explicitly provided input argument
// shapes.
void AddTFToStablehloPasses(
mlir::PassManager& pm,
llvm::ArrayRef<llvm::ArrayRef<int64_t>> input_arg_shapes = {});

} // namespace quantization
} // namespace tensorflow
Expand Down
106 changes: 106 additions & 0 deletions tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable")
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist")

package_group(
name = "internal_visibility_allowlist_package",
packages = [
"//learning/brain/mlir/quantization/stablehlo/python/integration_test/...",
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/compiler/mlir/quantization/...",
"//tensorflow/compiler/mlir/tf2xla/transforms/...",
"//tensorflow/lite/...",
"//third_party/cloud_tpu/inference_converter/...", # TPU Inference Converter V1
] + internal_visibility_allowlist(),
)

package(
# copybara:uncomment default_applicable_licenses = ["@stablehlo//:license"],
default_visibility = [
":internal_visibility_allowlist_package",
"//tensorflow:__pkg__",
],
licenses = ["notice"],
)

cc_library(
name = "tf_to_stablehlo",
srcs = [
"tf_to_stablehlo.cc",
],
hdrs = [
"tf_to_stablehlo.h",
],
compatible_with = get_compatible_with_portable(),
deps = [
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:saved_model_import",
"//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess",
"//tensorflow/compiler/mlir/tensorflow/transforms:shape_inference_pass",
"//tensorflow/core:core_cpu_base",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
alwayslink = True,
)

tf_cc_binary(
name = "tf-to-stablehlo-translate",
srcs = [
"tf_to_stablehlo_translate.cc",
],
visibility = [":internal_visibility_allowlist_package"],
deps = [
":tf_to_stablehlo",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)

glob_lit_tests(
name = "all_tests",
data = [":test_utilities"],
# TODO: b/288344501 - Enable OSS tests again when stable-quant-opt works well.
default_tags = [
"no_oss",
"no_pip",
],
driver = "//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo:run_lit.sh",
size_override = {
},
tags_override = {
},
test_file_exts = ["mlir"],
)

# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
":tf-to-stablehlo-translate",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
"@llvm-project//mlir:run_lit.sh",
],
)

0 comments on commit f67804f

Please sign in to comment.