Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update function aliases in an inline manner #63150

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,15 @@ cc_library(
compatible_with = get_compatible_with_portable(),
deps = [
":types",
"//tensorflow/cc/saved_model:reader",
"//tensorflow/core/protobuf:for_core_protos_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:errors",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,9 @@ absl::StatusOr<ExportedModel> CalibrationComponent::ExportToSavedModel(
TF_ASSIGN_OR_RETURN(const SmallVector<AssetFileDef> asset_file_defs,
RunExportPasses(export_opts, *ctx_, module_op));

const absl::flat_hash_map<FunctionName, FunctionAlias>
updated_function_aliases =
UpdateFunctionAliases(function_aliases_, module_op);

TF_ASSIGN_OR_RETURN(ExportedModel exported_model,
ConvertMlirModuleToExportedModel(
module_op, checkpoint_dir, updated_function_aliases,
module_op, checkpoint_dir, function_aliases_,
{asset_file_defs.begin(), asset_file_defs.end()}));

py_function_lib_->SaveExportedModel(dst_saved_model_path, exported_model,
Expand All @@ -208,17 +204,14 @@ absl::StatusOr<ModuleOp> CalibrationComponent::ImportCalibratedSavedModel(
tags_, signature_keys_, *ctx_));
ModuleOp module_op = imported_module.first;

const absl::flat_hash_map<FunctionName, FunctionAlias>
updated_function_aliases_post_calibration =
UpdateFunctionAliases(function_aliases_, module_op);
UpdateFunctionAliases(function_aliases_, module_op);

// Collect the names of the functions that have aliases so that they may not
// be inlined.
absl::flat_hash_set<std::string> aliased_function_names;
absl::c_for_each(updated_function_aliases_post_calibration,
[&](const auto& aliases) {
return aliased_function_names.insert(aliases.first);
});
absl::c_for_each(function_aliases_, [&](const auto& aliases) {
return aliased_function_names.insert(aliases.first);
});

// Freezing is required again since variables might have been produced
// during the pre-calibration step. `is_inliner_run = false` to prevent the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class CalibrationComponent : public Component {

// Function alias mapping for pre-calibrated SavedModel. Used to preserve
// aliased functions.
const absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases_;
absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases_;

// Tags to identify the MetaGraphDef to load from a SavedModel.
const std::unordered_set<std::string> tags_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,42 @@ limitations under the License.
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h"

#include <string>
#include <utility>
#include <unordered_set>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tsl/platform/errors.h"

namespace mlir::quant::stablehlo {

absl::flat_hash_map<FunctionName, FunctionAlias> UpdateFunctionAliases(
const absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases,
ModuleOp module_op) {
absl::flat_hash_map<FunctionName, FunctionAlias> updated_function_aliases;
absl::StatusOr<absl::flat_hash_map<FunctionName, FunctionAlias>>
GetFunctionAliases(absl::string_view saved_model_path,
const std::unordered_set<std::string>& tags) {
tensorflow::MetaGraphDef meta_graph;
TF_RETURN_IF_ERROR(tensorflow::ReadMetaGraphDefFromSavedModel(
saved_model_path, tags, &meta_graph));

absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases(
meta_graph.meta_info_def().function_aliases().begin(),
meta_graph.meta_info_def().function_aliases().end());
return function_aliases;
}

void UpdateFunctionAliases(
absl::flat_hash_map<FunctionName, FunctionAlias>& function_aliases,
ModuleOp module_op) {
absl::flat_hash_set<FunctionName> existing_func_names;
module_op->walk([&](func::FuncOp func_op) {
FunctionName func_name = func_op.getSymName().str();
existing_func_names.insert(func_name);
// We may retrieve the original function's name from the attribute.
// Functions without this attribute are ignored.
auto original_func_name =
Expand All @@ -38,14 +59,15 @@ absl::flat_hash_map<FunctionName, FunctionAlias> UpdateFunctionAliases(
if (auto alias_itr = function_aliases.find(original_func_name.str());
alias_itr != function_aliases.end()) {
const FunctionAlias alias = alias_itr->second;
const FunctionName new_func_name = func_op.getSymName().str();

updated_function_aliases[new_func_name] = alias;
function_aliases[func_name] = alias;
}
}
});

return updated_function_aliases;
// Remove aliases to function that no-longer exists.
absl::erase_if(function_aliases, [&existing_func_names](const auto& item) {
return !existing_func_names.contains(item.first);
});
}

} // namespace mlir::quant::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,31 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_IMPORT_H_
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_IMPORT_H_

#include <string>
#include <unordered_set>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h"

namespace mlir::quant::stablehlo {

// Returns the updated function aliases. `module_op` may have different
// Gets the function aliases from the SavedModel.
absl::StatusOr<absl::flat_hash_map<FunctionName, FunctionAlias>>
GetFunctionAliases(absl::string_view saved_model_path,
const std::unordered_set<std::string>& tags);

// Updates the function aliases. `module_op` may have different
// function names from the original model, so it re-associates the aliases
// with the new function names. Both the input `function_aliases` and the
// returned value are function name -> alias mappings. `function_aliases` is
// the function alias mapping of the original function. The original function's
// name is retrieved by looking at the "tf._original_func_name" string attribute
// attached to a `func::FuncOp`.
absl::flat_hash_map<FunctionName, FunctionAlias> UpdateFunctionAliases(
absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases,
void UpdateFunctionAliases(
absl::flat_hash_map<FunctionName, FunctionAlias>& function_aliases,
ModuleOp module_op);

} // namespace mlir::quant::stablehlo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ TEST_F(UpdateFunctionAliasesTest, NoAliasesReturnsEmptyMap) {
)mlir");
ASSERT_TRUE(module_op);

const absl::flat_hash_map<FunctionName, FunctionAlias>
updated_function_aliases =
UpdateFunctionAliases(/*function_aliases=*/{}, *module_op);
EXPECT_THAT(updated_function_aliases, IsEmpty());
absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases;
UpdateFunctionAliases(function_aliases, *module_op);
EXPECT_THAT(function_aliases, IsEmpty());
}

TEST_F(UpdateFunctionAliasesTest, AliasUpdatedByMlirFunctionName) {
Expand All @@ -55,11 +54,11 @@ TEST_F(UpdateFunctionAliasesTest, AliasUpdatedByMlirFunctionName) {
)mlir");
ASSERT_TRUE(module_op);

const absl::flat_hash_map<FunctionName, FunctionAlias>
updated_function_aliases = UpdateFunctionAliases(
/*function_aliases=*/{{"main_original", "main_alias"}}, *module_op);
absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases{
{"main_original", "main_alias"}};
UpdateFunctionAliases(function_aliases, *module_op);

EXPECT_THAT(updated_function_aliases,
EXPECT_THAT(function_aliases,
UnorderedElementsAre(Pair("main", "main_alias")));
}

Expand All @@ -74,11 +73,11 @@ TEST_F(UpdateFunctionAliasesTest, IgnoresUnmatchedFunctions) {

// There is no alias corresponding to "main_original". The existing entry
// without a corresponding function is ignored.
const absl::flat_hash_map<FunctionName, FunctionAlias>
updated_function_aliases = UpdateFunctionAliases(
/*function_aliases=*/{{"not_main", "not_main_alias"}}, *module_op);
absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases{
{"not_main", "not_main_alias"}};
UpdateFunctionAliases(function_aliases, *module_op);

EXPECT_THAT(updated_function_aliases, IsEmpty());
EXPECT_THAT(function_aliases, IsEmpty());
}

TEST_F(UpdateFunctionAliasesTest,
Expand All @@ -92,11 +91,29 @@ TEST_F(UpdateFunctionAliasesTest,
ASSERT_TRUE(module_op);

// The existing entry without a corresponding function is ignored.
const absl::flat_hash_map<FunctionName, FunctionAlias>
updated_function_aliases = UpdateFunctionAliases(
/*function_aliases=*/{{"main_original", "main_alias"}}, *module_op);
absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases{
{"main_original", "main_alias"}};
UpdateFunctionAliases(function_aliases, *module_op);

EXPECT_THAT(updated_function_aliases, IsEmpty());
EXPECT_THAT(function_aliases, IsEmpty());
}

TEST_F(UpdateFunctionAliasesTest, FunctionNameNotChanged) {
// @main does not have the "tf._original_func_name" attribute.
OwningOpRef<ModuleOp> module_op = ParseModuleOpString(R"mlir(
func.func private @main_original(%arg: tensor<1x2xf32>) -> (tensor<1x2xf32>) {
return %arg : tensor<1x2xf32>
}
)mlir");
ASSERT_TRUE(module_op);

// The existing entry without a corresponding function is ignored.
absl::flat_hash_map<FunctionName, FunctionAlias> function_aliases{
{"main_original", "main_alias"}};
UpdateFunctionAliases(function_aliases, *module_op);

EXPECT_THAT(function_aliases,
UnorderedElementsAre(Pair("main_original", "main_alias")));
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,19 @@ absl::StatusOr<ModuleOp> ImportSavedModel(
const std::vector<std::string>& signature_keys,
const std::unordered_set<std::string>& tags,
const QuantizationConfig& quantization_config,
const absl::flat_hash_map<FunctionName, FunctionAlias>& function_aliases,
absl::flat_hash_map<FunctionName, FunctionAlias>& function_aliases,
MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) {
TF_ASSIGN_OR_RETURN(
ImportedMlirModuleOp imported_module,
SavedModelToMlirModuleOp(saved_model_path, tags, signature_keys, ctx));
auto [module_op, saved_model_bundle] = std::move(imported_module);

const absl::flat_hash_map<FunctionName, FunctionAlias>
updated_function_aliases =
UpdateFunctionAliases(function_aliases, module_op);
UpdateFunctionAliases(function_aliases, module_op);

// Collect the names of the functions that have aliases so that they may not
// be inlined.
absl::flat_hash_set<std::string> aliased_function_names;
absl::c_for_each(updated_function_aliases, [&](const auto& aliases) {
absl::c_for_each(function_aliases, [&](const auto& aliases) {
return aliased_function_names.insert(aliases.first);
});

Expand All @@ -201,7 +199,7 @@ absl::StatusOr<ExportedModel> CreateExportedModel(
const std::vector<std::string>& signature_keys,
const std::unordered_set<std::string>& tags,
const QuantizationConfig& quantization_config,
const absl::flat_hash_map<FunctionName, FunctionAlias>& function_aliases,
absl::flat_hash_map<FunctionName, FunctionAlias>& function_aliases,
MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND, ModuleOp module_op) {
TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName());
const ExportOptions export_opts = {
Expand All @@ -213,12 +211,10 @@ absl::StatusOr<ExportedModel> CreateExportedModel(
TF_ASSIGN_OR_RETURN(const SmallVector<AssetFileDef> asset_file_defs,
RunExportPasses(export_opts, ctx, module_op));

const absl::flat_hash_map<FunctionName, FunctionAlias>
updated_function_aliases =
UpdateFunctionAliases(function_aliases, module_op);
UpdateFunctionAliases(function_aliases, module_op);

return ConvertMlirModuleToExportedModel(
module_op, checkpoint_dir, updated_function_aliases,
module_op, checkpoint_dir, function_aliases,
{asset_file_defs.begin(), asset_file_defs.end()});
}

Expand Down Expand Up @@ -264,29 +260,35 @@ absl::Status QuantizeStaticRangePtq(
const QuantizationConfig& quantization_config,
const std::vector<std::string>& signature_keys,
const absl::flat_hash_map<std::string, SignatureDef>& signature_def_map,
const absl::flat_hash_map<FunctionName, FunctionAlias>& function_aliases,
const PyFunctionLibrary& py_function_library) {
std::unordered_set<std::string> tags;
tags.insert(quantization_config.tf_saved_model().tags().begin(),
quantization_config.tf_saved_model().tags().end());

std::unique_ptr<MLIRContext> ctx = CreateMlirContextForQuantization();

absl::StatusOr<absl::flat_hash_map<FunctionName, FunctionAlias>>
function_aliases = GetFunctionAliases(src_saved_model_path, tags);
if (!function_aliases.ok()) {
return absl::InternalError(absl::StrCat(
"Failed to get function alias: ", function_aliases.status().message()));
}

TF_ASSIGN_OR_RETURN(
ModuleOp module_op,
ImportSavedModel(src_saved_model_path, signature_keys, tags,
quantization_config, function_aliases, *ctx));
quantization_config, *function_aliases, *ctx));

StaticRangePtqComponent static_range_ptq_component(
ctx.get(), &py_function_library, src_saved_model_path, signature_keys,
tags, signature_def_map, function_aliases);
tags, signature_def_map, *function_aliases);
TF_ASSIGN_OR_RETURN(module_op, static_range_ptq_component.Run(
module_op, quantization_config));

TF_ASSIGN_OR_RETURN(
const ExportedModel post_calibrated_exported_model,
CreateExportedModel(signature_keys, tags, quantization_config,
function_aliases, *ctx, module_op));
*function_aliases, *ctx, module_op));

// Remove the `tpu` tag for exporting because the output quantized model is
// essentially a CPU model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ class StaticRangePtqComponent : public Component {
//
// `signature_keys` specify the signatures that correspond to functions to be
// quantized. `signature_def_map` connects the signature keys to
// `SignatureDef`s. `function_aliases` maps actual function names to the
// function aliases, as defined by the
// `MetaGraphDef::MetaInfoDef::function_aliases` from the input SavedModel.
// `SignatureDef`s.
//
// Returns a non-OK status when the quantization is not successful.
// LINT.IfChange
Expand All @@ -97,7 +95,6 @@ absl::Status QuantizeStaticRangePtq(
const std::vector<std::string>& signature_keys,
const absl::flat_hash_map<std::string, tensorflow::SignatureDef>&
signature_def_map,
const absl::flat_hash_map<std::string, std::string>& function_aliases,
const tensorflow::quantization::PyFunctionLibrary& py_function_library);
// LINT.ThenChange(../python/pywrap_quantization.cc:static_range_ptq)

Expand Down
2 changes: 0 additions & 2 deletions tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ pytype_strict_library(
":pywrap_quantization",
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_py",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib_py",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:save_model",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/saved_model:loader",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,12 @@ PYBIND11_MODULE(pywrap_quantization, m) {
serialized `SignatureDef` mapping for the `signature_def_map_serialized`
argument.

`function_aliases` maps actual function names to the function aliases,
as defined by the `MetaGraphDef::MetaInfoDef::function_aliases` from the
input SavedModel.

Raises `StatusNotOk` exception if when the run was unsuccessful.
)pbdoc",
py::arg("src_saved_model_path"), py::arg("dst_saved_model_path"),
py::arg("quantization_config_serialized"), py::kw_only(),
py::arg("signature_keys"), py::arg("signature_def_map_serialized"),
py::arg("function_aliases"), py::arg("py_function_library"));
py::arg("py_function_library"));
// LINT.ThenChange(pywrap_quantization.pyi:static_range_ptq)

// If the function signature changes, likely its corresponding .pyi type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def static_range_ptq(
*,
signature_keys: list[str],
signature_def_map_serialized: dict[str, bytes],
function_aliases: dict[str, str],
py_function_library: py_function_lib.PyFunctionLibrary,
) -> Any: ... # Status

Expand Down