-
Notifications
You must be signed in to change notification settings - Fork 74.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tool to convert TF SavedModel to StableHLO
Here is the signature of the provide API: ```c++ // Converts a TensorFlow model (either from a SavedModel or an MLIR module) to a // StableHLO MLIR module. // // Args: // input_path: The path to the input TensorFlow SavedModel or MLIR module. // context: The MLIR context to use for parsing or creating the MLIR module. // exported_model_signatures: A comma-separated list of exported model // signatures (functions) to convert. // tag_names: A comma-separated list of tag names used for loading SavedModel. // input_arg_shapes_str: A string representation of input argument shapes. // Shapes for different tensors are separated by ':', and dimension sizes for // the same tensor are separated by ','. For example, // 'input-arg-shapes=1,2::1,?' expresses input arguments with shapes [1,2], // [] and [1,?]. // is_input_mlir_module: If true, `input_path` is treated as an MLIR // module instead of a SavedModel. // // Returns: // An absl::StatusOr containing the converted StableHLO MLIR module on // success, or an absl::Status with an error message on failure. absl::StatusOr<OwningOpRef<ModuleOp>> TfToStablehlo( absl::string_view input_path, MLIRContext* context, absl::string_view exported_model_signatures, absl::string_view tag_names, absl::string_view input_arg_shapes_str, bool is_input_mlir_module = false); ``` Reverts 2ba594d FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11721 from inailuig:mpicollectives_pytype 3924cc0fbbb63e9503f38a59aede3b8e817b17fa PiperOrigin-RevId: 625351420
- Loading branch information
1 parent
b251941
commit f67804f
Showing
48 changed files
with
1,112 additions
and
247 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary") | ||
load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") | ||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") | ||
load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") | ||
|
||
package_group( | ||
name = "internal_visibility_allowlist_package", | ||
packages = [ | ||
"//learning/brain/mlir/quantization/stablehlo/python/integration_test/...", | ||
"//tensorflow/compiler/mlir/lite/...", | ||
"//tensorflow/compiler/mlir/quantization/...", | ||
"//tensorflow/compiler/mlir/tf2xla/transforms/...", | ||
"//tensorflow/lite/...", | ||
"//third_party/cloud_tpu/inference_converter/...", # TPU Inference Converter V1 | ||
] + internal_visibility_allowlist(), | ||
) | ||
|
||
package( | ||
# copybara:uncomment default_applicable_licenses = ["@stablehlo//:license"], | ||
default_visibility = [ | ||
":internal_visibility_allowlist_package", | ||
"//tensorflow:__pkg__", | ||
], | ||
licenses = ["notice"], | ||
) | ||
|
||
cc_library( | ||
name = "tf_to_stablehlo", | ||
srcs = [ | ||
"tf_to_stablehlo.cc", | ||
], | ||
hdrs = [ | ||
"tf_to_stablehlo.h", | ||
], | ||
compatible_with = get_compatible_with_portable(), | ||
deps = [ | ||
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:saved_model_import", | ||
"//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", | ||
"//tensorflow/compiler/mlir/tensorflow/transforms:shape_inference_pass", | ||
"//tensorflow/core:core_cpu_base", | ||
"@com_google_absl//absl/algorithm:container", | ||
"@com_google_absl//absl/container:flat_hash_set", | ||
"@com_google_absl//absl/status", | ||
"@com_google_absl//absl/status:statusor", | ||
"@com_google_absl//absl/strings", | ||
"@com_google_absl//absl/strings:string_view", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Parser", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:Transforms", | ||
"@local_tsl//tsl/platform:errors", | ||
"@local_tsl//tsl/platform:statusor", | ||
], | ||
alwayslink = True, | ||
) | ||
|
||
tf_cc_binary( | ||
name = "tf-to-stablehlo-translate", | ||
srcs = [ | ||
"tf_to_stablehlo_translate.cc", | ||
], | ||
visibility = [":internal_visibility_allowlist_package"], | ||
deps = [ | ||
":tf_to_stablehlo", | ||
"//tensorflow/compiler/mlir:init_mlir", | ||
"//tensorflow/compiler/mlir/tensorflow", | ||
"@com_google_absl//absl/status", | ||
"@com_google_absl//absl/status:statusor", | ||
"@com_google_absl//absl/strings", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:AllPassesAndDialects", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Parser", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Support", | ||
], | ||
) | ||
|
||
glob_lit_tests( | ||
name = "all_tests", | ||
data = [":test_utilities"], | ||
# TODO: b/288344501 - Enable OSS tests again when stable-quant-opt works well. | ||
default_tags = [ | ||
"no_oss", | ||
"no_pip", | ||
], | ||
driver = "//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo:run_lit.sh", | ||
size_override = { | ||
}, | ||
tags_override = { | ||
}, | ||
test_file_exts = ["mlir"], | ||
) | ||
|
||
# Bundle together all of the test utilities that are used by tests. | ||
filegroup( | ||
name = "test_utilities", | ||
testonly = True, | ||
data = [ | ||
":tf-to-stablehlo-translate", | ||
"@llvm-project//llvm:FileCheck", | ||
"@llvm-project//llvm:not", | ||
"@llvm-project//mlir:run_lit.sh", | ||
], | ||
) |
Oops, something went wrong.