diff --git a/tensorflow_lite_support/cc/common.h b/tensorflow_lite_support/cc/common.h index 6514c4003..c522d789c 100644 --- a/tensorflow_lite_support/cc/common.h +++ b/tensorflow_lite_support/cc/common.h @@ -101,6 +101,9 @@ enum class TfLiteSupportStatus { // none was found or it was empty. // E.g. current task requires labels but none were found. kMetadataMissingLabelsError, + // The ProcessingUnit for tokenizer is not correctly configured. + // E.g BertTokenizer doesn't have a valid vocab file associated. + kMetadataInvalidTokenizerError, // Input tensor(s) error codes. diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h index 86cae91b0..6df40e51d 100644 --- a/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h +++ b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h @@ -97,9 +97,6 @@ class BertQuestionAnswerer : public QuestionAnswerer { static constexpr int kNumLiteThreads = 4; static constexpr bool kUseLowerCase = true; - // Constant for model metadata - static constexpr int kTokenizerProcessUnitIndex = 0; - static StatusOr> CreateQuestionAnswererWithMetadata( const std::string& path_to_model_with_metadata); diff --git a/tensorflow_lite_support/cc/task/vision/BUILD b/tensorflow_lite_support/cc/task/vision/BUILD index 931204198..d426486fb 100644 --- a/tensorflow_lite_support/cc/task/vision/BUILD +++ b/tensorflow_lite_support/cc/task/vision/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -60,6 +61,8 @@ cc_library( "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", "//tensorflow_lite_support/cc/task/vision/utils:score_calibration", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -69,7 +72,6 @@ cc_library( "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/core/api", - "@org_tensorflow//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", ], ) @@ -93,6 +95,7 @@ cc_library( "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_cc", "//tensorflow_lite_support/metadata/cc:metadata_extractor", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -101,6 +104,5 @@ cc_library( "@flatbuffers", "@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/core/api", - "@org_tensorflow//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", ], ) diff --git a/tensorflow_lite_support/cc/task/vision/image_classifier.cc b/tensorflow_lite_support/cc/task/vision/image_classifier.cc index 9618ab8c2..39c23ccd2 100644 --- a/tensorflow_lite_support/cc/task/vision/image_classifier.cc +++ b/tensorflow_lite_support/cc/task/vision/image_classifier.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/integral_types.h" @@ -30,6 +29,8 @@ limitations under the License. #include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" namespace tflite { namespace support { diff --git a/tensorflow_lite_support/cc/task/vision/image_segmenter.cc b/tensorflow_lite_support/cc/task/vision/image_segmenter.cc index 4cc4110f1..684c17ae2 100644 --- a/tensorflow_lite_support/cc/task/vision/image_segmenter.cc +++ b/tensorflow_lite_support/cc/task/vision/image_segmenter.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/status_macros.h" @@ -31,6 +30,7 @@ limitations under the License. #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" namespace tflite { namespace support { diff --git a/tensorflow_lite_support/cc/task/vision/object_detector.cc b/tensorflow_lite_support/cc/task/vision/object_detector.cc index 61b5d84d7..1c2f87cd6 100644 --- a/tensorflow_lite_support/cc/task/vision/object_detector.cc +++ b/tensorflow_lite_support/cc/task/vision/object_detector.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" namespace tflite { diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc index c9385d155..ec632ab78 100644 --- a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc +++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc @@ -24,7 +24,6 @@ limitations under the License. namespace tflite::support::text::tokenizer { -using ::absl::StatusCode; using ::tflite::ProcessUnit; using ::tflite::SentencePieceTokenizerOptions; using ::tflite::support::CreateStatusWithPayload; @@ -35,7 +34,9 @@ StatusOr> CreateTokenizerFromMetadata( metadata_extractor.GetInputProcessUnit(kTokenizerProcessUnitIndex); if (tokenizer_process_unit == nullptr) { return CreateStatusWithPayload( - StatusCode::kNotFound, "No input process unit found from metadata."); + absl::StatusCode::kInvalidArgument, + "No input process unit found from metadata.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); } if (tokenizer_process_unit->options_type() == ProcessUnitOptions_BertTokenizerOptions) { @@ -45,8 +46,9 @@ StatusOr> CreateTokenizerFromMetadata( if (options->vocab_file() == nullptr || options->vocab_file()->size() < 1 || options->vocab_file()->Get(0)->name() == nullptr) { return CreateStatusWithPayload( - StatusCode::kInvalidArgument, - "Invalid vocab_file from input process unit."); + absl::StatusCode::kInvalidArgument, + "Invalid vocab_file from input process unit.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); } ASSIGN_OR_RETURN(vocab_buffer, metadata_extractor.GetAssociatedFile( @@ -64,8 +66,9 @@ StatusOr> CreateTokenizerFromMetadata( options->sentencePiece_model()->size() < 1 || options->sentencePiece_model()->Get(0)->name() == nullptr) { return CreateStatusWithPayload( - StatusCode::kInvalidArgument, - "Invalid sentencePiece_model from input process unit."); + absl::StatusCode::kInvalidArgument, + "Invalid sentencePiece_model from input process unit.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); } ASSIGN_OR_RETURN( model_buffer, @@ -74,12 +77,13 @@ StatusOr> CreateTokenizerFromMetadata( // TODO(b/160647204): Extract sentence piece model vocabulary return absl::make_unique(model_buffer.data(), model_buffer.size()); + } else { + return CreateStatusWithPayload( + absl::StatusCode::kNotFound, + absl::StrCat("Incorrect options_type:", + tokenizer_process_unit->options_type()), + TfLiteSupportStatus::kMetadataInvalidTokenizerError); } - - return CreateStatusWithPayload( - StatusCode::kNotFound, - absl::StrCat("Incorrect options_type:", - tokenizer_process_unit->options_type())); } } // namespace tflite::support::text::tokenizer diff --git a/tensorflow_lite_support/metadata/java/BUILD b/tensorflow_lite_support/metadata/java/BUILD index cae0c9d17..9263c681f 100644 --- a/tensorflow_lite_support/metadata/java/BUILD +++ b/tensorflow_lite_support/metadata/java/BUILD @@ -14,7 +14,7 @@ METADATA_SRCS = glob( ) android_library( - name = "tensorflow-lite-support-metadata", + name = "tensorflowlite_support_metadata", srcs = METADATA_SRCS, manifest = "AndroidManifest.xml", deps = [ @@ -24,8 +24,13 @@ android_library( ], ) +alias( + name = "tensorflow-lite-support-metadata", + actual = ":tensorflowlite_support_metadata", +) + java_library( - name = "tensorflow-lite-support-metadata-lib", + name = "tensorflowlite_support_metadata_lib", srcs = METADATA_SRCS, javacopts = JAVACOPTS, resource_jars = [ @@ -38,3 +43,8 @@ java_library( "@org_checkerframework_qual", ], ) + +alias( + name = "tensorflow-lite-support-metadata-lib", + actual = ":tensorflowlite_support_metadata_lib", +) diff --git a/tensorflow_lite_support/metadata/java/src/javatests/BUILD b/tensorflow_lite_support/metadata/java/src/javatests/BUILD index 4d45eecf0..84ab680c4 100644 --- a/tensorflow_lite_support/metadata/java/src/javatests/BUILD +++ b/tensorflow_lite_support/metadata/java/src/javatests/BUILD @@ -30,7 +30,7 @@ android_local_test( ":test_lib", # unuseddeps: keep "//tensorflow_lite_support/metadata:metadata_schema_fbs_android", "//tensorflow_lite_support/metadata:schema_fbs_android", - "//tensorflow_lite_support/metadata/java:tensorflow-lite-support-metadata", + "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata", "//third_party/java/checker_framework_annotations", "//third_party/java/jakarta_commons_io", "//third_party/java/truth:truth-android", @@ -47,7 +47,7 @@ android_local_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.support.metadata.ByteBufferChannelTest", deps = [ - "//tensorflow_lite_support/metadata/java:tensorflow-lite-support-metadata", + "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata", "//third_party/java/truth:truth-android", ], ) @@ -64,7 +64,7 @@ android_local_test( test_class = "org.tensorflow.lite.support.metadata.ZipFileTest", deps = [ ":test_lib", - "//tensorflow_lite_support/metadata/java:tensorflow-lite-support-metadata", + "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata", "//third_party/java/jakarta_commons_io", "//third_party/java/truth:truth-android", ], @@ -79,7 +79,7 @@ android_local_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.support.metadata.BoundedInputStreamTest", deps = [ - "//tensorflow_lite_support/metadata/java:tensorflow-lite-support-metadata", + "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata", "//third_party/java/truth:truth-android", ], ) @@ -92,7 +92,7 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.support.metadata.MetadataParserTest", deps = [ - "//tensorflow_lite_support/metadata/java:tensorflow-lite-support-metadata-lib", + "//tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata_lib", "//third_party/java/junit", "//third_party/java/truth", ], diff --git a/tensorflow_lite_support/tools/ci_build/build_all.sh b/tensorflow_lite_support/tools/ci_build/build_all.sh index da2775896..0a82a26c4 100644 --- a/tensorflow_lite_support/tools/ci_build/build_all.sh +++ b/tensorflow_lite_support/tools/ci_build/build_all.sh @@ -18,7 +18,7 @@ set -ex bazel build -c opt --config=monolithic \ - //tensorflow_lite_support/java:tensorflow-lite-support \ + //tensorflow_lite_support/java:tensorflowlite_support \ //tensorflow_lite_support/codegen/python:codegen \ //tensorflow_lite_support/metadata:metadata \ - //tensorflow_lite_support/metadata/java:tensorflow-lite-support-metadata-lib + //tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata_lib