diff --git a/WORKSPACE b/WORKSPACE index e43bc6b79..48530a33b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -12,12 +12,35 @@ http_archive( ], ) +# Apple and Swift rules. +# https://github.com/bazelbuild/rules_apple/releases +http_archive( + name = "build_bazel_rules_apple", + sha256 = "ee9e6073aeb5a65c100cb9c44b0017c937706a4ae03176e14a7e78620a198079", + strip_prefix = "rules_apple-5131f3d46794bf227d296c82f30c2499c9de3c5b", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_apple/archive/5131f3d46794bf227d296c82f30c2499c9de3c5b.tar.gz", + "https://github.com/bazelbuild/rules_apple/archive/5131f3d46794bf227d296c82f30c2499c9de3c5b.tar.gz", + ], +) + +# https://github.com/bazelbuild/rules_swift/releases +http_archive( + name = "build_bazel_rules_swift", + sha256 = "d0833bc6dad817a367936a5f902a0c11318160b5e80a20ece35fb85a5675c886", + strip_prefix = "rules_swift-3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_swift/archive/3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8.tar.gz", + "https://github.com/bazelbuild/rules_swift/archive/3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8.tar.gz", + ], +) + http_archive( name = "org_tensorflow", - sha256 = "bb8b10da8184ce747f0348ea5b0d0aaf9e9bbe63cf68363d0e1bcdb72b4d3315", - strip_prefix = "tensorflow-5d49dc5526324443931a33cc84d66c8bcae9cea2", + sha256 = "972ec45352161e4308a1b203956eedfb56e22cc6ce4f4ec95f7b087aeb00559e", + strip_prefix = "tensorflow-2.3.0-rc0", urls = [ - "https://github.com/tensorflow/tensorflow/archive/5d49dc5526324443931a33cc84d66c8bcae9cea2.zip", # 2020-06-13 + "https://github.com/tensorflow/tensorflow/archive/v2.3.0-rc0.tar.gz", ], ) @@ -46,33 +69,105 @@ http_archive( ], ) +http_archive( + name = "six_archive", + build_file = "//third_party:six.BUILD", + sha256 = "d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73", + strip_prefix = "six-1.12.0", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz", + "https://pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz", + ], +) + +http_archive( + name = "com_google_sentencepiece", + strip_prefix = "sentencepiece-1.0.0", + sha256 = "c05901f30a1d0ed64cbcf40eba08e48894e1b0e985777217b7c9036cac631346", + urls = [ + "https://github.com/google/sentencepiece/archive/1.0.0.zip", + ], +) + +http_archive( + name = "org_tensorflow_text", + sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8", + strip_prefix = "text-2.2.0", + urls = [ + "https://github.com/tensorflow/text/archive/v2.2.0.zip", + ], + patches = ["@//third_party:tensorflow_text_fix_local_config_tf.patch"], + patch_args = ["-p1"], + repo_mapping = {"@com_google_re2": "@com_googlesource_code_re2"}, +) + +http_archive( + name = "com_googlesource_code_re2", + sha256 = "d070e2ffc5476c496a6a872a6f246bfddce8e7797d6ba605a7c8d72866743bf9", + strip_prefix = "re2-506cfa4bffd060c06ec338ce50ea3468daa6c814", + urls = [ + "https://github.com/google/re2/archive/506cfa4bffd060c06ec338ce50ea3468daa6c814.tar.gz", + ], +) + +# ABSL cpp library lts_2020_02_25 +# Needed for absl/status http_archive( name = "com_google_absl", build_file = "//third_party:com_google_absl.BUILD", - sha256 = "f368a8476f4e2e0eccf8a7318b98dafbe30b2600f4e3cf52636e5eb145aba06a", - strip_prefix = "abseil-cpp-df3ea785d8c30a9503321a3d35ee7d35808f190d", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/20200225.tar.gz", + ], + # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. + patches = [ + "@//third_party:com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff" ], + patch_args = [ + "-p1", + ], + strip_prefix = "abseil-cpp-20200225", + sha256 = "728a813291bdec2aa46eab8356ace9f75ac2ed9dfe2df5ab603c4e6c09f1c353" ) http_archive( - name = "six_archive", - build_file = "//third_party:six.BUILD", - sha256 = "d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73", - strip_prefix = "six-1.12.0", + name = "com_google_glog", + sha256 = "1ee310e5d0a19b9d584a855000434bb724aa744745d5b8ab1855c85bff8a8e21", + strip_prefix = "glog-028d37889a1e80e8a07da1b8945ac706259e5fd8", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz", - "https://pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz", + "https://mirror.bazel.build/github.com/google/glog/archive/028d37889a1e80e8a07da1b8945ac706259e5fd8.tar.gz", + "https://github.com/google/glog/archive/028d37889a1e80e8a07da1b8945ac706259e5fd8.tar.gz", ], ) -load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") -flatbuffers() +http_archive( + name = "zlib", + build_file = "//third_party:zlib.BUILD", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + urls = [ + "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", + "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 + ], +) + +http_archive( + name = "org_libzip", + build_file = "//third_party:libzip.BUILD", + sha256 = "a5d22f0c87a2625450eaa5e10db18b8ee4ef17042102d04c62e311993a2ba363", + strip_prefix = "libzip-rel-1-5-1", + urls = [ + # Bazel does not like the official download link at libzip.org, + # so use the GitHub release tag. + "https://mirror.bazel.build/github.com/nih-at/libzip/archive/rel-1-5-1.zip", + "https://github.com/nih-at/libzip/archive/rel-1-5-1.zip", + ], +) +load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") + +flatbuffers() # Set up TF. load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") tf_workspace(tf_repo_name="@org_tensorflow") diff --git a/tensorflow_lite_support/cc/BUILD b/tensorflow_lite_support/cc/BUILD new file mode 100644 index 000000000..4760dde02 --- /dev/null +++ b/tensorflow_lite_support/cc/BUILD @@ -0,0 +1,17 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "common", + srcs = [ + "common.cc", + ], + hdrs = ["common.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) diff --git a/tensorflow_lite_support/cc/common.cc b/tensorflow_lite_support/cc/common.cc new file mode 100644 index 000000000..47dd3bcc6 --- /dev/null +++ b/tensorflow_lite_support/cc/common.cc @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/common.h" + +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" + +namespace tflite { +namespace support { + +absl::Status CreateStatusWithPayload(absl::StatusCode canonical_code, + absl::string_view message, + TfLiteSupportStatus tfls_code) { + // NOTE: Ignores `message` if the canonical code is ok. + absl::Status status = absl::Status(canonical_code, message); + // NOTE: Does nothing if the canonical code is ok. + status.SetPayload(kTfLiteSupportPayload, absl::Cord(absl::StrCat(tfls_code))); + return status; +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/common.h b/tensorflow_lite_support/cc/common.h new file mode 100644 index 000000000..6514c4003 --- /dev/null +++ b/tensorflow_lite_support/cc/common.h @@ -0,0 +1,163 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace tflite { +namespace support { + +// Name (aka type URL key) of the `absl::Status` payload which contains a +// stringified `TfLiteSupportStatus` code (see below). +inline constexpr absl::string_view kTfLiteSupportPayload = + "tflite::support::TfLiteSupportStatus"; + +// Error codes for TensorFlow Lite Support (TFLS) C++ APIs. +// +// Such codes capture errors encountered in the TFLS layer. They complement all +// the other type of errors that occur in the lower-level TF Lite codebase (see +// `TfLiteStatus` codes). +// +// At runtime, such codes are meant to be attached (where applicable) to a +// `absl::Status` in a key-value manner with `kTfLiteSupportPayload` as key and +// stringifed error code as value (aka payload). This logic is encapsulated in +// the `CreateStatusWithPayload` helper below for convenience. +// +// The returned status includes: +// 1. The canonical error code (INVALID_ARGUMENT) +// 2. The fine-grained error message ("Invalid metadata ...") +// 3. The specific TFLS code as a payload (kMetadataInvalidSchemaVersionError) +enum class TfLiteSupportStatus { + // Generic error codes. + + // Success. + kOk = 0, + // Unspecified error. + kError = 1, + // Invalid argument specified. + kInvalidArgumentError = 2, + // Invalid FlatBuffer file or buffer specified. + kInvalidFlatBufferError = 3, + + // File I/O error codes. + + // No such file. + kFileNotFoundError = 100, + // Permission issue. + kFilePermissionDeniedError, + // I/O error when reading file. + kFileReadError, + // I/O error when mmap-ing file. + kFileMmapError, + + // TensorFlow Lite metadata error codes. + + // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. + kMetadataInvalidSchemaVersionError = 200, + // No such associated file within metadata, or file has not been packed. + kMetadataAssociatedFileNotFoundError, + // ZIP I/O error when unpacking an associated file. + kMetadataAssociatedFileZipError, + // Inconsistency error between the metadata and actual TF Lite model. + // E.g.: number of labels and output tensor values differ. + kMetadataInconsistencyError, + // Invalid process units specified. + // E.g.: multiple ProcessUnits with the same type for a given tensor. + kMetadataInvalidProcessUnitsError, + // Inconsistency error with the number of labels. + // E.g.: label files for different locales have a different number of labels. + kMetadataNumLabelsMismatchError, + // Score calibration parameters parsing error. + // E.g.: too many parameters provided in the corresponding associated file. + kMetadataMalformedScoreCalibrationError, + // Unexpected number of subgraphs for the current task. + // E.g.: image classification expects a single subgraph. + kMetadataInvalidNumSubgraphsError, + // A given tensor requires NormalizationOptions but none were found. + // E.g.: float input tensor requires normalization to preprocess input images. + kMetadataMissingNormalizationOptionsError, + // Invalid ContentProperties specified. + // E.g. expected ImageProperties, got BoundingBoxProperties. + kMetadataInvalidContentPropertiesError, + // Metadata is mandatory but was not found. + // E.g. current task requires TFLite Model Metadata but none was found. + kMetadataNotFoundError, + // Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but + // none was found or it was empty. + // E.g. current task requires labels but none were found. + kMetadataMissingLabelsError, + + // Input tensor(s) error codes. + + // Unexpected number of input tensors for the current task. + // E.g. current task expects a single input tensor. + kInvalidNumInputTensorsError = 300, + // Unexpected input tensor dimensions for the current task. + // E.g.: only 4D input tensors supported. + kInvalidInputTensorDimensionsError, + // Unexpected input tensor type for the current task. + // E.g.: current task expects a uint8 pixel image as input. + kInvalidInputTensorTypeError, + // Unexpected input tensor bytes size. + // E.g.: size in bytes does not correspond to the expected number of pixels. + kInvalidInputTensorSizeError, + // No correct input tensor found for the model. + // E.g.: input tensor name is not part of the text model's input tensors. + kInputTensorNotFoundError, + + // Output tensor(s) error codes. + + // Unexpected output tensor dimensions for the current task. + // E.g.: only a batch size of 1 is supported. + kInvalidOutputTensorDimensionsError = 400, + // Unexpected input tensor type for the current task. + // E.g.: multi-head model with different output tensor types. + kInvalidOutputTensorTypeError, + // No correct output tensor found for the model. + // E.g.: output tensor name is not part of the text model's output tensors. + kOutputTensorNotFoundError, + // Unexpected number of output tensors for the current task. + // E.g.: current task expects a single output tensor. + kInvalidNumOutputTensorsError, + + // Image processing error codes. + + // Unspecified image processing failures. + kImageProcessingError = 500, + // Unexpected input or output buffer metadata. + // E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. + kImageProcessingInvalidArgumentError, + // Image processing operation failures. + // E.g. libyuv rotation failed for an unknown reason. + kImageProcessingBackendError, +}; + +// Convenience helper to create an `absl::Status` augmented with the +// fine-grained `tfls_code` attached as payload under the +// `kTfLiteSupportPayload` type URL key. +// +// This should only be used for non-ok codes since otherwise it does nothing +// more than returning an object identical to an OK status. See `absl::Status` +// for more details. +absl::Status CreateStatusWithPayload( + absl::StatusCode canonical_code, absl::string_view message, + TfLiteSupportStatus tfls_code = TfLiteSupportStatus::kError); + +} // namespace support +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ diff --git a/tensorflow_lite_support/cc/port/BUILD b/tensorflow_lite_support/cc/port/BUILD new file mode 100644 index 000000000..ea18c7f35 --- /dev/null +++ b/tensorflow_lite_support/cc/port/BUILD @@ -0,0 +1,37 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "statusor", + hdrs = [ + "statusor.h", + ], + deps = [ + "//tensorflow_lite_support/cc/port/default:statusor", + ], +) + +cc_library( + name = "status_macros", + hdrs = [ + "status_macros.h", + ], + deps = [ + "//tensorflow_lite_support/cc/port/default:status_macros", + ], +) + +cc_library( + name = "tflite_wrapper", + hdrs = [ + "tflite_wrapper.h", + ], + deps = ["//tensorflow_lite_support/cc/port/default:tflite_wrapper"], +) + +cc_library( + name = "integral_types", + hdrs = ["integral_types.h"], +) diff --git a/tensorflow_lite_support/cc/port/default/BUILD b/tensorflow_lite_support/cc/port/default/BUILD new file mode 100644 index 000000000..3f6e9e939 --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/BUILD @@ -0,0 +1,50 @@ +package( + default_visibility = [ + "//tensorflow_lite_support/cc/port:__pkg__", + "//tensorflow_lite_support/cc/test:__pkg__", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "statusor", + srcs = ["statusor.cc"], + hdrs = [ + "statusor.h", + "statusor_internals.h", + ], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", + "@com_google_absl//absl/utility", + "@com_google_glog//:glog", + ], +) + +cc_library( + name = "status_macros", + hdrs = [ + "status_macros.h", + ], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "tflite_wrapper", + srcs = ["tflite_wrapper.cc"], + hdrs = [ + "tflite_wrapper.h", + ], + deps = [ + "//tensorflow_lite_support/cc/port:status_macros", + "@com_google_absl//absl/status", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_cc_proto", + ], +) diff --git a/tensorflow_lite_support/cc/port/default/status_macros.h b/tensorflow_lite_support/cc/port/default/status_macros.h new file mode 100644 index 000000000..47476c9ce --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/status_macros.h @@ -0,0 +1,215 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is forked from absl. + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_ + +#include "absl/base/optimization.h" +#include "absl/status/status.h" + +// Evaluates an expression that produces a `absl::Status`. If the status is not +// ok, returns it from the current function. +// +// For example: +// absl::Status MultiStepFunction() { +// RETURN_IF_ERROR(Function(args...)); +// RETURN_IF_ERROR(foo.Method(args...)); +// return absl::OkStatus(); +// } +#define RETURN_IF_ERROR(expr) \ + STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (::tflite::support::status_macro_internal::StatusAdaptorForMacros \ + status_macro_internal_adaptor = {(expr)}) { \ + } else /* NOLINT */ \ + return status_macro_internal_adaptor.Consume() + +// Executes an expression `rexpr` that returns a `tflite::support::StatusOr`. +// On OK, moves its value into the variable defined by `lhs`, otherwise returns +// from the current function. By default the error status is returned +// unchanged, but it may be modified by an `error_expression`. If there is an +// error, `lhs` is not evaluated; thus any side effects that `lhs` may have +// only occur in the success case. +// +// Interface: +// +// ASSIGN_OR_RETURN(lhs, rexpr) +// ASSIGN_OR_RETURN(lhs, rexpr, error_expression); +// +// WARNING: if lhs is parenthesized, the parentheses are removed. See examples +// for more details. +// +// WARNING: expands into multiple statements; it cannot be used in a single +// statement (e.g. as the body of an if statement without {})! +// +// Example: Declaring and initializing a new variable (ValueType can be anything +// that can be initialized with assignment, including references): +// ASSIGN_OR_RETURN(ValueType value, MaybeGetValue(arg)); +// +// Example: Assigning to an existing variable: +// ValueType value; +// ASSIGN_OR_RETURN(value, MaybeGetValue(arg)); +// +// Example: Assigning to an expression with side effects: +// MyProto data; +// ASSIGN_OR_RETURN(*data.mutable_str(), MaybeGetValue(arg)); +// // No field "str" is added on error. +// +// Example: Assigning to a std::unique_ptr. +// ASSIGN_OR_RETURN(std::unique_ptr ptr, MaybeGetPtr(arg)); +// +// Example: Assigning to a map. Because of C preprocessor +// limitation, the type used in ASSIGN_OR_RETURN cannot contain comma, so +// wrap lhs in parentheses: +// ASSIGN_OR_RETURN((absl::flat_hash_map my_map), GetMap()); +// Or use auto if the type is obvious enough: +// ASSIGN_OR_RETURN(const auto& my_map, GetMapRef()); +// +// Example: Assigning to structured bindings. The same situation with comma as +// in map, so wrap the statement in parentheses. +// ASSIGN_OR_RETURN((const auto& [first, second]), GetPair()); + +#define ASSIGN_OR_RETURN(...) \ + STATUS_MACROS_IMPL_GET_VARIADIC_((__VA_ARGS__, \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \ + (__VA_ARGS__) + +// ================================================================= +// == Implementation details, do not rely on anything below here. == +// ================================================================= + +// Some builds do not support C++14 fully yet, using C++11 constexpr technique. +constexpr bool TFLSHasPotentialConditionalOperator(const char* lhs, int index) { + return (index == -1 + ? false + : (lhs[index] == '?' + ? true + : TFLSHasPotentialConditionalOperator(lhs, index - 1))); +} + +// MSVC incorrectly expands variadic macros, splice together a macro call to +// work around the bug. +#define STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME +#define STATUS_MACROS_IMPL_GET_VARIADIC_(args) \ + STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args + +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, _) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ + STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ + error_expression) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ + error_expression) \ + auto statusor = (rexpr); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + ::absl::Status _(std::move(statusor).status()); \ + (void)_; /* error_expression is allowed to not use this variable */ \ + return (error_expression); \ + } \ + { \ + static_assert( \ + #lhs[0] != '(' || #lhs[sizeof(#lhs) - 2] != ')' || \ + !TFLSHasPotentialConditionalOperator(#lhs, sizeof(#lhs) - 2), \ + "Identified potential conditional operator, consider not " \ + "using ASSIGN_OR_RETURN"); \ + } \ + STATUS_MACROS_IMPL_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \ + std::move(statusor).value() + +// Internal helpers for macro expansion. +#define STATUS_MACROS_IMPL_EAT(...) +#define STATUS_MACROS_IMPL_REM(...) __VA_ARGS__ +#define STATUS_MACROS_IMPL_EMPTY() + +// Internal helpers for emptyness arguments check. +#define STATUS_MACROS_IMPL_IS_EMPTY_INNER(...) \ + STATUS_MACROS_IMPL_IS_EMPTY_INNER_I(__VA_ARGS__, 0, 1) +#define STATUS_MACROS_IMPL_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty + +#define STATUS_MACROS_IMPL_IS_EMPTY(...) \ + STATUS_MACROS_IMPL_IS_EMPTY_I(__VA_ARGS__) +#define STATUS_MACROS_IMPL_IS_EMPTY_I(...) \ + STATUS_MACROS_IMPL_IS_EMPTY_INNER(_, ##__VA_ARGS__) + +// Internal helpers for if statement. +#define STATUS_MACROS_IMPL_IF_1(_Then, _Else) _Then +#define STATUS_MACROS_IMPL_IF_0(_Then, _Else) _Else +#define STATUS_MACROS_IMPL_IF(_Cond, _Then, _Else) \ + STATUS_MACROS_IMPL_CONCAT_(STATUS_MACROS_IMPL_IF_, _Cond) \ + (_Then, _Else) + +// Expands to 1 if the input is parenthesized. Otherwise expands to 0. +#define STATUS_MACROS_IMPL_IS_PARENTHESIZED(...) \ + STATUS_MACROS_IMPL_IS_EMPTY(STATUS_MACROS_IMPL_EAT __VA_ARGS__) + +// If the input is parenthesized, removes the parentheses. Otherwise expands to +// the input unchanged. +#define STATUS_MACROS_IMPL_UNPARENTHESIZE_IF_PARENTHESIZED(...) \ + STATUS_MACROS_IMPL_IF(STATUS_MACROS_IMPL_IS_PARENTHESIZED(__VA_ARGS__), \ + STATUS_MACROS_IMPL_REM, STATUS_MACROS_IMPL_EMPTY()) \ + __VA_ARGS__ + +// Internal helper for concatenating macro values. +#define STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y +#define STATUS_MACROS_IMPL_CONCAT_(x, y) STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) + +// The GNU compiler emits a warning for code like: +// +// if (foo) +// if (bar) { } else baz; +// +// because it thinks you might want the else to bind to the first if. This +// leads to problems with code like: +// +// if (do_expr) RETURN_IF_ERROR(expr) << "Some message"; +// +// The "switch (0) case 0:" idiom is used to suppress this. +#define STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + switch (0) \ + case 0: \ + default: // NOLINT + +namespace tflite { +namespace support { +namespace status_macro_internal { + +// Provides a conversion to bool so that it can be used inside an if statement +// that declares a variable. +class StatusAdaptorForMacros { + public: + StatusAdaptorForMacros(const ::absl::Status& status) // NOLINT + : status_(status) {} + + StatusAdaptorForMacros(::absl::Status&& status) // NOLINT + : status_(std::move(status)) {} + + StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; + StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete; + + explicit operator bool() const { return ABSL_PREDICT_TRUE(status_.ok()); } + + ::absl::Status&& Consume() { return std::move(status_); } + + private: + ::absl::Status status_; +}; + +} // namespace status_macro_internal +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_ diff --git a/tensorflow_lite_support/cc/port/default/statusor.cc b/tensorflow_lite_support/cc/port/default/statusor.cc new file mode 100644 index 000000000..7394355ea --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/statusor.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is forked from absl. + +#include "tensorflow_lite_support/cc/port/default/statusor.h" + +#include + +#include +#include "absl/strings/str_cat.h" + +namespace tflite { +namespace support { + +BadStatusOrAccess::BadStatusOrAccess(absl::Status status) + : status_(std::move(status)) {} + +BadStatusOrAccess::~BadStatusOrAccess() = default; + +const char* BadStatusOrAccess::what() const noexcept { + return "Bad StatusOr access"; +} + +const absl::Status& BadStatusOrAccess::status() const { return status_; } + +namespace internal_statusor { + +void Helper::HandleInvalidStatusCtorArg(absl::Status* status) { + const char* kMessage = + "An OK status is not a valid constructor argument to StatusOr"; + LOG(DFATAL) << kMessage; + // In optimized builds, we will fall back to ::util::error::INTERNAL. + *status = absl::InternalError(kMessage); +} + +void Helper::Crash(const absl::Status& status) { + LOG(FATAL) << "Attempting to fetch value instead of handling error " + << status; +} + +void ThrowBadStatusOrAccess(absl::Status status) { +#ifdef ABSL_HAVE_EXCEPTIONS + throw BadStatusOrAccess(std::move(status)); +#else + LOG(FATAL) << "Attempting to fetch value instead of handling error " + << status; +#endif +} + +} // namespace internal_statusor +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/port/default/statusor.h b/tensorflow_lite_support/cc/port/default/statusor.h new file mode 100644 index 000000000..96f6bb3e0 --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/statusor.h @@ -0,0 +1,570 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is forked from absl. + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/types/variant.h" +#include "absl/utility/utility.h" +#include "tensorflow_lite_support/cc/port/default/statusor_internals.h" + +namespace tflite { +namespace support { + +#ifndef SWIG +class BadStatusOrAccess : public std::exception { + public: + explicit BadStatusOrAccess(absl::Status status); + ~BadStatusOrAccess() override; + const char* what() const noexcept override; + const absl::Status& status() const; + + private: + absl::Status status_; +}; +#endif // !SWIG + +// Returned StatusOr objects may not be ignored. +// Note: Disabled for SWIG as it doesn't parse attributes correctly. Codesearch +// doesn't handle ifdefs as part of a class definitions (b/6995610), so we use a +// forward declaration. +#ifndef SWIG +template +class ABSL_MUST_USE_RESULT StatusOr; +#endif + +template +class StatusOr : private internal_statusor::StatusOrData, + private internal_statusor::CopyCtorBase, + private internal_statusor::MoveCtorBase, + private internal_statusor::CopyAssignBase, + private internal_statusor::MoveAssignBase { + template + friend class StatusOr; + + typedef internal_statusor::StatusOrData Base; + + public: + typedef T element_type; + + // Constructs a new StatusOr with Status::UNKNOWN status. This is marked + // 'explicit' to try to catch cases like 'return {};', where people think + // tflite::support::StatusOr> will be initialized with an + // empty vector, instead of a Status::UNKNOWN status. + explicit StatusOr(); + + // StatusOr is copy constructible if T is copy constructible. + StatusOr(const StatusOr&) = default; + // StatusOr is copy assignable if T is copy constructible and copy + // assignable. + StatusOr& operator=(const StatusOr&) = default; + +#ifndef SWIG + + // StatusOr is move constructible if T is move constructible. + StatusOr(StatusOr&&) = default; + // StatusOr is moveAssignable if T is move constructible and move + // assignable. + StatusOr& operator=(StatusOr&&) = default; + + // Converting constructors from StatusOr, when T is constructible from U. + // To avoid ambiguity, they are disabled if T is also constructible from + // StatusOr. Explicit iff the corresponding construction of T from U is + // explicit. + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation>, + std::is_constructible, + std::is_convertible, + absl::negation< + internal_statusor::IsConstructibleOrConvertibleFromStatusOr< + T, U>>>::value, + int> = 0> + StatusOr(const StatusOr& other) // NOLINT + : Base(static_cast::Base&>(other)) {} + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation>, + std::is_constructible, + absl::negation>, + absl::negation< + internal_statusor::IsConstructibleOrConvertibleFromStatusOr< + T, U>>>::value, + int> = 0> + explicit StatusOr(const StatusOr& other) + : Base(static_cast::Base&>(other)) {} + + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation>, std::is_constructible, + std::is_convertible, + absl::negation< + internal_statusor::IsConstructibleOrConvertibleFromStatusOr< + T, U>>>::value, + int> = 0> + StatusOr(StatusOr&& other) // NOLINT + : Base(static_cast::Base&&>(other)) {} + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation>, std::is_constructible, + absl::negation>, + absl::negation< + internal_statusor::IsConstructibleOrConvertibleFromStatusOr< + T, U>>>::value, + int> = 0> + explicit StatusOr(StatusOr&& other) + : Base(static_cast::Base&&>(other)) {} + + // Conversion copy/move assignment operator, T must be constructible and + // assignable from U. Only enable if T cannot be directly assigned from + // StatusOr. + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation>, + std::is_constructible, + std::is_assignable, + absl::negation< + internal_statusor:: + IsConstructibleOrConvertibleOrAssignableFromStatusOr< + T, U>>>::value, + int> = 0> + StatusOr& operator=(const StatusOr& other) { + this->Assign(other); + return *this; + } + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation>, std::is_constructible, + std::is_assignable, + absl::negation< + internal_statusor:: + IsConstructibleOrConvertibleOrAssignableFromStatusOr< + T, U>>>::value, + int> = 0> + StatusOr& operator=(StatusOr&& other) { + this->Assign(std::move(other)); + return *this; + } + +#endif // SWIG + + // Constructs a new StatusOr with the given value. After calling this + // constructor, this->ok() will be true and the contained value may be + // retrieved with value(), operator*(), or operator->(). + // + // NOTE: Not explicit - we want to use StatusOr as a return type + // so it is convenient and sensible to be able to do 'return T()' + // when the return type is StatusOr. + // + // REQUIRES: T is copy constructible. + // TODO(b/113125838): Replace this constructor with a direct-initialization + // constructor. + StatusOr(const T& value); + + // Constructs a new StatusOr with the given non-ok status. After calling this + // constructor, this->ok() will be false and calls to value() will CHECK-fail. + // + // NOTE: Not explicit - we want to use StatusOr as a return + // value, so it is convenient and sensible to be able to do 'return + // Status()' when the return type is StatusOr. + // + // REQUIRES: !status.ok(). This requirement is DCHECKed. + // In optimized builds, passing util::OkStatus() here will have the effect + // of passing util::error::INTERNAL as a fallback. + StatusOr(const absl::Status& status); + StatusOr& operator=(const absl::Status& status); + +#ifndef SWIG + // Perfect-forwarding value assignment operator. + // If `*this` contains a `T` value before the call, the contained value is + // assigned from `std::forward(v)`; Otherwise, it is directly-initialized + // from `std::forward(v)`. + // This function does not participate in overload unless: + // 1. `std::is_constructible_v` is true, + // 2. `std::is_assignable_v` is true. + // 3. `std::is_same_v, std::remove_cvref_t>` is false. + // 4. Assigning `U` to `T` is not ambiguous: + // If `U` is `StatusOr` and `T` is constructible and assignable from + // both `StatusOr` and `V`, the assignment is considered bug-prone and + // ambiguous thus will fail to compile. For example: + // StatusOr s1 = true; // s1.ok() && *s1 == true + // StatusOr s2 = false; // s2.ok() && *s2 == false + // s1 = s2; // ambiguous, `s1 = *s2` or `s1 = bool(s2)`? + template < + typename U = T, + typename = typename std::enable_if, std::is_assignable, + internal_statusor::IsForwardingAssignmentValid>::value>::type> + StatusOr& operator=(U&& v) { + this->Assign(std::forward(v)); + return *this; + } + + // Similar to the `const T&` overload. + // + // REQUIRES: T is move constructible. + StatusOr(T&& value); + + // RValue versions of the operations declared above. + StatusOr(absl::Status&& status); + StatusOr& operator=(absl::Status&& status); + + // Constructs the inner value T in-place using the provided args, using the + // T(args...) constructor. + template + explicit StatusOr(absl::in_place_t, Args&&... args); + template + explicit StatusOr(absl::in_place_t, std::initializer_list ilist, + Args&&... args); + + // Constructs the inner value T in-place using the provided args, using the + // T(U) (direct-initialization) constructor. Only valid if T can be + // constructed from a U. Can accept move or copy constructors. Explicit if + // U is not convertible to T. To avoid ambiguity, this is disabled if U is + // a StatusOr, where J is convertible to T. + // Style waiver for implicit conversion granted in cl/209187539. + template , + std::is_constructible, + std::is_convertible>::value, + int> = 0> + StatusOr(U&& u) // NOLINT + : StatusOr(absl::in_place, std::forward(u)) {} + + template , + std::is_constructible, + absl::negation>>::value, + int> = 0> + explicit StatusOr(U&& u) // NOLINT + : StatusOr(absl::in_place, std::forward(u)) {} + +#endif // SWIG + + // Returns this->status().ok() + ABSL_MUST_USE_RESULT bool ok() const { return this->status_.ok(); } + + // Returns a reference to our status. If this contains a T, then + // returns util::OkStatus(). +#ifdef SWIG + const ::util::Status& status() const; +#else // SWIG + const absl::Status& status() const&; + absl::Status status() &&; +#endif // SWIG + + // Returns a reference to the held value if `this->ok()`. Otherwise, throws + // `absl::BadStatusOrAccess` if exception is enabled, or `LOG(FATAL)` if + // exception is disabled. + // If you have already checked the status using `this->ok()` or + // `operator bool()`, you probably want to use `operator*()` or `operator->()` + // to access the value instead of `value`. + // Note: for value types that are cheap to copy, prefer simple code: + // + // T value = statusor.value(); + // + // Otherwise, if the value type is expensive to copy, but can be left + // in the StatusOr, simply assign to a reference: + // + // T& value = statusor.value(); // or `const T&` + // + // Otherwise, if the value type supports an efficient move, it can be + // used as follows: + // + // T value = std::move(statusor).value(); + // + // The `std::move` on statusor instead of on the whole expression enables + // warnings about possible uses of the statusor object after the move. +#ifdef SWIG + const T& value() const; +#else // SWIG + const T& value() const&; + T& value() &; + const T&& value() const&&; + T&& value() &&; +#endif // SWIG + +#ifndef SWIG + // Returns a reference to the current value. + // + // REQUIRES: this->ok() == true, otherwise the behavior is undefined. + // + // Use this->ok() or `operator bool()` to verify that there is a current + // value. Alternatively, see value() for a similar API that guarantees + // CHECK-failing if there is no current value. + const T& operator*() const&; + T& operator*() &; + const T&& operator*() const&&; + T&& operator*() &&; +#endif // SWIG + +#ifndef SWIG + // Returns a pointer to the current value. + // + // REQUIRES: this->ok() == true, otherwise the behavior is undefined. + // + // Use this->ok() or `operator bool()` to verify that there is a current + // value. + const T* operator->() const; + T* operator->(); +#endif // SWIG + +#ifndef SWIG + // Returns a copy of the current value if this->ok() == true. Otherwise + // returns a default value. + template + T value_or(U&& default_value) const&; + template + T value_or(U&& default_value) &&; +#endif // SWIG + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; + +#ifndef SWIG + // Reconstructs the inner value T in-place using the provided args, using the + // T(args...) constructor. Returns reference to the reconstructed `T`. + template + T& emplace(Args&&... args) { + if (ok()) { + this->Clear(); + this->MakeValue(std::forward(args)...); + } else { + this->MakeValue(std::forward(args)...); + this->status_ = absl::OkStatus(); + } + return this->data_; + } + + template < + typename U, typename... Args, + absl::enable_if_t< + std::is_constructible&, Args&&...>::value, + int> = 0> + T& emplace(std::initializer_list ilist, Args&&... args) { + if (ok()) { + this->Clear(); + this->MakeValue(ilist, std::forward(args)...); + } else { + this->MakeValue(ilist, std::forward(args)...); + this->status_ = absl::OkStatus(); + } + return this->data_; + } +#endif // SWIG + + private: +#ifndef SWIG + using internal_statusor::StatusOrData::Assign; + template + void Assign(const StatusOr& other); + template + void Assign(StatusOr&& other); +#endif // SWIG +}; + +#ifndef SWIG +//////////////////////////////////////////////////////////////////////////////// +// Implementation details for StatusOr + +template +StatusOr::StatusOr() : Base(absl::Status(absl::StatusCode::kUnknown, "")) {} + +template +StatusOr::StatusOr(const T& value) : Base(value) {} + +template +StatusOr::StatusOr(const absl::Status& status) : Base(status) {} + +template +StatusOr& StatusOr::operator=(const absl::Status& status) { + this->Assign(status); + return *this; +} + +template +StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} + +template +StatusOr::StatusOr(absl::Status&& status) : Base(std::move(status)) {} + +template +StatusOr& StatusOr::operator=(absl::Status&& status) { + this->Assign(std::move(status)); + return *this; +} + +template +template +inline void StatusOr::Assign(const StatusOr& other) { + if (other.ok()) { + this->Assign(other.value()); + } else { + this->Assign(other.status()); + } +} + +template +template +inline void StatusOr::Assign(StatusOr&& other) { + if (other.ok()) { + this->Assign(std::move(other).value()); + } else { + this->Assign(std::move(other).status()); + } +} +template +template +StatusOr::StatusOr(absl::in_place_t, Args&&... args) + : Base(absl::in_place, std::forward(args)...) {} + +template +template +StatusOr::StatusOr(absl::in_place_t, std::initializer_list ilist, + Args&&... args) + : Base(absl::in_place, ilist, std::forward(args)...) {} + +template +const absl::Status& StatusOr::status() const& { + return this->status_; +} +template +absl::Status StatusOr::status() && { + return ok() ? absl::OkStatus() : std::move(this->status_); +} + +template +const T& StatusOr::value() const& { + if (!this->ok()) internal_statusor::ThrowBadStatusOrAccess(this->status_); + return this->data_; +} + +template +T& StatusOr::value() & { + if (!this->ok()) internal_statusor::ThrowBadStatusOrAccess(this->status_); + return this->data_; +} + +template +const T&& StatusOr::value() const&& { + if (!this->ok()) { + internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_)); + } + return std::move(this->data_); +} + +template +T&& StatusOr::value() && { + if (!this->ok()) { + internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_)); + } + return std::move(this->data_); +} + +template +const T& StatusOr::operator*() const& { + this->EnsureOk(); + return this->data_; +} + +template +T& StatusOr::operator*() & { + this->EnsureOk(); + return this->data_; +} + +template +const T&& StatusOr::operator*() const&& { + this->EnsureOk(); + return std::move(this->data_); +} + +template +T&& StatusOr::operator*() && { + this->EnsureOk(); + return std::move(this->data_); +} + +template +const T* StatusOr::operator->() const { + this->EnsureOk(); + return &this->data_; +} + +template +T* StatusOr::operator->() { + this->EnsureOk(); + return &this->data_; +} + +template +template +T StatusOr::value_or(U&& default_value) const& { + if (ok()) { + return this->data_; + } + return std::forward(default_value); +} + +template +template +T StatusOr::value_or(U&& default_value) && { + if (ok()) { + return std::move(this->data_); + } + return std::forward(default_value); +} + +template +void StatusOr::IgnoreError() const { + // no-op +} + +#endif // SWIG + + +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_ diff --git a/tensorflow_lite_support/cc/port/default/statusor_internals.h b/tensorflow_lite_support/cc/port/default/statusor_internals.h new file mode 100644 index 000000000..56d466162 --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/statusor_internals.h @@ -0,0 +1,409 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is forked from absl. + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_INTERNALS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_INTERNALS_H_ + +#include +#include + +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/utility/utility.h" + +namespace tflite { +namespace support { + +template +class ABSL_MUST_USE_RESULT StatusOr; + +namespace internal_statusor { + +// Detects whether `T` is constructible or convertible from `StatusOr`. +template +using IsConstructibleOrConvertibleFromStatusOr = + absl::disjunction&>, + std::is_constructible&>, + std::is_constructible&&>, + std::is_constructible&&>, + std::is_convertible&, T>, + std::is_convertible&, T>, + std::is_convertible&&, T>, + std::is_convertible&&, T>>; + +// Detects whether `T` is constructible or convertible or assignable from +// `StatusOr`. +template +using IsConstructibleOrConvertibleOrAssignableFromStatusOr = + absl::disjunction, + std::is_assignable&>, + std::is_assignable&>, + std::is_assignable&&>, + std::is_assignable&&>>; + +// Detects whether direct initializing `StatusOr` from `U` is ambiguous, i.e. +// when `U` is `StatusOr` and `T` is constructible or convertible from `V`. +template +struct IsDirectInitializationAmbiguous + : public absl::conditional_t< + std::is_same>, + U>::value, + std::false_type, + IsDirectInitializationAmbiguous< + T, absl::remove_cv_t>>> {}; + +template +struct IsDirectInitializationAmbiguous> + : public IsConstructibleOrConvertibleFromStatusOr {}; + +// Checks against the constraints of the direction initialization, i.e. when +// `StatusOr::StatusOr(U&&)` should participate in overload resolution. +template +using IsDirectInitializationValid = absl::disjunction< + // Short circuits if T is basically U. + std::is_same>>, + absl::negation, + absl::remove_cv_t>>, + std::is_same>>, + std::is_same>>, + IsDirectInitializationAmbiguous>>>; + +// This trait detects whether `StatusOr::operator=(U&&)` is ambiguous, which +// is equivalent to whether all the following conditions are met: +// 1. `U` is `StatusOr`. +// 2. `T` is constructible and assignable from `V`. +// 3. `T` is constructible and assignable from `U` (i.e. `StatusOr`). +// For example, the following code is considered ambiguous: +// (`T` is `bool`, `U` is `StatusOr`, `V` is `bool`) +// StatusOr s1 = true; // s1.ok() && s1.ValueOrDie() == true +// StatusOr s2 = false; // s2.ok() && s2.ValueOrDie() == false +// s1 = s2; // ambiguous, `s1 = s2.ValueOrDie()` or `s1 = bool(s2)`? +template +struct IsForwardingAssignmentAmbiguous + : public absl::conditional_t< + std::is_same>, + U>::value, + std::false_type, + IsForwardingAssignmentAmbiguous< + T, absl::remove_cv_t>>> {}; + +template +struct IsForwardingAssignmentAmbiguous> + : public IsConstructibleOrConvertibleOrAssignableFromStatusOr {}; + +// Checks against the constraints of the forwarding assignment, i.e. whether +// `StatusOr::operator(U&&)` should participate in overload resolution. +template +using IsForwardingAssignmentValid = absl::disjunction< + // Short circuits if T is basically U. + std::is_same>>, + absl::negation, + absl::remove_cv_t>>, + std::is_same>>, + std::is_same>>, + IsForwardingAssignmentAmbiguous>>>; + +class Helper { + public: + // Move type-agnostic error handling to the .cc. + static void HandleInvalidStatusCtorArg(absl::Status*); + ABSL_ATTRIBUTE_NORETURN static void Crash(const absl::Status& status); +}; + +// Construct an instance of T in `p` through placement new, passing Args... to +// the constructor. +// This abstraction is here mostly for the gcc performance fix. +template +void PlacementNew(void* p, Args&&... args) { +#if defined(__GNUC__) && !defined(__clang__) + // Teach gcc that 'p' cannot be null, fixing code size issues. + if (p == nullptr) __builtin_unreachable(); +#endif + new (p) T(std::forward(args)...); +} + +// Helper base class to hold the data and all operations. +// We move all this to a base class to allow mixing with the appropriate +// TraitsBase specialization. +template +class StatusOrData { + template + friend class StatusOrData; + + public: + StatusOrData() = delete; + + StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + StatusOrData(StatusOrData&& other) noexcept { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + template + explicit StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + template + explicit StatusOrData(StatusOrData&& other) { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + template + explicit StatusOrData(absl::in_place_t, Args&&... args) + : data_(std::forward(args)...) { + MakeStatus(); + } + + explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } + explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } + + explicit StatusOrData(const absl::Status& status) : status_(status) { + EnsureNotOk(); + } + explicit StatusOrData(absl::Status&& status) : status_(std::move(status)) { + EnsureNotOk(); + } + + StatusOrData& operator=(const StatusOrData& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(other.data_); + else + Assign(other.status_); + return *this; + } + + StatusOrData& operator=(StatusOrData&& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(std::move(other.data_)); + else + Assign(std::move(other.status_)); + return *this; + } + + ~StatusOrData() { + if (ok()) { + status_.~Status(); + data_.~T(); + } else { + status_.~Status(); + } + } + + // TODO(b/140189837): Remove the SFINAE condition after cleanup. + template ::value, int> = 0> + void Assign(U&& value) { + if (ok()) { + data_ = std::forward(value); + } else { + MakeValue(std::forward(value)); + status_ = absl::OkStatus(); + } + } + + // TODO(b/140189837): Remove this after cleanup. + // This overload is to handle the case where `T` is a `const` type. + // `StatusOr` supports assignment for `const` types though it's forbidden by + // other standard types like `std::optional`. + template ::value, int> = 0> + void Assign(U&& value) { + if (ok()) { + data_.~T(); + MakeValue(std::forward(value)); + } else { + MakeValue(std::forward(value)); + status_ = absl::OkStatus(); + } + } + + void Assign(const absl::Status& status) { + Clear(); + status_ = status; + EnsureNotOk(); + } + + void Assign(absl::Status&& status) { + Clear(); + status_ = std::move(status); + EnsureNotOk(); + } + + bool ok() const { return status_.ok(); } + + protected: + // status_ will always be active after the constructor. + // We make it a union to be able to initialize exactly how we need without + // waste. + // Eg. in the copy constructor we use the default constructor of Status in + // the ok() path to avoid an extra Ref call. + union { + absl::Status status_; + }; + + // data_ is active iff status_.ok()==true + struct Dummy {}; + union { + // When T is const, we need some non-const object we can cast to void* for + // the placement new. dummy_ is that object. + Dummy dummy_; + T data_; + }; + + void Clear() { + if (ok()) data_.~T(); + } + + void EnsureOk() const { + if (ABSL_PREDICT_FALSE(!ok())) Helper::Crash(status_); + } + + void EnsureNotOk() { + if (ABSL_PREDICT_FALSE(ok())) Helper::HandleInvalidStatusCtorArg(&status_); + } + + // Construct the value (ie. data_) through placement new with the passed + // argument. + template + void MakeValue(Arg&&... arg) { + internal_statusor::PlacementNew(&dummy_, std::forward(arg)...); + } + + // Construct the status (ie. status_) through placement new with the passed + // argument. + template + void MakeStatus(Args&&... args) { + internal_statusor::PlacementNew(&status_, + std::forward(args)...); + } +}; + +// Helper base classes to allow implicitly deleted constructors and assignment +// operators in `StatusOr`. For example, `CopyCtorBase` will explicitly delete +// the copy constructor when T is not copy constructible and `StatusOr` will +// inherit that behavior implicitly. +template ::value> +struct CopyCtorBase { + CopyCtorBase() = default; + CopyCtorBase(const CopyCtorBase&) = default; + CopyCtorBase(CopyCtorBase&&) = default; + CopyCtorBase& operator=(const CopyCtorBase&) = default; + CopyCtorBase& operator=(CopyCtorBase&&) = default; +}; + +template +struct CopyCtorBase { + CopyCtorBase() = default; + CopyCtorBase(const CopyCtorBase&) = delete; + CopyCtorBase(CopyCtorBase&&) = default; + CopyCtorBase& operator=(const CopyCtorBase&) = default; + CopyCtorBase& operator=(CopyCtorBase&&) = default; +}; + +template ::value> +struct MoveCtorBase { + MoveCtorBase() = default; + MoveCtorBase(const MoveCtorBase&) = default; + MoveCtorBase(MoveCtorBase&&) = default; + MoveCtorBase& operator=(const MoveCtorBase&) = default; + MoveCtorBase& operator=(MoveCtorBase&&) = default; +}; + +template +struct MoveCtorBase { + MoveCtorBase() = default; + MoveCtorBase(const MoveCtorBase&) = default; + MoveCtorBase(MoveCtorBase&&) = delete; + MoveCtorBase& operator=(const MoveCtorBase&) = default; + MoveCtorBase& operator=(MoveCtorBase&&) = default; +}; + +template ::value&& + std::is_copy_assignable::value> +struct CopyAssignBase { + CopyAssignBase() = default; + CopyAssignBase(const CopyAssignBase&) = default; + CopyAssignBase(CopyAssignBase&&) = default; + CopyAssignBase& operator=(const CopyAssignBase&) = default; + CopyAssignBase& operator=(CopyAssignBase&&) = default; +}; + +template +struct CopyAssignBase { + CopyAssignBase() = default; + CopyAssignBase(const CopyAssignBase&) = default; + CopyAssignBase(CopyAssignBase&&) = default; + CopyAssignBase& operator=(const CopyAssignBase&) = delete; + CopyAssignBase& operator=(CopyAssignBase&&) = default; +}; + +template ::value&& + std::is_move_assignable::value> +struct MoveAssignBase { + MoveAssignBase() = default; + MoveAssignBase(const MoveAssignBase&) = default; + MoveAssignBase(MoveAssignBase&&) = default; + MoveAssignBase& operator=(const MoveAssignBase&) = default; + MoveAssignBase& operator=(MoveAssignBase&&) = default; +}; + +template +struct MoveAssignBase { + MoveAssignBase() = default; + MoveAssignBase(const MoveAssignBase&) = default; + MoveAssignBase(MoveAssignBase&&) = default; + MoveAssignBase& operator=(const MoveAssignBase&) = default; + MoveAssignBase& operator=(MoveAssignBase&&) = delete; +}; + +void ThrowBadStatusOrAccess(absl::Status status); + +} // namespace internal_statusor +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_INTERNALS_H_ diff --git a/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc new file mode 100644 index 000000000..548e679af --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h" + +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" + +namespace tflite { +namespace support { + +absl::Status TfLiteInterpreterWrapper::InitializeWithFallback( + std::function*)> + interpreter_initializer, + const tflite::proto::ComputeSettings& compute_settings) { + if (compute_settings.has_preference() || + compute_settings.has_tflite_settings()) { + return absl::UnimplementedError( + "Acceleration via ComputeSettings is not supported yet."); + } + RETURN_IF_ERROR(interpreter_initializer(&interpreter_)); + return interpreter_->AllocateTensors() != kTfLiteOk + ? absl::InternalError( + "TFLite interpreter: AllocateTensors() failed.") + : absl::OkStatus(); +} + +absl::Status TfLiteInterpreterWrapper::InvokeWithFallback( + const std::function& + set_inputs) { + RETURN_IF_ERROR(set_inputs(interpreter_.get())); + return interpreter_->Invoke() != kTfLiteOk + ? absl::InternalError("TFLite interpreter: Invoke() failed.") + : absl::OkStatus(); +} + +absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() { + return interpreter_->Invoke() != kTfLiteOk + ? absl::InternalError("TFLite interpreter: Invoke() failed.") + : absl::OkStatus(); +} + +void TfLiteInterpreterWrapper::Cancel() { + // NOP +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/port/default/tflite_wrapper.h b/tensorflow_lite_support/cc/port/default/tflite_wrapper.h new file mode 100644 index 000000000..3fd489f7b --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/tflite_wrapper.h @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h" +#include "tensorflow/lite/interpreter.h" + +namespace tflite { +namespace support { + +// Wrapper for a TfLiteInterpreter that may be accelerated[1]. This is NOT yet +// implemented: this class only provides a first, minimal interface in the +// meanwhile. +// +// [1] See tensorflow/lite/experimental/acceleration for more details. +class TfLiteInterpreterWrapper { + public: + TfLiteInterpreterWrapper() = default; + + virtual ~TfLiteInterpreterWrapper() = default; + + // Calls `interpreter_initializer` and then `AllocateTensors`. Future + // implementation of this method will attempt to apply the provided + // `compute_settings` with a graceful fallback in case a failure occurs. + // Note: before this gets implemented, do NOT call this method with non-empty + // `compute_settings` otherwise an unimplemented error occurs. + absl::Status InitializeWithFallback( + std::function*)> + interpreter_initializer, + const tflite::proto::ComputeSettings& compute_settings); + + // Calls `set_inputs` and then Invoke() on the interpreter. Future + // implementation of this method will perform a graceful fallback in case a + // failure occur due to the `compute_settings` provided at initialization + // time. + absl::Status InvokeWithFallback( + const std::function& + set_inputs); + + // Calls Invoke() on the interpreter. Caller must have set up inputs + // before-hand. + absl::Status InvokeWithoutFallback(); + + // Cancels the current running TFLite invocation on CPU. This method is not + // yet implemented though it is safe to use it as it acts as a NOP. + void Cancel(); + + // Accesses the underlying interpreter for other methods. + tflite::Interpreter& operator*() { return *interpreter_; } + tflite::Interpreter* operator->() { return interpreter_.get(); } + tflite::Interpreter& operator*() const { return *interpreter_; } + tflite::Interpreter* operator->() const { return interpreter_.get(); } + tflite::Interpreter* get() const { return interpreter_.get(); } + + TfLiteInterpreterWrapper(const TfLiteInterpreterWrapper&) = delete; + TfLiteInterpreterWrapper& operator=(const TfLiteInterpreterWrapper&) = delete; + + private: + std::unique_ptr interpreter_; +}; + +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_ diff --git a/tensorflow_lite_support/cc/port/integral_types.h b/tensorflow_lite_support/cc/port/integral_types.h new file mode 100644 index 000000000..c8eaca83e --- /dev/null +++ b/tensorflow_lite_support/cc/port/integral_types.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_INTEGRAL_TYPES_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_INTEGRAL_TYPES_H_ + +// TODO: is not portable C. Take a close look at this when we add +// mobile support. +#include +#include + +typedef signed char schar; +typedef int8_t int8; +typedef int16_t int16; +typedef int32_t int32; +typedef int64_t int64; + +typedef uint8_t uint8; +typedef uint16_t uint16; +typedef uint32_t uint32; +typedef uint32_t char32; +typedef uint64_t uint64; + +typedef unsigned long uword_t; + +#define GG_LONGLONG(x) x##LL +#define GG_ULONGLONG(x) x##ULL +#define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also. +#define GG_LL_FORMAT_W L"ll" + +typedef uint64 Fprint; +static const Fprint kIllegalFprint = 0; +static const Fprint kMaxFprint = GG_ULONGLONG(0xFFFFFFFFFFFFFFFF); + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_INTEGRAL_TYPES_H_ diff --git a/tensorflow_lite_support/cc/port/status_macros.h b/tensorflow_lite_support/cc/port/status_macros.h new file mode 100644 index 000000000..3890c7729 --- /dev/null +++ b/tensorflow_lite_support/cc/port/status_macros.h @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MACROS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MACROS_H_ + +#include "tensorflow_lite_support/cc/port/default/status_macros.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MACROS_H_ diff --git a/tensorflow_lite_support/cc/port/statusor.h b/tensorflow_lite_support/cc/port/statusor.h new file mode 100644 index 000000000..f84c7568b --- /dev/null +++ b/tensorflow_lite_support/cc/port/statusor.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_ + +#include "tensorflow_lite_support/cc/port/default/statusor.h" +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_ diff --git a/tensorflow_lite_support/cc/port/tflite_wrapper.h b/tensorflow_lite_support/cc/port/tflite_wrapper.h new file mode 100644 index 000000000..df8cb832b --- /dev/null +++ b/tensorflow_lite_support/cc/port/tflite_wrapper.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_TFLITE_WRAPPER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_TFLITE_WRAPPER_H_ + +#include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h" +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_TFLITE_WRAPPER_H_ diff --git a/tensorflow_lite_support/cc/task/README.md b/tensorflow_lite_support/cc/task/README.md new file mode 100644 index 000000000..74aee8a4e --- /dev/null +++ b/tensorflow_lite_support/cc/task/README.md @@ -0,0 +1,100 @@ +#TFLite Task library - C++ + +A flexible and ready-to-use library for common machine learning model types, +such as classification and detection. + +## Text Task Librarys + +### QuestionAnswerer + +`QuestionAnswerer` API is able to load +[Mobile BERT](https://tfhub.dev/tensorflow/mobilebert/1) or +[AlBert](https://tfhub.dev/tensorflow/albert_lite_base/1) TFLite models and +answer question based on context. + +Use the C++ API to answer questions as follows: + +```cc +using tflite::support::task::text::qa::BertQuestionAnswerer; +using tflite::support::task::text::qa::QaAnswer; +// Create API handler with Mobile Bert model. +auto qa_client = BertQuestionAnswerer::CreateBertQuestionAnswerer("/path/to/mobileBertModel", "/path/to/vocab"); +// Or create API handler with Albert model. +// auto qa_client = BertQuestionAnswerer::CreateAlbertQuestionAnswerer("/path/to/alBertModel", "/path/to/sentencePieceModel"); + + +std::string context = + "Nikola Tesla (Serbian Cyrillic: Никола Тесла; 10 " + "July 1856 – 7 January 1943) was a Serbian American inventor, electrical " + "engineer, mechanical engineer, physicist, and futurist best known for his " + "contributions to the design of the modern alternating current (AC) " + "electricity supply system."; +std::string question = "When was Nikola Tesla born?"; +// Run inference with `context` and a given `question` to the context, and get top-k +// answers ranked by logits. +const std::vector answers = qa_client->Answer(context, question); +// Access QaAnswer results. +for (const QaAnswer& item : answers) { + std::cout << absl::StrFormat("Text: %s logit=%f start=%d end=%d", item.text, + item.pos.logit, item.pos.start, item.pos.end) + << std::endl; +} +// Output: +// Text: 10 July 1856 logit=16.8527 start=17 end=19 +// ... (and more) +// +// So the top-1 answer is: "10 July 1856". +``` + +In the above code, `item.text` is the text content of an answer. We use a span +with closed interval `[item.pos.start, item.pos.end]` to denote predicted tokens +in the answer, and `item.pos.logit` is the sum of span logits to represent the +confidence score. + +### NLClassifier + +`NLClassifier` API is able to load any TFLite models for natural language +classaification task such as language detection or sentiment detection. + +The API expects a TFLite model with the following input/output tensor: +Input tensor0: + (kTfLiteString) - input of the model, accepts a string. +Output tensor0: + (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64) + - output scores for each class, if type is one of the Int types, + dequantize it to double +Output tensor1: optional + (kTfLiteString) + - output classname for each class, should be of the same length with + scores. If this tensor is not present, the API uses score indices as + classnames. +By default the API tries to find the input/output tensors with default +configurations in NLClassifierOptions, with tensor name prioritized over +tensor index. The option is configurable for different TFLite models. + +Use the C++ API to perform language ID classification as follows: + +```cc +using tflite::support::task::text::nlclassifier::NLClassifier; +using tflite::support::task::core::Category; +auto classifier = NLClassifier::CreateNLClassifier("/path/to/model"); +// Or create a customized NLClassifierOptions +// NLClassifierOptions options = +// { +// .output_score_tensor_name = myOutputScoreTensorName, +// .output_label_tensor_name = myOutputLabelTensorName, +// } +// auto classifier = NLClassifier::CreateNLClassifier("/path/to/model", options); +std::string context = "What language is this?"; +std::vector categories = classifier->Classify(context); +// Access category results. +for (const Categoryr& category : categories) { + std::cout << absl::StrFormat("Language: %s Probability: %f", category.class_name, category_.score) + << std::endl; +} +// Output: +// Language: en Probability=0.9 +// ... (and more) +// +// So the top-1 answer is 'en'. +``` diff --git a/tensorflow_lite_support/cc/task/core/BUILD b/tensorflow_lite_support/cc/task/core/BUILD new file mode 100644 index 000000000..5e7edc7e2 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/BUILD @@ -0,0 +1,104 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tflite_engine", + srcs = ["tflite_engine.cc"], + hdrs = [ + "tflite_engine.h", + ], + deps = [ + ":external_file_handler", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:tflite_wrapper", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/tools:verifier", + ], +) + +cc_library( + name = "base_task_api", + hdrs = [ + "base_task_api.h", + ], + deps = [ + ":tflite_engine", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/port:tflite_wrapper", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/c:common", + ], +) + +cc_library( + name = "task_api_factory", + hdrs = [ + "task_api_factory.h", + ], + deps = [ + ":base_task_api", + ":tflite_engine", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "@com_google_absl//absl/status", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:op_macros", + ], +) + +cc_library( + name = "task_utils", + srcs = [ + "task_utils.cc", + ], + hdrs = [ + "task_utils.h", + ], + deps = [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite:type_to_tflitetype", + "@org_tensorflow//tensorflow/lite/kernels:op_macros", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_library( + name = "category", + hdrs = ["category.h"], +) + +cc_library( + name = "external_file_handler", + srcs = ["external_file_handler.cc"], + hdrs = ["external_file_handler.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) diff --git a/tensorflow_lite_support/cc/task/core/base_task_api.h b/tensorflow_lite_support/cc/task/core/base_task_api.h new file mode 100644 index 000000000..897e377f0 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/base_task_api.h @@ -0,0 +1,152 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/port/tflite_wrapper.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +namespace tflite { +namespace support { +namespace task { +namespace core { + +class BaseUntypedTaskApi { + public: + explicit BaseUntypedTaskApi(std::unique_ptr engine) + : engine_{std::move(engine)} {} + + virtual ~BaseUntypedTaskApi() = default; + + const TfLiteEngine* GetTfLiteEngine() const { return engine_.get(); } + const metadata::ModelMetadataExtractor* GetMetadataExtractor() const { + return engine_->metadata_extractor(); + } + + protected: + std::unique_ptr engine_; +}; + +template +class BaseTaskApi : public BaseUntypedTaskApi { + public: + explicit BaseTaskApi(std::unique_ptr engine) + : BaseUntypedTaskApi(std::move(engine)) {} + // BaseTaskApi is neither copyable nor movable. + BaseTaskApi(const BaseTaskApi&) = delete; + BaseTaskApi& operator=(const BaseTaskApi&) = delete; + + // Cancels the current running TFLite invocation on CPU. + // + // Usually called on a different thread than the one inference is running on. + // Calling Cancel() will cause the underlying TFLite interpreter to return an + // error, which will turn into a `CANCELLED` status and empty results. Calling + // Cancel() at the other time will not take any effect on the current or + // following invocation. It is perfectly fine to run inference again on the + // same instance after a cancelled invocation. If the TFLite inference is + // partially delegated on CPU, logs a warning message and only cancels the + // invocation running on CPU. Other invocation which depends on the output of + // the CPU invocation will not be executed. + void Cancel() { engine_->Cancel(); } + + protected: + // Subclasses need to populate input_tensors from api_inputs. + virtual absl::Status Preprocess( + const std::vector& input_tensors, + InputTypes... api_inputs) = 0; + + // Subclasses need to construct OutputType object from output_tensors. + // Original inputs are also provided as they may be needed. + virtual StatusOr Postprocess( + const std::vector& output_tensors, + InputTypes... api_inputs) = 0; + + // Returns the tensors associated with the given input/output indexes. + template + std::vector GetTensors(const std::vector& tensor_indices) { + tflite::Interpreter* interpreter = engine_->interpreter(); + std::vector tensors; + tensors.reserve(tensor_indices.size()); + for (int index : tensor_indices) { + tensors.push_back(interpreter->tensor(index)); + } + return tensors; + } + + std::vector GetInputTensors() { + return GetTensors(engine_->interpreter()->inputs()); + } + + std::vector GetOutputTensors() { + return GetTensors(engine_->interpreter()->outputs()); + } + + // Performs inference using tflite::support::TfLiteInterpreterWrapper + // InvokeWithoutFallback(). + StatusOr Infer(InputTypes... args) { + tflite::support::TfLiteInterpreterWrapper* interpreter_wrapper = + engine_->interpreter_wrapper(); + // Note: AllocateTensors() is already performed by the interpreter wrapper + // at InitInterpreter time (see TfLiteEngine). + RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...)); + absl::Status status = interpreter_wrapper->InvokeWithoutFallback(); + if (!status.ok()) { + return status.GetPayload(tflite::support::kTfLiteSupportPayload) + .has_value() + ? status + : CreateStatusWithPayload(status.code(), status.message()); + } + return Postprocess(GetOutputTensors(), args...); + } + + // Performs inference using tflite::support::TfLiteInterpreterWrapper + // InvokeWithFallback() to benefit from automatic fallback from delegation to + // CPU where applicable. + StatusOr InferWithFallback(InputTypes... args) { + tflite::support::TfLiteInterpreterWrapper* interpreter_wrapper = + engine_->interpreter_wrapper(); + // Note: AllocateTensors() is already performed by the interpreter wrapper + // at InitInterpreter time (see TfLiteEngine). + RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...)); + auto set_inputs_nop = [](tflite::Interpreter* interpreter) -> absl::Status { + // NOP since inputs are populated at Preprocess() time. + return absl::OkStatus(); + }; + absl::Status status = + interpreter_wrapper->InvokeWithFallback(set_inputs_nop); + if (!status.ok()) { + return status.GetPayload(tflite::support::kTfLiteSupportPayload) + .has_value() + ? status + : CreateStatusWithPayload(status.code(), status.message()); + } + return Postprocess(GetOutputTensors(), args...); + } +}; + +} // namespace core +} // namespace task +} // namespace support +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ diff --git a/tensorflow_lite_support/cc/task/core/category.h b/tensorflow_lite_support/cc/task/core/category.h new file mode 100644 index 000000000..5cd8fee82 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/category.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_ +#include + +namespace tflite { +namespace support { +namespace task { +namespace core { + +// Result for classification APIs. +struct Category { + std::string class_name; + double score; + Category(const std::string& class_name, double score) + : class_name(class_name), score(score) {} + + friend bool operator==(const Category& lhs, const Category& rhs) { + return lhs.score == rhs.score && lhs.class_name == rhs.class_name; + } + + friend bool operator!=(const Category& lhs, const Category& rhs) { + return !(lhs == rhs); + } +}; + +} // namespace core +} // namespace task +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_ diff --git a/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/tensorflow_lite_support/cc/task/core/external_file_handler.cc new file mode 100644 index 000000000..bb9ecb62c --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/external_file_handler.cc @@ -0,0 +1,195 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" + +namespace tflite { +namespace support { +namespace task { +namespace core { +namespace { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::TfLiteSupportStatus; + +// Gets the offset aligned to page size for mapping given files into memory by +// file descriptor correctly, as according to mmap(2), the offset used in mmap +// must be a multiple of sysconf(_SC_PAGE_SIZE). +int64 GetPageSizeAlignedOffset(int64 offset) { + int64 aligned_offset = offset; + int64 page_size = sysconf(_SC_PAGE_SIZE); + if (offset % page_size != 0) { + aligned_offset = offset / page_size * page_size; + } + return aligned_offset; +} + +} // namespace + +/* static */ +StatusOr> +ExternalFileHandler::CreateFromExternalFile(const ExternalFile* external_file) { + // Use absl::WrapUnique() to call private constructor: + // https://abseil.io/tips/126. + std::unique_ptr handler = + absl::WrapUnique(new ExternalFileHandler(external_file)); + + RETURN_IF_ERROR(handler->MapExternalFile()); + + return handler; +} + +absl::Status ExternalFileHandler::MapExternalFile() { + if (!external_file_.file_content().empty()) { + return absl::OkStatus(); + } + if (external_file_.file_name().empty() && + !external_file_.has_file_descriptor_meta()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "ExternalFile must specify at least one of 'file_content', file_name' " + "or 'file_descriptor_meta'.", + TfLiteSupportStatus::kInvalidArgumentError); + } + // Obtain file descriptor, offset and size. + int fd = -1; + if (!external_file_.file_name().empty()) { + owned_fd_ = open(external_file_.file_name().c_str(), O_RDONLY); + if (owned_fd_ < 0) { + const std::string error_message = absl::StrFormat( + "Unable to open file at %s", external_file_.file_name()); + switch (errno) { + case ENOENT: + return CreateStatusWithPayload( + StatusCode::kNotFound, error_message, + TfLiteSupportStatus::kFileNotFoundError); + case EACCES: + case EPERM: + return CreateStatusWithPayload( + StatusCode::kPermissionDenied, error_message, + TfLiteSupportStatus::kFilePermissionDeniedError); + case EINTR: + return CreateStatusWithPayload(StatusCode::kUnavailable, + error_message, + TfLiteSupportStatus::kFileReadError); + case EBADF: + return CreateStatusWithPayload(StatusCode::kFailedPrecondition, + error_message, + TfLiteSupportStatus::kFileReadError); + default: + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrFormat("%s, errno=%d", error_message, errno), + TfLiteSupportStatus::kFileReadError); + } + } + fd = owned_fd_; + } else { + fd = external_file_.file_descriptor_meta().fd(); + if (fd < 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Provided file descriptor is invalid: %d < 0", fd), + TfLiteSupportStatus::kInvalidArgumentError); + } + buffer_offset_ = external_file_.file_descriptor_meta().offset(); + buffer_size_ = external_file_.file_descriptor_meta().length(); + } + // Get actual file size. Always use 0 as offset to lseek(2) to get the actual + // file size, as SEEK_END returns the size of the file *plus* offset. + size_t file_size = lseek(fd, /*offset=*/0, SEEK_END); + if (file_size <= 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrFormat("Unable to get file size, errno=%d", errno), + TfLiteSupportStatus::kFileReadError); + } + // Deduce buffer size if not explicitly provided through file descriptor. + if (buffer_size_ <= 0) { + buffer_size_ = file_size - buffer_offset_; + } + // Check for out of range issues. + if (file_size <= buffer_offset_) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Provided file offset (%d) exceeds or matches actual " + "file length (%d)", + buffer_offset_, file_size), + TfLiteSupportStatus::kInvalidArgumentError); + } + if (file_size < buffer_size_ + buffer_offset_) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Provided file length + offset (%d) exceeds actual " + "file length (%d)", + buffer_size_ + buffer_offset_, file_size), + TfLiteSupportStatus::kInvalidArgumentError); + } + // If buffer_offset_ is not multiple of sysconf(_SC_PAGE_SIZE), align with + // extra leading bytes and adjust buffer_size_ to account for the extra + // leading bytes. + buffer_aligned_offset_ = GetPageSizeAlignedOffset(buffer_offset_); + buffer_aligned_size_ = buffer_size_ + buffer_offset_ - buffer_aligned_offset_; + // Map into memory. + buffer_ = mmap(/*addr=*/nullptr, buffer_aligned_size_, PROT_READ, MAP_SHARED, + fd, buffer_aligned_offset_); + if (buffer_ == MAP_FAILED) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrFormat("Unable to map file to memory buffer, errno=%d", errno), + TfLiteSupportStatus::kFileMmapError); + } + return absl::OkStatus(); +} + +absl::string_view ExternalFileHandler::GetFileContent() { + if (external_file_.has_file_content()) { + return external_file_.file_content(); + } else { + return absl::string_view(static_cast(buffer_) + + buffer_offset_ - buffer_aligned_offset_, + buffer_size_); + } +} + +ExternalFileHandler::~ExternalFileHandler() { + if (buffer_ != MAP_FAILED) { + munmap(buffer_, buffer_aligned_size_); + } + if (owned_fd_ >= 0) { + close(owned_fd_); + } +} + +} // namespace core +} // namespace task +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/core/external_file_handler.h b/tensorflow_lite_support/cc/task/core/external_file_handler.h new file mode 100644 index 000000000..400444f8e --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/external_file_handler.h @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_EXTERNAL_FILE_HANDLER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_EXTERNAL_FILE_HANDLER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" + +namespace tflite { +namespace support { +namespace task { +namespace core { + +// Handler providing easy access to the contents of a file specified by an +// ExternalFile proto [1]. Takes care (if needed, depending on the provided +// proto fields) of opening and/or mapping the file in memory at creation time, +// as well as closing and/or unmapping at destruction time. +// +// [1]: support/c/task/core/proto/external_file.proto +class ExternalFileHandler { + public: + // Creates an ExternalFileHandler from the input ExternalFile proto and + // returns a pointer to the new object. Ownership is transferred to the + // caller. Returns an error if the creation failed, which may happen if the + // provided ExternalFile can't be opened or mapped into memory. + // + // Warning: Does not take ownership of `external_file`, which must refer to a + // valid proto that outlives this object. + static StatusOr> CreateFromExternalFile( + const ExternalFile* external_file); + + ~ExternalFileHandler(); + + // Returns the content of the ExternalFile as a string_view guaranteed to be + // valid as long as the ExternalFileHandler is alive. + absl::string_view GetFileContent(); + + private: + // Private constructor, called from CreateFromExternalFile(). + explicit ExternalFileHandler(const ExternalFile* external_file) + : external_file_(*external_file) {} + + // Opens (if provided by path) and maps (if provided by path or file + // descriptor) the external file in memory. Does nothing otherwise, as file + // contents are already loaded in memory. + absl::Status MapExternalFile(); + + // Reference to the input ExternalFile. + const ExternalFile& external_file_; + + // The file descriptor of the ExternalFile if provided by path, as it is + // opened and owned by this class. Set to -1 otherwise. + int owned_fd_{-1}; + + // Points to the memory buffer mapped from the file descriptor of the + // ExternalFile, if provided by path or file descriptor. + void* buffer_{}; + + // The mapped memory buffer offset, if any. + int64 buffer_offset_{}; + // The size in bytes of the mapped memory buffer, if any. + int64 buffer_size_{}; + + // As mmap(2) requires the offset to be a multiple of sysconf(_SC_PAGE_SIZE): + + // The aligned mapped memory buffer offset, if any. + int64 buffer_aligned_offset_{}; + // The aligned mapped memory buffer size in bytes taking into account the + // offset shift introduced by buffer_aligned_memory_offset_, if any. + int64 buffer_aligned_size_{}; +}; + +} // namespace core +} // namespace task +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_EXTERNAL_FILE_HANDLER_H_ diff --git a/tensorflow_lite_support/cc/task/core/proto/BUILD b/tensorflow_lite_support/cc/task/core/proto/BUILD new file mode 100644 index 000000000..3af8f12a3 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/proto/BUILD @@ -0,0 +1,30 @@ +load("@org_tensorflow//tensorflow/core/platform:build_config.bzl", "tf_proto_library") +load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_portable_proto_library") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +tf_proto_library( + name = "external_file_proto", + srcs = ["external_file.proto"], + cc_api_version = 2, +) + +tf_portable_proto_library( + name = "external_file_portable_proto", + config = "proto_config.pbtxt", + header_outs = ["//tensorflow_lite_support/cc/task/core/proto/external_file.proto.h"], + proto_deps = [ + ":external_file_proto", + ], +) + +cc_library( + name = "external_file_proto_inc", + hdrs = ["external_file_proto_inc.h"], + deps = [":external_file_portable_proto"], +) diff --git a/tensorflow_lite_support/cc/task/core/proto/external_file.proto b/tensorflow_lite_support/cc/task/core/proto/external_file.proto new file mode 100644 index 000000000..a3228f51c --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/proto/external_file.proto @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.support.task.core; + + +// Represents external files used by the Task APIs (e.g. TF Lite FlatBuffer or +// plain-text labels file). The files can be specified by one of the following +// three ways: +// +// (1) file contents loaded in `file_content`. +// (2) file path in `file_name`. +// (3) file descriptor through `file_descriptor_meta` as returned by open(2). +// +// If more than one field of these fields is provided, they are used in this +// precedence order. +// Next id: 5 +message ExternalFile { + // The path to the file to open and mmap in memory + optional string file_name = 1; + + // The file contents as a byte array. + optional bytes file_content = 2; + + // The file descriptor to a file opened with open(2), with optional additional + // offset and length information. + optional FileDescriptorMeta file_descriptor_meta = 4; + + // Deprecated field numbers. + reserved 3; +} + +// A proto defining file descriptor metadata for mapping file into memory using +// mmap(2). +message FileDescriptorMeta { + // File descriptor as returned by open(2). + optional int32 fd = 1; + + // Optional length of the mapped memory. If not specified, the actual file + // size is used at runtime. + // + // This is an advanced option, e.g. this can be used on Android to specify the + // length of a given asset obtained from AssetFileDescriptor#getLength(). + optional int64 length = 2; + + // Optional starting offset in the file referred to by the file descriptor + // `fd`. + // + // This is an advanced option, e.g. this can be used on Android to specify the + // offset of a given asset obtained from AssetFileDescriptor#getStartOffset(). + optional int64 offset = 3; +} + diff --git a/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h b/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h new file mode 100644 index 000000000..017aa6511 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_EXTERNAL_FILE_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_EXTERNAL_FILE_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/external_file.pb.h" +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_EXTERNAL_FILE_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt b/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt new file mode 100644 index 000000000..dafb0fde0 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt @@ -0,0 +1,2 @@ +allow_all: true +optimize_mode: LITE_RUNTIME diff --git a/tensorflow_lite_support/cc/task/core/task_api_factory.h b/tensorflow_lite_support/cc/task/core/task_api_factory.h new file mode 100644 index 000000000..dc0f2551f --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/task_api_factory.h @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_API_FACTORY_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_API_FACTORY_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +namespace tflite { +namespace support { +namespace task { +namespace core { +template +using EnableIfBaseUntypedTaskApiSubclass = typename std::enable_if< + std::is_base_of::value>::type*; + +// Template creator for all subclasses of BaseTaskApi +class TaskAPIFactory { + public: + TaskAPIFactory() = delete; + + template = nullptr> + static StatusOr> CreateFromBuffer( + const char* buffer_data, size_t buffer_size, + std::unique_ptr resolver = + absl::make_unique(), + int num_threads = 1) { + auto engine = absl::make_unique(std::move(resolver)); + RETURN_IF_ERROR(engine->BuildModelFromFlatBuffer(buffer_data, buffer_size)); + return CreateFromTfLiteEngine(std::move(engine), num_threads); + } + + template = nullptr> + static StatusOr> CreateFromFile( + const string& file_name, + std::unique_ptr resolver = + absl::make_unique(), + int num_threads = 1) { + auto engine = absl::make_unique(std::move(resolver)); + RETURN_IF_ERROR(engine->BuildModelFromFile(file_name)); + return CreateFromTfLiteEngine(std::move(engine), num_threads); + } + + template = nullptr> + static StatusOr> CreateFromFileDescriptor( + int file_descriptor, + std::unique_ptr resolver = + absl::make_unique(), + int num_threads = 1) { + auto engine = absl::make_unique(std::move(resolver)); + RETURN_IF_ERROR(engine->BuildModelFromFileDescriptor(file_descriptor)); + return CreateFromTfLiteEngine(std::move(engine), num_threads); + } + + template = nullptr> + static StatusOr> CreateFromExternalFileProto( + const ExternalFile* external_file, + std::unique_ptr resolver = + absl::make_unique(), + int num_threads = 1) { + auto engine = absl::make_unique(std::move(resolver)); + RETURN_IF_ERROR(engine->BuildModelFromExternalFileProto(external_file)); + return CreateFromTfLiteEngine(std::move(engine), num_threads); + } + + private: + template = nullptr> + static StatusOr> CreateFromTfLiteEngine( + std::unique_ptr engine, int num_threads) { + RETURN_IF_ERROR(engine->InitInterpreter(num_threads)); + return absl::make_unique(std::move(engine)); + } +}; + +} // namespace core +} // namespace task +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_API_FACTORY_H_ diff --git a/tensorflow_lite_support/cc/task/core/task_utils.cc b/tensorflow_lite_support/cc/task/core/task_utils.cc new file mode 100644 index 000000000..b9d277c74 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/task_utils.cc @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/core/task_utils.h" + +#include + +#include "absl/strings/str_cat.h" + +namespace tflite { +namespace support { +namespace task { +namespace core { + +double Dequantize(const TfLiteTensor& tensor, int index) { + int32_t quantized_value = 0; + switch (tensor.type) { + case kTfLiteUInt8: + quantized_value = GetTensorData(&tensor)[index]; + break; + case kTfLiteInt8: + quantized_value = GetTensorData(&tensor)[index]; + break; + case kTfLiteInt16: + quantized_value = GetTensorData(&tensor)[index]; + break; + default: + TF_LITE_FATAL( + absl::StrCat( + "Invalid tensor type for dequantization ", tensor.name, + ". Requested kTfLiteUInt8, kTfLiteInt8 or kTfLiteInt16, got ", + TfLiteTypeGetName(tensor.type), ".") + .c_str()); + } + return tensor.params.scale * (quantized_value - tensor.params.zero_point); +} + +std::string GetStringAtIndex(const TfLiteTensor* labels, int index) { + const auto& strref = tflite::GetString(labels, index); + return std::string(strref.str, strref.len); +} + +std::string LoadBinaryContent(const char* filename) { + std::ifstream input_file(filename, std::ios::binary | std::ios::ate); + // Find buffer size from input file, and load the buffer. + size_t buffer_size = input_file.tellg(); + std::string buffer(buffer_size, '\0'); + input_file.seekg(0, std::ios::beg); + input_file.read(const_cast(buffer.c_str()), buffer_size); + return buffer; +} + +} // namespace core +} // namespace task +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/core/task_utils.h b/tensorflow_lite_support/cc/task/core/task_utils.h new file mode 100644 index 000000000..5a8acb0d9 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/task_utils.h @@ -0,0 +1,163 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/type_to_tflitetype.h" + +namespace tflite { +namespace support { +namespace task { +namespace core { + +// Checks if data type of tensor is T and returns the pointer casted to T if +// applicable, returns nullptr if tensor type is not T. +// See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType. +template +T* TypedTensor(const TfLiteTensor* tensor_ptr) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + return nullptr; +} + +// Checks and returns type of a tensor, fails if tensor type is not T. +template +T* AssertAndReturnTypedTensor(const TfLiteTensor* tensor) { + if (T* v = TypedTensor(tensor)) return v; + // TODO(b/150903834): throw exceptions instead + TF_LITE_ASSERT(tensor->data.raw); + TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name, + ". Requested ", + TfLiteTypeGetName(typeToTfLiteType()), ", got ", + TfLiteTypeGetName(tensor->type), ".") + .c_str()); +} + +// Populates tensor with array of data, fails if data type doesn't match tensor +// type or has not the same number of elements. +template +inline void PopulateTensor(const T* data, int num_elements, + TfLiteTensor* tensor) { + T* v = AssertAndReturnTypedTensor(tensor); + size_t bytes = num_elements * sizeof(T); + // TODO(b/150903834): throw exceptions instead + TF_LITE_ASSERT(tensor->bytes == bytes); + memcpy(v, data, bytes); +} + +// Populates tensor with vector of data, fails if data type doesn't match tensor +// type or has not the same number of elements. +template +inline void PopulateTensor(const std::vector& data, TfLiteTensor* tensor) { + return PopulateTensor(data.data(), data.size(), tensor); +} + +template <> +inline void PopulateTensor(const std::vector& data, + TfLiteTensor* tensor) { + if (tensor->type != kTfLiteString) { + TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name, + ". Requested STRING, got ", + TfLiteTypeGetName(tensor->type), ".") + .c_str()); + } + tflite::DynamicBuffer input_buf; + for (const auto& value : data) { + input_buf.AddString(value.data(), value.length()); + } + input_buf.WriteToTensorAsVector(tensor); +} + +// Populates tensor one data item, fails if data type doesn't match tensor +// type. +template +inline void PopulateTensor(const T& data, TfLiteTensor* tensor) { + T* v = AssertAndReturnTypedTensor(tensor); + *v = data; +} + +template <> +inline void PopulateTensor(const std::string& data, + TfLiteTensor* tensor) { + tflite::DynamicBuffer input_buf; + input_buf.AddString(data.data(), data.length()); + input_buf.WriteToTensorAsVector(tensor); +} + +// Populates a vector from the tensor, fails if data type doesn't match tensor +// type. +template +inline void PopulateVector(const TfLiteTensor* tensor, std::vector* data) { + AssertAndReturnTypedTensor(tensor); + const T* results = GetTensorData(tensor); + size_t num = tensor->bytes / sizeof(tensor->type); + data->reserve(num); + for (int i = 0; i < num; i++) { + data->emplace_back(results[i]); + } +} + +template <> +inline void PopulateVector(const TfLiteTensor* tensor, + std::vector* data) { + AssertAndReturnTypedTensor(tensor); + int num = GetStringCount(tensor); + data->reserve(num); + for (int i = 0; i < num; i++) { + const auto& strref = tflite::GetString(tensor, i); + data->emplace_back(strref.str, strref.len); + } +} + +// Returns the reversely sorted indices of a vector. +template +std::vector ReverseSortIndices(const std::vector& v) { + std::vector idx(v.size()); + std::iota(idx.begin(), idx.end(), 0); + + std::stable_sort(idx.begin(), idx.end(), + [&v](size_t i1, size_t i2) { return v[i2] < v[i1]; }); + + return idx; +} + +// Returns the original (dequantized) value of the 'index'-th element of +// 'tensor. +double Dequantize(const TfLiteTensor& tensor, int index); + +// Returns the index-th string from the tensor. +std::string GetStringAtIndex(const TfLiteTensor* labels, int index); + +// Loads binary content of a file into a string. +std::string LoadBinaryContent(const char* filename); + +} // namespace core +} // namespace task +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_ diff --git a/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/tensorflow_lite_support/cc/task/core/tflite_engine.cc new file mode 100644 index 000000000..b14af1594 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/tflite_engine.cc @@ -0,0 +1,170 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/tools/verifier.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" + +namespace tflite { +namespace support { +namespace task { +namespace core { + +namespace { +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::TfLiteSupportStatus; +} // namespace + +TfLiteEngine::TfLiteEngine(std::unique_ptr resolver) + : resolver_(std::move(resolver)), verifier_(resolver_.get()) {} + +int TfLiteEngine::ErrorReporter::Report(const char* format, va_list args) { + return std::vsnprintf(error_message, sizeof(error_message), format, args); +} + +bool TfLiteEngine::Verifier::Verify(const char* data, int length, + tflite::ErrorReporter* reporter) { + return tflite::Verify(data, length, *op_resolver_, reporter); +} + +absl::Status TfLiteEngine::BuildModelFromFlatBuffer(const char* buffer_data, + size_t buffer_size) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Model already built"); + } + const char* final_buffer_data = buffer_data; + model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( + final_buffer_data, buffer_size, &verifier_, &error_reporter_); + + if (model_ == nullptr) { + // To be replaced with a proper switch-case when TF Lite model builder + // returns a `TfLiteStatus` code capturing this type of error. + if (absl::StrContains(error_reporter_.error_message, + "The model is not a valid Flatbuffer")) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, error_reporter_.error_message, + TfLiteSupportStatus::kInvalidFlatBufferError); + } else { + // TODO(b/154917059): augment status with another `TfLiteStatus` code when + // ready. And use a new `TfLiteStatus::kCoreTfLiteError` for the TFLS + // code, instead of the unspecified `kError`. + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrCat( + "Could not build model from the provided pre-loaded flatbuffer: ", + error_reporter_.error_message)); + } + } + + ASSIGN_OR_RETURN( + model_metadata_extractor_, + tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer( + buffer_data, buffer_size)); + + return absl::OkStatus(); +} + +absl::Status TfLiteEngine::BuildModelFromFile(const std::string& file_name) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Model already built"); + } + external_file_.set_file_name(file_name); + ASSIGN_OR_RETURN( + model_file_handler_, + ExternalFileHandler::CreateFromExternalFile(&external_file_)); + return BuildModelFromFlatBuffer(model_file_handler_->GetFileContent().data(), + model_file_handler_->GetFileContent().size()); +} + +absl::Status TfLiteEngine::BuildModelFromFileDescriptor(int file_descriptor) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Model already built"); + } + external_file_.mutable_file_descriptor_meta()->set_fd(file_descriptor); + ASSIGN_OR_RETURN( + model_file_handler_, + ExternalFileHandler::CreateFromExternalFile(&external_file_)); + return BuildModelFromFlatBuffer(model_file_handler_->GetFileContent().data(), + model_file_handler_->GetFileContent().size()); +} + +absl::Status TfLiteEngine::BuildModelFromExternalFileProto( + const ExternalFile* external_file) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Model already built"); + } + ASSIGN_OR_RETURN(model_file_handler_, + ExternalFileHandler::CreateFromExternalFile(external_file)); + return BuildModelFromFlatBuffer(model_file_handler_->GetFileContent().data(), + model_file_handler_->GetFileContent().size()); +} + +absl::Status TfLiteEngine::InitInterpreter(int num_threads) { + tflite::proto::ComputeSettings compute_settings; + return InitInterpreter(compute_settings, num_threads); +} + +absl::Status TfLiteEngine::InitInterpreter( + const tflite::proto::ComputeSettings& compute_settings, int num_threads) { + if (model_ == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInternal, + "TF Lite FlatBufferModel is null. Please make sure to call " + "one of the BuildModelFrom methods before."); + } + + auto initializer = + [this, num_threads](std::unique_ptr* interpreter_out) + -> absl::Status { + if (tflite::InterpreterBuilder(*model_, *resolver_)( + interpreter_out, num_threads) != kTfLiteOk) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrCat("Could not build the TF Lite interpreter: ", + error_reporter_.error_message)); + } + if (interpreter_out == nullptr) { + return CreateStatusWithPayload(StatusCode::kInternal, + "TF Lite interpreter is null."); + } + return absl::OkStatus(); + }; + + absl::Status status = + interpreter_.InitializeWithFallback(initializer, compute_settings); + + if (!status.ok() && + !status.GetPayload(tflite::support::kTfLiteSupportPayload).has_value()) { + status = CreateStatusWithPayload(status.code(), status.message()); + } + return status; +} + +} // namespace core +} // namespace task +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/core/tflite_engine.h b/tensorflow_lite_support/cc/task/core/tflite_engine.h new file mode 100644 index 000000000..6b0109b58 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/tflite_engine.h @@ -0,0 +1,148 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ + +#include + +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow_lite_support/cc/port/tflite_wrapper.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" + +namespace tflite { +namespace support { +namespace task { +namespace core { + +// TfLiteEngine encapsulates logic for TFLite model initialization, inference +// and error reporting. +class TfLiteEngine { + public: + explicit TfLiteEngine( + std::unique_ptr resolver = + absl::make_unique()); + // Model is neither copyable nor movable. + TfLiteEngine(const TfLiteEngine&) = delete; + TfLiteEngine& operator=(const TfLiteEngine&) = delete; + + // Accessors. + tflite::FlatBufferModel* model() const { return model_.get(); } + tflite::Interpreter* interpreter() const { return interpreter_.get(); } + tflite::support::TfLiteInterpreterWrapper* interpreter_wrapper() { + return &interpreter_; + } + const tflite::metadata::ModelMetadataExtractor* metadata_extractor() const { + return model_metadata_extractor_.get(); + } + + // Builds the TF Lite FlatBufferModel (model_) from the raw FlatBuffer data + // whose ownership remains with the caller, and which must outlive the current + // object. This performs extra verification on the input data using + // tflite::Verify. + absl::Status BuildModelFromFlatBuffer(const char* buffer_data, + size_t buffer_size); + + // Builds the TF Lite model from a given file. + absl::Status BuildModelFromFile(const std::string& file_name); + + // Builds the TF Lite model from a given file descriptor using mmap(2). + absl::Status BuildModelFromFileDescriptor(int file_descriptor); + + // Builds the TFLite model from the provided ExternalFile proto, which must + // outlive the current object. + absl::Status BuildModelFromExternalFileProto( + const ExternalFile* external_file); + + // Initializes interpreter with encapsulated model. + // Note: setting num_threads to -1 has for effect to let TFLite runtime set + // the value. + absl::Status InitInterpreter(int num_threads = 1); + + // Same as above, but allows specifying `compute_settings` for acceleration. + absl::Status InitInterpreter( + const tflite::proto::ComputeSettings& compute_settings, + int num_threads = 1); + + // Cancels the on-going `Invoke()` call if any and if possible. This method + // can be called from a different thread than the one where `Invoke()` is + // running. + void Cancel() { interpreter_.Cancel(); } + + protected: + // TF Lite's DefaultErrorReporter() outputs to stderr. This one captures the + // error into a string so that it can be used to complement tensorflow::Status + // error messages. + struct ErrorReporter : public tflite::ErrorReporter { + // Last error message captured by this error reporter. + char error_message[256]; + int Report(const char* format, va_list args) override; + }; + // Custom error reporter capturing low-level TF Lite error messages. + ErrorReporter error_reporter_; + + private: + // Direct wrapper around tflite::TfLiteVerifier which checks the integrity of + // the FlatBuffer data provided as input. + class Verifier : public tflite::TfLiteVerifier { + public: + explicit Verifier(const tflite::OpResolver* op_resolver) + : op_resolver_(op_resolver) {} + bool Verify(const char* data, int length, + tflite::ErrorReporter* reporter) override; + // The OpResolver to be used to build the TF Lite interpreter. + const tflite::OpResolver* op_resolver_; + }; + + // TF Lite model and interpreter for actual inference. + std::unique_ptr model_; + + // Interpreter wrapper built from the model. + tflite::support::TfLiteInterpreterWrapper interpreter_; + + // TFLite Metadata extractor built from the model. + std::unique_ptr + model_metadata_extractor_; + + // Mechanism used by TF Lite to map Ops referenced in the FlatBuffer model to + // actual implementation. Defaults to TF Lite BuiltinOpResolver. + std::unique_ptr resolver_; + + // Extra verifier for FlatBuffer input data. + Verifier verifier_; + + // ExternalFile and corresponding ExternalFileHandler for models loaded from + // disk or file descriptor. + ExternalFile external_file_; + std::unique_ptr model_file_handler_; +}; + +} // namespace core +} // namespace task +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD b/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD new file mode 100644 index 000000000..71083a288 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD @@ -0,0 +1,25 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "nl_classifier", + srcs = [ + "nl_classifier.cc", + ], + hdrs = [ + "nl_classifier.h", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:category", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/utils:common_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + ], +) diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc new file mode 100644 index 000000000..25e84a2ec --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc @@ -0,0 +1,244 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" + +#include "absl/algorithm/container.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/utils/common_utils.h" + +namespace tflite { +namespace support { +namespace task { +namespace text { +namespace nlclassifier { + +using ::absl::StatusCode; +using ::flatbuffers::Offset; +using ::flatbuffers::Vector; +using ::tflite::TensorMetadata; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::task::core::Dequantize; +using ::tflite::support::task::core::GetStringAtIndex; +using ::tflite::support::task::core::PopulateTensor; +using ::tflite::support::utils::LoadVocabFromBuffer; + +const NLClassifierOptions& NLClassifier::GetOptions() const { return options_; } + +void NLClassifier::SetOptions(const NLClassifierOptions& options) { + options_ = options; +} + +void NLClassifier::SetLabelsVector( + std::unique_ptr> labels_vector) { + labels_vector_ = std::move(labels_vector); +} + +std::vector NLClassifier::Classify(const std::string& text) { + // The NLClassifier implementation for Preprocess() and Postprocess() never + // returns errors: just call value(). + return Infer(text).value(); +} + +absl::Status NLClassifier::Preprocess( + const std::vector& input_tensors, const std::string& input) { + PopulateTensor( + input, + FindTensorWithNameOrIndex( + input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(), + options_.input_tensor_name, options_.input_tensor_index)); + return absl::OkStatus(); +} + +StatusOr> NLClassifier::Postprocess( + const std::vector& output_tensors, + const std::string& /*input*/) { + auto scores = FindTensorWithNameOrIndex( + output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(), + options_.output_score_tensor_name, options_.output_score_tensor_index); + auto labels_tensor = FindTensorWithNameOrIndex( + output_tensors, GetMetadataExtractor()->GetInputTensorMetadata(), + options_.output_label_tensor_name, options_.output_label_tensor_index); + + bool use_index_as_labels = + (labels_vector_ == nullptr) && (labels_tensor == nullptr); + // Some models output scores with transposed shape [1, categories] + int categories = + scores->dims->size == 2 ? scores->dims->data[1] : scores->dims->data[0]; + + std::vector predictions; + predictions.reserve(categories); + + bool should_dequantize = scores->type == kTfLiteUInt8 || + scores->type == kTfLiteInt8 || + scores->type == kTfLiteInt16; + for (int index = 0; index < categories; index++) { + std::string label; + if (use_index_as_labels) { + label = std::to_string(index); + } else if (labels_vector_ == nullptr) { + label = GetStringAtIndex(labels_tensor, index); + } else { + label = (*labels_vector_)[index]; + } + if (should_dequantize) { + predictions.emplace_back(label, Dequantize(*scores, index)); + } else { + predictions.emplace_back(label, + scores->type == kTfLiteFloat32 + ? GetTensorData(scores)[index] + : GetTensorData(scores)[index]); + } + } + + return predictions; +} + +absl::Status NLClassifier::CheckStatusAndSetOptions( + const NLClassifierOptions& options, NLClassifier* nl_classifier) { + nl_classifier->SetOptions(options); + // input tensor should be type STRING + auto input_tensor = FindTensorWithNameOrIndex( + nl_classifier->GetInputTensors(), + nl_classifier->GetMetadataExtractor()->GetInputTensorMetadata(), + options.input_tensor_name, options.input_tensor_index); + if (input_tensor == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("No input tensor found with name ", + options.input_tensor_name, " or at index ", + options.input_tensor_index), + TfLiteSupportStatus::kInputTensorNotFoundError); + } + if (input_tensor->type != kTfLiteString) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("Type mismatch for input tensor ", input_tensor->name, + ". Requested STRING, got ", + TfLiteTypeGetName(input_tensor->type), "."), + TfLiteSupportStatus::kInvalidInputTensorTypeError); + } + + // output score tensor should be type + // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) + std::vector output_tensors = + nl_classifier->GetOutputTensors(); + const Vector>* output_tensor_metadatas = + nl_classifier->GetMetadataExtractor()->GetOutputTensorMetadata(); + + const auto scores = FindTensorWithNameOrIndex( + output_tensors, output_tensor_metadatas, options.output_score_tensor_name, + options.output_score_tensor_index); + if (scores == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("No output score tensor found with name ", + options.output_score_tensor_name, " or at index ", + options.output_score_tensor_index), + TfLiteSupportStatus::kOutputTensorNotFoundError); + } + static constexpr TfLiteType valid_types[] = { + kTfLiteUInt8, kTfLiteInt8, kTfLiteInt16, kTfLiteFloat32, kTfLiteFloat64}; + if (!absl::c_linear_search(valid_types, scores->type)) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("Type mismatch for score tensor ", scores->name, + ". Requested one of these types: " + "INT8/UINT8/INT16/FLOAT32/FLOAT64, got ", + TfLiteTypeGetName(scores->type), "."), + TfLiteSupportStatus::kInvalidOutputTensorTypeError); + } + + // Extract associated label file from output score tensor if one exists, a + // well-formatted metadata should have same number of tensors with the model. + if (output_tensor_metadatas && + output_tensor_metadatas->size() == output_tensors.size()) { + for (const auto& metadata : *output_tensor_metadatas) { + if (metadata->name() && + metadata->name()->string_view() == options.output_score_tensor_name) { + const auto associated_files = metadata->associated_files(); + if (associated_files && associated_files->size() >= 0 && + associated_files->Get(0)->name()) { + StatusOr label_buffer = + nl_classifier->GetMetadataExtractor()->GetAssociatedFile( + associated_files->Get(0)->name()->str()); + if (label_buffer.ok()) { + nl_classifier->SetLabelsVector( + absl::make_unique>(LoadVocabFromBuffer( + label_buffer.value().data(), label_buffer.value().size()))); + } + } + } + } + } + + // output label tensor should be type STRING if the one exists + auto labels = FindTensorWithNameOrIndex( + output_tensors, output_tensor_metadatas, options.output_label_tensor_name, + options.output_label_tensor_index); + if (labels != nullptr && labels->type != kTfLiteString) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("Type mismatch for label tensor ", scores->name, + ". Requested STRING, got ", + TfLiteTypeGetName(scores->type), "."), + TfLiteSupportStatus::kInvalidOutputTensorTypeError); + } + return absl::OkStatus(); +} + +StatusOr> NLClassifier::CreateNLClassifier( + const char* model_buffer_data, size_t model_buffer_size, + const NLClassifierOptions& options, + std::unique_ptr resolver) { + std::unique_ptr nl_classifier; + ASSIGN_OR_RETURN( + nl_classifier, + core::TaskAPIFactory::CreateFromBuffer( + model_buffer_data, model_buffer_size, std::move(resolver))); + RETURN_IF_ERROR(CheckStatusAndSetOptions(options, nl_classifier.get())); + return std::move(nl_classifier); +} + +StatusOr> NLClassifier::CreateNLClassifier( + const std::string& path_to_model, const NLClassifierOptions& options, + std::unique_ptr resolver) { + std::unique_ptr nl_classifier; + ASSIGN_OR_RETURN(nl_classifier, + core::TaskAPIFactory::CreateFromFile( + path_to_model, std::move(resolver))); + RETURN_IF_ERROR(CheckStatusAndSetOptions(options, nl_classifier.get())); + return std::move(nl_classifier); +} + +StatusOr> NLClassifier::CreateNLClassifier( + int fd, const NLClassifierOptions& options, + std::unique_ptr resolver) { + std::unique_ptr nl_classifier; + ASSIGN_OR_RETURN(nl_classifier, + core::TaskAPIFactory::CreateFromFileDescriptor( + fd, std::move(resolver))); + RETURN_IF_ERROR(CheckStatusAndSetOptions(options, nl_classifier.get())); + return std::move(nl_classifier); +} + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h new file mode 100644 index 000000000..53a09c216 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h @@ -0,0 +1,147 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_ + +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/category.h" + +namespace tflite { +namespace support { +namespace task { +namespace text { +namespace nlclassifier { + +// Options to identify input and output tensors of the model +struct NLClassifierOptions { + int input_tensor_index = 0; + int output_score_tensor_index = 0; + // By default there is no output label tensor. The label file can be attached + // to the output score tensor metadata. See NLClassifier for more + // information. + int output_label_tensor_index = -1; + std::string input_tensor_name = "INPUT"; + std::string output_score_tensor_name = "OUTPUT_SCORE"; + std::string output_label_tensor_name = "OUTPUT_LABEL"; +}; + +// Classifier API for NLClassification tasks, categorizes string into different +// classes. +// +// The API expects a TFLite model with the following input/output tensor: +// Input tensor: +// (kTfLiteString) - input of the model, accepts a string. +// Output score tensor: +// (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64) +// - output scores for each class, if type is one of the Int types, +// dequantize it to double +// - can have an optional associated file in metadata for labels, the file +// should be a plain text file with one label per line, the number of +// labels should match the number of categories the model outputs. +// Output label tensor: optional +// (kTfLiteString) +// - output classname for each class, should be of the same length with +// scores. If this tensor is not present, the API uses score indices as +// classnames. +// - will be ignored if output score tensor already has an associated label +// file. +// +// By default the API tries to find the input/output tensors with default +// configurations in NLClassifierOptions, with tensor name prioritized over +// tensor index. The option is configurable for different TFLite models. +class NLClassifier : public core::BaseTaskApi, + const std::string&> { + public: + using BaseTaskApi::BaseTaskApi; + + // Creates a NLClassifier from TFLite model buffer. + static StatusOr> CreateNLClassifier( + const char* model_buffer_data, size_t model_buffer_size, + const NLClassifierOptions& options = {}, + std::unique_ptr resolver = + absl::make_unique()); + + // Creates a NLClassifier from TFLite model file. + static StatusOr> CreateNLClassifier( + const std::string& path_to_model, const NLClassifierOptions& options = {}, + std::unique_ptr resolver = + absl::make_unique()); + + // Creates a NLClassifier from TFLite model file descriptor. + static StatusOr> CreateNLClassifier( + int fd, const NLClassifierOptions& options = {}, + std::unique_ptr resolver = + absl::make_unique()); + + // Performs classification on a string input, returns classified results. + std::vector Classify(const std::string& text); + + protected: + const NLClassifierOptions& GetOptions() const; + void SetOptions(const NLClassifierOptions& options); + void SetLabelsVector(std::unique_ptr> labels_vector); + absl::Status Preprocess(const std::vector& input_tensors, + const std::string& input) override; + + StatusOr> Postprocess( + const std::vector& output_tensors, + const std::string& input) override; + + // Gets the tensor from a vector of tensors by checking tensor name first and + // tensor index second, return nullptr if no tensor is found. + template + static TensorType* FindTensorWithNameOrIndex( + const std::vector& tensors, + const flatbuffers::Vector>* + metadata_array, + const std::string& name, int index) { + if (metadata_array != nullptr && metadata_array->size() == tensors.size()) { + for (int i = 0; i < metadata_array->size(); i++) { + if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) { + return tensors[i]; + } + } + } + + for (TensorType* tensor : tensors) { + if (tensor->name == name) { + return tensor; + } + } + return index >= 0 && index < tensors.size() ? tensors[index] : nullptr; + } + + // Set options and validate model with options. + static absl::Status CheckStatusAndSetOptions( + const NLClassifierOptions& options, NLClassifier* nl_classifier); + + private: + NLClassifierOptions options_; + // labels vector initialized from output tensor's associated file, if one + // exists. + std::unique_ptr> labels_vector_; +}; + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_ diff --git a/tensorflow_lite_support/cc/task/text/qa/BUILD b/tensorflow_lite_support/cc/task/text/qa/BUILD new file mode 100644 index 000000000..b7055f7bf --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/BUILD @@ -0,0 +1,39 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "question_answerer", + hdrs = [ + "question_answerer.h", + ], + deps = [ + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], +) + +cc_library( + name = "bert_question_answerer", + srcs = [ + "bert_question_answerer.cc", + ], + hdrs = [ + "bert_question_answerer.h", + ], + deps = [ + ":question_answerer", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:sentencepiece_tokenizer", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc new file mode 100644 index 000000000..b5dbda463 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc @@ -0,0 +1,289 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h" + +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" + +namespace tflite { +namespace support { +namespace task { +namespace text { +namespace qa { + +using ::tflite::support::task::core::PopulateTensor; +using ::tflite::support::task::core::PopulateVector; +using ::tflite::support::task::core::ReverseSortIndices; +using ::tflite::support::text::tokenizer::BertTokenizer; +using ::tflite::support::text::tokenizer::SentencePieceTokenizer; +using ::tflite::support::text::tokenizer::TokenizerResult; + +StatusOr> +BertQuestionAnswerer::CreateBertQuestionAnswerer( + const std::string& path_to_model, const std::string& path_to_vocab) { + std::unique_ptr api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromFile( + path_to_model, + absl::make_unique(), + kNumLiteThreads)); + api_to_init->InitializeVocab(path_to_vocab); + return api_to_init; +} + +StatusOr> +BertQuestionAnswerer::CreateBertQuestionAnswererFromBinary( + const char* model_buffer_data, size_t model_buffer_size, + const char* vocab_buffer_data, size_t vocab_buffer_size) { + std::unique_ptr api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromBuffer( + model_buffer_data, model_buffer_size, + absl::make_unique(), + kNumLiteThreads)); + api_to_init->InitializeVocabFromBinary(vocab_buffer_data, vocab_buffer_size); + return api_to_init; +} + +StatusOr> +BertQuestionAnswerer::CreateAlbertQuestionAnswerer( + const std::string& path_to_model, const std::string& path_to_spmodel) { + std::unique_ptr api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromFile( + path_to_model, + absl::make_unique(), + kNumLiteThreads)); + api_to_init->InitializeSPModel(path_to_spmodel); + return api_to_init; +} + +StatusOr> +BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBinary( + const char* model_buffer_data, size_t model_buffer_size, + const char* spmodel_buffer_data, size_t spmodel_buffer_size) { + std::unique_ptr api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromBuffer( + model_buffer_data, model_buffer_size, + absl::make_unique(), + kNumLiteThreads)); + api_to_init->InitializeSPModelFromBinary(spmodel_buffer_data, + spmodel_buffer_size); + return api_to_init; +} + +std::vector BertQuestionAnswerer::Answer( + const std::string& context, const std::string& question) { + // The BertQuestionAnswererer implementation for Preprocess() and + // Postprocess() never returns errors: just call value(). + return Infer(context, question).value(); +} + +absl::Status BertQuestionAnswerer::Preprocess( + const std::vector& input_tensors, const std::string& context, + const std::string& query) { + token_to_orig_map_.clear(); + + // The orig_tokens is used for recovering the answer string from the index, + // while the processed_tokens is lower-cased and used to generate input of + // the model. + orig_tokens_ = absl::StrSplit(context, absl::ByChar(' '), absl::SkipEmpty()); + std::vector processed_tokens(orig_tokens_); + + std::string processed_query = query; + if (kUseLowerCase) { + for (auto& token : processed_tokens) { + absl::AsciiStrToLower(&token); + } + absl::AsciiStrToLower(&processed_query); + } + + TokenizerResult query_tokenize_results; + query_tokenize_results = tokenizer_->Tokenize(processed_query); + + std::vector query_tokens = query_tokenize_results.subwords; + if (query_tokens.size() > kMaxQueryLen) { + query_tokens.resize(kMaxQueryLen); + } + + // Example: + // context: tokenize me please + // all_doc_tokens: token ##ize me plea ##se + // token_to_orig_index: [0, 0, 1, 2, 2] + + std::vector all_doc_tokens; + std::vector token_to_orig_index; + for (size_t i = 0; i < processed_tokens.size(); i++) { + const std::string& token = processed_tokens[i]; + std::vector sub_tokens = tokenizer_->Tokenize(token).subwords; + for (const std::string& sub_token : sub_tokens) { + token_to_orig_index.emplace_back(i); + all_doc_tokens.emplace_back(sub_token); + } + } + + // -3 accounts for [CLS], [SEP] and [SEP]. + int max_context_len = kMaxSeqLen - query_tokens.size() - 3; + if (all_doc_tokens.size() > max_context_len) { + all_doc_tokens.resize(max_context_len); + } + + std::vector tokens; + tokens.reserve(3 + query_tokens.size() + all_doc_tokens.size()); + std::vector segment_ids; + segment_ids.reserve(kMaxSeqLen); + + // Start of generating the features. + tokens.emplace_back("[CLS]"); + segment_ids.emplace_back(0); + + // For query input. + for (const auto& query_token : query_tokens) { + tokens.emplace_back(query_token); + segment_ids.emplace_back(0); + } + + // For Separation. + tokens.emplace_back("[SEP]"); + segment_ids.emplace_back(0); + + // For Text Input. + for (int i = 0; i < all_doc_tokens.size(); i++) { + auto& doc_token = all_doc_tokens[i]; + tokens.emplace_back(doc_token); + segment_ids.emplace_back(1); + token_to_orig_map_[tokens.size()] = token_to_orig_index[i]; + } + + // For ending mark. + tokens.emplace_back("[SEP]"); + segment_ids.emplace_back(1); + + std::vector input_ids(tokens.size()); + input_ids.reserve(kMaxSeqLen); + // Convert tokens back into ids + for (int i = 0; i < tokens.size(); i++) { + auto& token = tokens[i]; + tokenizer_->LookupId(token, &input_ids[i]); + } + + std::vector input_mask; + input_mask.reserve(kMaxSeqLen); + input_mask.insert(input_mask.end(), tokens.size(), 1); + + int zeros_to_pad = kMaxSeqLen - input_ids.size(); + input_ids.insert(input_ids.end(), zeros_to_pad, 0); + input_mask.insert(input_mask.end(), zeros_to_pad, 0); + segment_ids.insert(segment_ids.end(), zeros_to_pad, 0); + + // input_tensors[0]: input_ids INT32[1, 384] + PopulateTensor(input_ids, input_tensors[0]); + // input_tensors[1]: input_mask INT32[1, 384] + PopulateTensor(input_mask, input_tensors[1]); + // input_tensors[2]: segment_ids INT32[1, 384] + PopulateTensor(segment_ids, input_tensors[2]); + + return absl::OkStatus(); +} + +StatusOr> BertQuestionAnswerer::Postprocess( + const std::vector& output_tensors, + const std::string& /*lowercased_context*/, + const std::string& /*lowercased_query*/) { + // convert output tensors back to string, float maps back here + + std::vector end_logits; + std::vector start_logits; + + // output_tensors[0]: end_logits FLOAT[1, 384] + PopulateVector(output_tensors[0], &end_logits); + // output_tensors[1]: start_logits FLOAT[1, 384] + PopulateVector(output_tensors[1], &start_logits); + + auto start_indices = ReverseSortIndices(start_logits); + auto end_indices = ReverseSortIndices(end_logits); + + std::vector orig_results; + for (int start_index = 0; start_index < kPredictAnsNum; start_index++) { + for (int end_index = 0; end_index < kPredictAnsNum; end_index++) { + int start = start_indices[start_index]; + int end = end_indices[end_index]; + + if (!token_to_orig_map_.contains(start + kOutputOffset) || + !token_to_orig_map_.contains(end + kOutputOffset) || end < start || + (end - start + 1) > kMaxAnsLen) { + continue; + } + orig_results.emplace_back( + QaAnswer::Pos(start, end, start_logits[start] + end_logits[end])); + } + } + + std::sort(orig_results.begin(), orig_results.end()); + + std::vector answers; + for (int i = 0; i < orig_results.size() && i < kPredictAnsNum; i++) { + auto orig_pos = orig_results[i]; + answers.emplace_back( + orig_pos.start > 0 ? ConvertIndexToString(orig_pos.start, orig_pos.end) + : "", + orig_pos); + } + + return answers; +} + +std::string BertQuestionAnswerer::ConvertIndexToString(int start, int end) { + int start_index = token_to_orig_map_[start + kOutputOffset]; + int end_index = token_to_orig_map_[end + kOutputOffset]; + + return absl::StrJoin(orig_tokens_.begin() + start_index, + orig_tokens_.begin() + end_index + 1, " "); +} + +void BertQuestionAnswerer::InitializeVocab(const std::string& path_to_vocab) { + tokenizer_ = absl::make_unique(path_to_vocab); +} + +void BertQuestionAnswerer::InitializeVocabFromBinary( + const char* vocab_buffer_data, size_t vocab_buffer_size) { + tokenizer_ = + absl::make_unique(vocab_buffer_data, vocab_buffer_size); +} + +void BertQuestionAnswerer::InitializeSPModel( + const std::string& path_to_spmodel) { + tokenizer_ = absl::make_unique(path_to_spmodel); +} + +void BertQuestionAnswerer::InitializeSPModelFromBinary( + const char* spmodel_buffer_data, size_t spmodel_buffer_size) { + tokenizer_ = absl::make_unique(spmodel_buffer_data, + spmodel_buffer_size); +} + +} // namespace qa +} // namespace text +} // namespace task +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h new file mode 100644 index 000000000..2f1fcb5cf --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h @@ -0,0 +1,135 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" + +namespace tflite { +namespace support { +namespace task { +namespace text { +namespace qa { + +// BertQA task API, performs tokenization for models (BERT, Albert, etc.) in +// preprocess and returns most possible answers. +// +// In particular, the branch of BERT models use WordPiece tokenizer, and the +// branch of Albert models use SentencePiece tokenizer, respectively. +// +// Factory methods: +// CreateBertQuestionAnswerer(path_to_model, path_to_vocab) +// Creates a BertQuestionAnswerer from TFLite model file and vocab file for +// WordPiece tokenizer. Used in C++ environment. +// One suitable model is: +// https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 +// +// CreateBertQuestionAnswererFromBinary(model_buffer_data, model_buffer_size, +// vocab_buffer_data, vocab_buffer_size) +// Creates a BertQuestionAnswerer from TFLite model buffer and vocab file +// buffer for WordPiece tokenizer. Used in Jave (JNI) environment. +// +// CreateAlbertQuestionAnswerer(path_to_model, path_to_spmodel) +// Creates an AlbertQuestionAnswerer from TFLite model file and +// SentencePiece model file. Used in C++ environment. +// One suitable model is: +// https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 +// +// CreateAlbertQuestionAnswererFromBinary(model_buffer_data, +// model_buffer_size, +// spmodel_buffer_data, +// spmodel_buffer_size) +// Creates an AlbertQuestionAnswerer from TFLite model file buffer and +// SentencePiece model file buffer. Used in Jave (JNI) environment. +class BertQuestionAnswerer : public QuestionAnswerer { + public: + // TODO(b/150904655): add support to parameterize. + static constexpr int kMaxQueryLen = 64; + static constexpr int kMaxSeqLen = 384; + static constexpr int kPredictAnsNum = 5; + static constexpr int kMaxAnsLen = 32; + // TODO(b/151954803): clarify the offset usage + static constexpr int kOutputOffset = 1; + static constexpr int kNumLiteThreads = 4; + static constexpr bool kUseLowerCase = true; + + static StatusOr> CreateBertQuestionAnswerer( + const std::string& path_to_model, const std::string& path_to_vocab); + + static StatusOr> + CreateBertQuestionAnswererFromBinary(const char* model_buffer_data, + size_t model_buffer_size, + const char* vocab_buffer_data, + size_t vocab_buffer_size); + + static StatusOr> + CreateAlbertQuestionAnswerer(const std::string& path_to_model, + const std::string& path_to_spmodel); + + static StatusOr> + CreateAlbertQuestionAnswererFromBinary(const char* model_buffer_data, + size_t model_buffer_size, + const char* spmodel_buffer_data, + size_t spmodel_buffer_size); + + explicit BertQuestionAnswerer(std::unique_ptr engine) + : QuestionAnswerer(std::move(engine)) {} + + // Answers question based on the context. + std::vector Answer(const std::string& context, + const std::string& question) override; + + private: + absl::Status Preprocess(const std::vector& input_tensors, + const std::string& lowercased_context, + const std::string& lowercased_query) override; + + StatusOr> Postprocess( + const std::vector& output_tensors, + const std::string& lowercased_context, + const std::string& lowercased_query) override; + + void InitializeVocab(const std::string& path_to_vocab); + void InitializeVocabFromBinary(const char* vocab_buffer_data, + size_t vocab_buffer_size); + void InitializeSPModel(const std::string& path_to_spmodel); + void InitializeSPModelFromBinary(const char* spmodel_buffer_data, + size_t spmodel_buffer_size); + + std::string ConvertIndexToString(int start, int end); + + std::unique_ptr tokenizer_; + // Maps index of input token to index of untokenized word from original input. + absl::flat_hash_map token_to_orig_map_; + // Original tokens of context. + std::vector orig_tokens_; +}; + +} // namespace qa +} // namespace text +} // namespace task +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ diff --git a/tensorflow_lite_support/cc/task/text/qa/question_answerer.h b/tensorflow_lite_support/cc/task/text/qa/question_answerer.h new file mode 100644 index 000000000..bb65ef3ad --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/question_answerer.h @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_ + +#include +#include +#include + +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +namespace tflite { +namespace support { +namespace task { +namespace text { +namespace qa { + +// Struct for the Answer to QuestionAnswerer. +struct QaAnswer { + // struct to represent the logit and offset of the answer related to context. + struct Pos { + Pos(int arg_start, int arg_end, float arg_logit) + : start(arg_start), end(arg_end), logit(arg_logit) {} + int start, end; + float logit; + bool operator<(const Pos& rhs) const { return rhs.logit < logit; } + }; + + QaAnswer(std::string arg_text, Pos arg_pos) + : text(std::move(arg_text)), pos(arg_pos) {} + std::string text; + Pos pos; +}; + +// Interface for an Question-Answer API. +class QuestionAnswerer + : public core::BaseTaskApi, const std::string&, + const std::string&> { + public: + explicit QuestionAnswerer(std::unique_ptr engine) + : BaseTaskApi(std::move(engine)) {} + + virtual std::vector Answer(const std::string& context, + const std::string& question) = 0; +}; + +} // namespace qa +} // namespace text +} // namespace task +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/BUILD b/tensorflow_lite_support/cc/text/tokenizers/BUILD new file mode 100644 index 000000000..e70020ca6 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/BUILD @@ -0,0 +1,151 @@ +# This package contains C++ support libraries that Java libraries can invoke. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load( + "@org_tensorflow//tensorflow/lite:build_def.bzl", + "tflite_copts", + "tflite_jni_binary", +) + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tokenizer", + hdrs = [ + "tokenizer.h", + ], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "tokenizer_jni_lib", + srcs = [ + "tokenizer_jni_lib.cc", + ], + hdrs = [ + "tokenizer_jni_lib.h", + ], + deps = [ + ":tokenizer", + "//tensorflow_lite_support/cc/utils:jni_utils", + "@org_tensorflow//tensorflow/lite/java/jni", + ], +) + +cc_library( + name = "bert_tokenizer", + srcs = [ + "bert_tokenizer.cc", + ], + hdrs = [ + "bert_tokenizer.h", + ], + deps = [ + ":tokenizer", + "//tensorflow_lite_support/cc/utils:common_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_googlesource_code_re2//:re2", + "@org_tensorflow_text//tensorflow_text/core/kernels:regex_split", + "@org_tensorflow_text//tensorflow_text/core/kernels:wordpiece_tokenizer", + ], +) + +cc_library( + name = "bert_tokenizer_jni_lib", + srcs = [ + "bert_tokenizer_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + ":bert_tokenizer", + ":tokenizer_jni_lib", + "//tensorflow_lite_support/cc/utils:jni_utils", + "@org_tensorflow//tensorflow/lite/java/jni", + ], + alwayslink = 1, +) + +tflite_jni_binary( + name = "libbert_tokenizer_jni.so", + deps = [ + ":bert_tokenizer_jni_lib", + ], +) + +cc_library( + name = "bert_tokenizer_runtime", + srcs = ["libbert_tokenizer_jni.so"], + alwayslink = 1, +) + +android_library( + name = "bert_tokenizer_jni", + custom_package = "org.tensorflow.lite.support.text", + manifest = "DummyManifest.xml", + resource_files = [], + deps = [ + ":bert_tokenizer_runtime", # build_cleaner: skip + ], +) + +cc_library( + name = "sentencepiece_tokenizer", + hdrs = [ + "sentencepiece_tokenizer.h", + ], + deps = [ + ":tokenizer", + "@com_google_sentencepiece//src:sentencepiece_processor", + ], +) + +cc_library( + name = "sentencepiece_jni_lib", + srcs = [ + "sentencepiece_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + ":sentencepiece_tokenizer", + ":tokenizer_jni_lib", + "//tensorflow_lite_support/cc/utils:jni_utils", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/java/jni", + ], + alwayslink = 1, +) + +cc_library( + name = "sentencepiece_runtime", + srcs = ["libsentencepiece_jni.so"], + alwayslink = 1, +) + +tflite_jni_binary( + name = "libsentencepiece_jni.so", + deps = [ + ":sentencepiece_jni_lib", + ], +) + +android_library( + name = "sentencepiece_jni", + custom_package = "org.tensorflow.lite.support.text", + manifest = "DummyManifest.xml", + resource_files = [], + deps = [ + ":sentencepiece_runtime", # build_cleaner: skip + ], +) diff --git a/tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml b/tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml new file mode 100644 index 000000000..ff025072c --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml @@ -0,0 +1,19 @@ + + + + diff --git a/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc new file mode 100644 index 000000000..8ffd52597 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc @@ -0,0 +1,107 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" + +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + +FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece( + const std::vector& vocab) + : vocab_{vocab} { + for (int i = 0; i < vocab_.size(); ++i) { + index_map_[vocab_[i]] = i; + } +} + +tensorflow::text::LookupStatus FlatHashMapBackedWordpiece::Contains( + absl::string_view key, bool* value) const { + *value = index_map_.contains(key); + return tensorflow::text::LookupStatus(); +} + +bool FlatHashMapBackedWordpiece::LookupId(const absl::string_view key, + int* result) const { + auto it = index_map_.find(key); + if (it == index_map_.end()) { + return false; + } + *result = it->second; + return true; +} + +bool FlatHashMapBackedWordpiece::LookupWord(int vocab_id, + absl::string_view* result) const { + if (vocab_id >= vocab_.size() || vocab_id < 0) { + return false; + } + *result = vocab_[vocab_id]; + return true; +} + +TokenizerResult BertTokenizer::Tokenize(const std::string& input) { + return TokenizeWordpiece(input); +} + +WordpieceTokenizerResult BertTokenizer::TokenizeWordpiece( + const std::string& input) { + WordpieceTokenizerResult result; + std::vector& subwords = result.subwords; + std::vector& wp_absolute_begin_offset = result.wp_begin_offset; + std::vector& wp_absolute_end_offset = result.wp_end_offset; + + std::vector tokens; + // NO lint as int64 is not defined in tensorflow lite scope + std::vector begin_offsets; // NOLINT + std::vector end_offsets; // NOLINT + + // Run through tokenize function + tensorflow::text::RegexSplit(input, delim_re_, true, include_delim_re_, + &tokens, &begin_offsets, &end_offsets); + + for (int token_index = 0; token_index < tokens.size(); token_index++) { + auto& token = tokens[token_index]; + int num_word_pieces = 0; + tensorflow::text::LookupStatus status = WordpieceTokenize( + token, options_.max_bytes_per_token, options_.max_chars_per_subtoken, + options_.suffix_indicator, options_.use_unknown_token, + options_.unknown_token, options_.split_unknown_chars, &vocab_, + &subwords, &wp_absolute_begin_offset, &wp_absolute_end_offset, + &num_word_pieces); + + result.row_lengths.emplace_back(num_word_pieces); + // for the last num_word_pieces added into wp_absolute_begin_offset and + // wp_absolute_end_offset, offset them with begin_offsets[token_index] + int absolute_offset_size = wp_absolute_begin_offset.size(); + for (int i = num_word_pieces; i > 0; i--) { + wp_absolute_begin_offset[absolute_offset_size - i] += + begin_offsets[token_index]; + wp_absolute_end_offset[absolute_offset_size - i] += + begin_offsets[token_index]; + } + if (!status.success) { + return result; + } + } + + return result; +} + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h new file mode 100644 index 000000000..bc3d31015 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h @@ -0,0 +1,143 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "re2/re2.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/cc/utils/common_utils.h" +#include "tensorflow_text/core/kernels/regex_split.h" +#include "tensorflow_text/core/kernels/wordpiece_tokenizer.h" + +namespace tflite::support::text::tokenizer { + +constexpr char kDefaultDelimRe[] = + R"((\s+|[!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))"; +constexpr char kDefaultIncludeDelimRe[] = + R"(([!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))"; +constexpr int kDefaultMaxBytesPerToken = 100; +constexpr int kDefaultMaxCharsPerSubToken = 100; +constexpr char kDefaultSuffixIndicator[] = "##"; +constexpr bool kDefaultUseUnknownToken = true; +constexpr char kDefaultUnknownToken[] = "[UNK]"; +constexpr bool kDefaultSplitUnknownChars = false; + +// Result of wordpiece tokenization including subwords and offsets. +// Example: +// input: tokenize me please +// subwords: token ##ize me plea ##se +// wp_begin_offset: [0, 5, 9, 12, 16] +// wp_end_offset: [ 5, 8, 11, 16, 18] +// row_lengths: [2, 1, 1] +struct WordpieceTokenizerResult : TokenizerResult { + std::vector wp_begin_offset; + std::vector wp_end_offset; + std::vector row_lengths; +}; +// Options to create a BertTokenizer. +struct BertTokenizerOptions { + int max_bytes_per_token = kDefaultMaxBytesPerToken; + int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken; + std::string suffix_indicator = kDefaultSuffixIndicator; + bool use_unknown_token = kDefaultUseUnknownToken; + std::string unknown_token = kDefaultUnknownToken; + bool split_unknown_chars = kDefaultSplitUnknownChars; + std::string delim_str = kDefaultDelimRe; + std::string include_delim_str = kDefaultIncludeDelimRe; +}; + +// A flat-hash-map based implementation of WordpieceVocab, used in +// BertTokenizer to invoke tensorflow::text::WordpieceTokenize within. +class FlatHashMapBackedWordpiece : public tensorflow::text::WordpieceVocab { + public: + explicit FlatHashMapBackedWordpiece(const std::vector& vocab); + + tensorflow::text::LookupStatus Contains(absl::string_view key, + bool* value) const override; + bool LookupId(absl::string_view key, int* result) const; + bool LookupWord(int vocab_id, absl::string_view* result) const; + int VocabularySize() const { return vocab_.size(); } + + private: + // All words indexed position in vocabulary file. + std::vector vocab_; + absl::flat_hash_map index_map_; +}; + +// Wordpiece tokenizer for bert models. Initialized with a vocab file or vector. +class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer { + public: + // Initialize the tokenizer from vocab vector and tokenizer configs. + explicit BertTokenizer(const std::vector& vocab, + const BertTokenizerOptions& options = {}) + : vocab_{FlatHashMapBackedWordpiece(vocab)}, + options_{options}, + delim_re_{options.delim_str}, + include_delim_re_{options.include_delim_str} {} + + // Initialize the tokenizer from file path to vocab and tokenizer configs. + explicit BertTokenizer(const std::string& path_to_vocab, + const BertTokenizerOptions& options = {}) + : BertTokenizer(utils::LoadVocabFromFile(path_to_vocab), options) {} + + // Initialize the tokenizer from buffer and size of vocab and tokenizer + // configs. + BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size, + const BertTokenizerOptions& options = {}) + : BertTokenizer( + utils::LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size), + options) {} + + // Perform tokenization, return tokenized results containing the subwords. + TokenizerResult Tokenize(const std::string& input) override; + + // Perform tokenization, return wordpiece-specific tokenized result including + // subwords and offsets + WordpieceTokenizerResult TokenizeWordpiece(const std::string& input); + + // Check if a certain key is included in the vocab. + tensorflow::text::LookupStatus Contains(const absl::string_view key, + bool* value) const { + return vocab_.Contains(key, value); + } + + // Find the id of a wordpiece. + bool LookupId(absl::string_view key, int* result) const override { + return vocab_.LookupId(key, result); + } + + // Find the wordpiece from an id. + bool LookupWord(int vocab_id, absl::string_view* result) const override { + return vocab_.LookupWord(vocab_id, result); + } + + int VocabularySize() const { return vocab_.VocabularySize(); } + + private: + tflite::support::text::tokenizer::FlatHashMapBackedWordpiece vocab_; + BertTokenizerOptions options_; + RE2 delim_re_; + RE2 include_delim_re_; +}; + +} // namespace tflite::support::text::tokenizer + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc new file mode 100644 index 000000000..748f7f477 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include + +#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace support { + +using ::tflite::support::text::tokenizer::BertTokenizer; +using ::tflite::support::text::tokenizer::BertTokenizerOptions; +using ::tflite::support::utils::StringListToVector; + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResource( // NOLINT + JNIEnv* env, jobject thiz, jobject vocab_list, jint max_bytes_per_token, + jint max_chars_per_sub_token, jstring jsuffix_indicator, + jboolean use_unknown_token, jstring junknown_token, + jboolean split_unknown_chars) { + // Convert java.util.List into std::vector + std::vector vocab = StringListToVector(env, vocab_list); + + // Convert jstrings to std::string + const char* raw_suffix_indicator = + env->GetStringUTFChars(jsuffix_indicator, JNI_FALSE); + std::string suffix_indicator(raw_suffix_indicator); + + const char* raw_unknown_token = + env->GetStringUTFChars(junknown_token, JNI_FALSE); + std::string unknown_token(raw_unknown_token); + + auto handle = absl::MakeUnique( + vocab, BertTokenizerOptions{ + .max_bytes_per_token = max_bytes_per_token, + .max_chars_per_subtoken = max_chars_per_sub_token, + .suffix_indicator = suffix_indicator, + .use_unknown_token = static_cast(use_unknown_token), + .unknown_token = unknown_token, + .split_unknown_chars = static_cast(split_unknown_chars), + .delim_str = text::tokenizer::kDefaultDelimRe, + .include_delim_str = text::tokenizer::kDefaultIncludeDelimRe}); + + env->ReleaseStringUTFChars(jsuffix_indicator, raw_suffix_indicator); + env->ReleaseStringUTFChars(junknown_token, raw_unknown_token); + + return reinterpret_cast(handle.release()); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeUnloadResource( // NOLINT + JNIEnv* env, jobject thiz, jlong handle) { + delete reinterpret_cast(handle); + return 0; +} + +extern "C" JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeTokenize( + JNIEnv* env, jobject thiz, jlong handle, jstring jtext) { + return nativeTokenize(env, handle, jtext); +} + +extern "C" JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeConvertTokensToIds( // NOLINT + JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) { + return nativeConvertTokensToIds(env, handle, jtokens); +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc new file mode 100644 index 000000000..24dd11c91 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace support { + +using ::tflite::support::text::tokenizer::SentencePieceTokenizer; +using ::tflite::support::utils::GetMappedFileBuffer; + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLoadResource( // NOLINT + JNIEnv* env, jobject obj, jobject model_buffer) { + auto model = GetMappedFileBuffer(env, model_buffer); + auto handle = + absl::MakeUnique(model.data(), model.size()); + return reinterpret_cast(handle.release()); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeUnloadResource( // NOLINT + JNIEnv* env, jobject obj, jlong handle) { + delete reinterpret_cast(handle); + return 0; +} + +extern "C" JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeTokenize( // NOLINT + JNIEnv* env, jobject thiz, jlong handle, jstring jtext) { + return nativeTokenize(env, handle, jtext); +} + +extern "C" JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeConvertTokensToIds( // NOLINT + JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) { + return nativeConvertTokensToIds(env, handle, jtokens); +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h new file mode 100644 index 000000000..ed5d3da75 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ + +#include +#include +#include + +#include "src/sentencepiece_processor.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" + +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + +// SentencePiece tokenizer. Initialized with a model file. +class SentencePieceTokenizer : public Tokenizer { + public: + // Initialize the SentencePiece tokenizer from model file path. + explicit SentencePieceTokenizer(const std::string& path_to_model) { + CHECK_OK(sp_.Load(path_to_model)); + } + + explicit SentencePieceTokenizer(const char* spmodel_buffer_data, + size_t spmodel_buffer_size) { + absl::string_view buffer_binary(spmodel_buffer_data, spmodel_buffer_size); + CHECK_OK(sp_.LoadFromSerializedProto(buffer_binary)); + } + + // Perform tokenization, return tokenized results. + TokenizerResult Tokenize(const std::string& input) override { + TokenizerResult result; + std::vector& subwords = result.subwords; + CHECK_OK(sp_.Encode(input, &subwords)); + return result; + } + + // Find the id of a string token. + bool LookupId(absl::string_view key, int* result) const override { + *result = sp_.PieceToId(key); + return true; + } + + // Find the string token of an id. + bool LookupWord(int vocab_id, absl::string_view* result) const override { + *result = sp_.IdToPiece(vocab_id); + return true; + } + + private: + sentencepiece::SentencePieceProcessor sp_; +}; + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h new file mode 100644 index 000000000..681d0a94e --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" + +namespace tflite::support::text::tokenizer { + +struct TokenizerResult { + std::vector subwords; +}; + +// Interface of general tokenizer. +class Tokenizer { + public: + // Perform tokenization to get tokenized results. + virtual TokenizerResult Tokenize(const std::string& input) = 0; + + // Find the id of a string token. + virtual bool LookupId(absl::string_view key, int* result) const = 0; + + // Find the string token from an id. + virtual bool LookupWord(int vocab_id, absl::string_view* result) const = 0; + + // Destructor. + virtual ~Tokenizer() = default; +}; + +} // namespace tflite::support::text::tokenizer + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc new file mode 100644 index 000000000..a72523be5 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc @@ -0,0 +1,86 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h" + +namespace tflite { +namespace support { + +using ::tflite::support::text::tokenizer::Tokenizer; +using ::tflite::support::text::tokenizer::TokenizerResult; +using ::tflite::support::utils::CheckNotNull; +using ::tflite::support::utils::JStringToString; +using ::tflite::support::utils::kIllegalStateException; + +jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext) { + if (handle == 0) { + env->ThrowNew(env->FindClass(kIllegalStateException), + "Vocab not initialized!"); + return nullptr; + } + + Tokenizer* tokenizer = reinterpret_cast(handle); + + // Get the tokenization results. + const TokenizerResult tokenize_result = + tokenizer->Tokenize(JStringToString(env, jtext)); + std::vector subwords = tokenize_result.subwords; + + jclass string_class = CheckNotNull(env, env->FindClass("java/lang/String")); + jobjectArray result = CheckNotNull( + env, env->NewObjectArray(subwords.size(), string_class, nullptr)); + + for (int i = 0; i < subwords.size(); ++i) { + jstring text = CheckNotNull(env, env->NewStringUTF(subwords[i].data())); + if (env->ExceptionCheck()) { + return nullptr; + } + + env->SetObjectArrayElement(result, i, text); + } + + return result; +} + +jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle, + jobjectArray jtokens) { + if (handle == 0) { + env->ThrowNew(env->FindClass(kIllegalStateException), + "vocab not initialized!"); + return nullptr; + } + + Tokenizer* tokenizer = reinterpret_cast(handle); + + // Get the token ids. + const int count = env->GetArrayLength(jtokens); + jintArray result = env->NewIntArray(count); + jint* jid_ptr = env->GetIntArrayElements(result, nullptr); + + for (int i = 0; i < count; i++) { + auto jstr = + reinterpret_cast(env->GetObjectArrayElement(jtokens, i)); + const char* token = env->GetStringUTFChars(jstr, JNI_FALSE); + int id; + tokenizer->LookupId(token, &id); + jid_ptr[i] = id; + env->ReleaseStringUTFChars(jstr, token); + } + env->ReleaseIntArrayElements(result, jid_ptr, 0); + return result; +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h new file mode 100644 index 000000000..fc7285c68 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_JNI_LIB_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_JNI_LIB_H_ + +#include + +#include + +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace support { + +jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext); + +jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle, + jobjectArray jtokens); + +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_JNI_LIB_H_ diff --git a/tensorflow_lite_support/cc/utils/BUILD b/tensorflow_lite_support/cc/utils/BUILD new file mode 100644 index 000000000..b818d5ae3 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/BUILD @@ -0,0 +1,28 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "jni_utils", + srcs = [ + "jni_utils.cc", + ], + hdrs = [ + "jni_utils.h", + ], + deps = [ + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/java/jni", + ], +) + +cc_library( + name = "common_utils", + srcs = [ + "common_utils.cc", + ], + hdrs = [ + "common_utils.h", + ], +) diff --git a/tensorflow_lite_support/cc/utils/common_utils.cc b/tensorflow_lite_support/cc/utils/common_utils.cc new file mode 100644 index 000000000..e19699118 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/common_utils.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/utils/common_utils.h" + +#include + +namespace tflite { +namespace support { +namespace utils { + +struct membuf : std::streambuf { + membuf(char* begin, char* end) { this->setg(begin, begin, end); } +}; + +std::vector LoadVocabFromFile(const std::string& path_to_vocab) { + std::vector vocab_from_file; + std::ifstream in(path_to_vocab.c_str()); + std::string str; + while (std::getline(in, str)) { + if (!str.empty()) vocab_from_file.push_back(str); + } + in.close(); + + return vocab_from_file; +} + +std::vector LoadVocabFromBuffer(const char* vocab_buffer_data, + const size_t vocab_buffer_size) { + membuf sbuf(const_cast(vocab_buffer_data), + const_cast(vocab_buffer_data + vocab_buffer_size)); + std::vector vocab_from_file; + std::istream in(&sbuf); + std::string str; + while (std::getline(in, str)) { + if (!str.empty()) vocab_from_file.push_back(str); + } + return vocab_from_file; +} + +} // namespace utils +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/utils/common_utils.h b/tensorflow_lite_support/cc/utils/common_utils.h new file mode 100644 index 000000000..def102086 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/common_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_UTILS_COMMON_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_UTILS_COMMON_UTILS_H_ + +#include +#include + +namespace tflite { +namespace support { +namespace utils { + +// read a vocab file, create a vector of strings +std::vector LoadVocabFromFile(const std::string& path_to_vocab); + +std::vector LoadVocabFromBuffer(const char* vocab_buffer_data, + const size_t vocab_buffer_size); +} // namespace utils +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_UTILS_COMMON_UTILS_H_ diff --git a/tensorflow_lite_support/cc/utils/jni_utils.cc b/tensorflow_lite_support/cc/utils/jni_utils.cc new file mode 100644 index 000000000..da6603e99 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/jni_utils.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace support { +namespace utils { + +std::string JStringToString(JNIEnv* env, jstring jstr) { + if (jstr == nullptr) { + return std::string(); + } + const char* cstring = env->GetStringUTFChars(jstr, nullptr); + std::string result(cstring); + env->ReleaseStringUTFChars(jstr, cstring); + return result; +} + +std::vector StringListToVector(JNIEnv* env, jobject list_object) { + jobject j_iterator = env->CallObjectMethod( + list_object, env->GetMethodID(env->GetObjectClass(list_object), + "iterator", "()Ljava/util/Iterator;")); + std::vector result; + jmethodID has_next = + env->GetMethodID(env->GetObjectClass(j_iterator), "hasNext", "()Z"); + jmethodID get_next = env->GetMethodID(env->GetObjectClass(j_iterator), "next", + "()Ljava/lang/Object;"); + while (env->CallBooleanMethod(j_iterator, has_next)) { + jstring jstr = + static_cast(env->CallObjectMethod(j_iterator, get_next)); + const char* raw_str = env->GetStringUTFChars(jstr, JNI_FALSE); + result.emplace_back(std::string(raw_str)); + env->ReleaseStringUTFChars(jstr, raw_str); + } + return result; +} + +absl::string_view GetMappedFileBuffer(JNIEnv* env, const jobject& file_buffer) { + return absl::string_view( + static_cast(env->GetDirectBufferAddress(file_buffer)), + static_cast(env->GetDirectBufferCapacity(file_buffer))); +} + +} // namespace utils +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/utils/jni_utils.h b/tensorflow_lite_support/cc/utils/jni_utils.h new file mode 100644 index 000000000..b16d37179 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/jni_utils.h @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_UTILS_JNI_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_UTILS_JNI_UTILS_H_ + +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" + +namespace tflite { +namespace support { +namespace utils { + +const char kIllegalStateException[] = "java/lang/IllegalStateException"; + +// Check if t is nullptr, throw IllegalStateException if it is. +// Used to verify different types of jobjects are correctly created from jni. +template +T CheckNotNull(JNIEnv* env, T&& t) { + if (t == nullptr) { + env->ThrowNew(env->FindClass(kIllegalStateException), ""); + return nullptr; + } + return std::forward(t); +} + +// Convert a vector into an Java ArrayList using a converter. +template +jobject ConvertVectorToArrayList(JNIEnv* env, const std::vector& results, + std::function converter) { + jclass array_list_class = env->FindClass("java/util/ArrayList"); + jmethodID array_list_ctor = + env->GetMethodID(array_list_class, "", "(I)V"); + jint initial_capacity = static_cast(results.size()); + jobject array_list_object = + env->NewObject(array_list_class, array_list_ctor, initial_capacity); + jmethodID array_list_add_method = + env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); + + for (const auto& ans : results) { + env->CallBooleanMethod(array_list_object, array_list_add_method, + converter(ans)); + } + return array_list_object; +} + +std::string JStringToString(JNIEnv* env, jstring jstr); + +std::vector StringListToVector(JNIEnv* env, jobject list_object); + +// Gets a mapped file buffer from a java object representing a file. +absl::string_view GetMappedFileBuffer(JNIEnv* env, const jobject& file_buffer); +} // namespace utils +} // namespace support +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_UTILS_JNI_UTILS_H_ diff --git a/tensorflow_lite_support/metadata/cc/BUILD b/tensorflow_lite_support/metadata/cc/BUILD index 5508fe0e2..903c6995f 100644 --- a/tensorflow_lite_support/metadata/cc/BUILD +++ b/tensorflow_lite_support/metadata/cc/BUILD @@ -20,13 +20,13 @@ cc_library( "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/metadata:metadata_schema_cc", - "//third_party/libzip:zip", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@flatbuffers", + "@org_libzip//:zip", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/schema:schema_fbs", "@org_tensorflow//tensorflow/lite/tools:verifier", diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc index ac39d7732..5c2e6bb26 100644 --- a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc +++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "third_party/libzip/lib/zip.h" +#include "lib/zip.h" #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/tools/verifier.h" diff --git a/tensorflow_lite_support/opensource/opensource_only.files b/tensorflow_lite_support/opensource/opensource_only.files index d9a5d3db6..5c7cee55e 100644 --- a/tensorflow_lite_support/opensource/opensource_only.files +++ b/tensorflow_lite_support/opensource/opensource_only.files @@ -5,9 +5,12 @@ tensorflow_lite_support/third_party/android/android.bzl.tpl tensorflow_lite_support/third_party/android/android_configure.BUILD.tpl tensorflow_lite_support/third_party/android/android_configure.bzl tensorflow_lite_support/third_party/com_google_absl.BUILD +tensorflow_lite_support/third_party/libyuv.BUILD +tensorflow_lite_support/third_party/libzip.BUILD tensorflow_lite_support/third_party/pybind11.BUILD tensorflow_lite_support/third_party/python_runtime/BUILD tensorflow_lite_support/third_party/six.BUILD tensorflow_lite_support/third_party/toolchains/java/BUILD +tensorflow_lite_support/third_party/zlib.BUILD tensorflow_lite_support/tools/ci_build/build_all.sh tensorflow_lite_support/tools/ci_build/common.sh \ No newline at end of file diff --git a/tensorflow_lite_support/tools/build_rules/expand_template.bzl b/tensorflow_lite_support/tools/build_rules/expand_template.bzl new file mode 100644 index 000000000..717860ca8 --- /dev/null +++ b/tensorflow_lite_support/tools/build_rules/expand_template.bzl @@ -0,0 +1,50 @@ +"""Build macro for libzip.""" + +# forked from kythe/kythe/tools/build_rules/expand_template.bzl +def _expand_template_impl(ctx): + ctx.actions.expand_template( + template = ctx.file.template, + output = ctx.outputs.out, + substitutions = ctx.attr.substitutions, + ) + +expand_template = rule( + attrs = { + "out": attr.output(mandatory = True), + "substitutions": attr.string_dict(mandatory = True), + "template": attr.label( + mandatory = True, + allow_single_file = True, + ), + }, + output_to_genfiles = True, + implementation = _expand_template_impl, +) + +def cmake_substitutions(vars, defines = {}): + """Returns a dict of template substitutions combining `vars` and `defines`. + + Args: + vars: will be turned into a dict replacing `${key}` and `@key@` with `value`. + defines: will be turned into a dict replacing `#cmakedefine` with `#define {value}` + if present is true, otherwise `/* #undef %s /*`. + Returns: + substitutions + """ + subs = {} + for key, value in vars.items(): + subs["${%s}" % (key,)] = str(value) if value != None else "" + subs["@%s@" % (key,)] = str(value) if value != None else "" + + # TODO(shahms): Better handling of #cmakedefine delimiters and line endings to + # avoid the prefix-substitution problem. + # Potentially allow value to be: True, False, None or string. + # True/False => Same as current + # None => assume no suffix value, include \n in sub and replacement + # string => use string to lookup in vars and assume ${} or @@ tail? + for macro, present in defines.items(): + if present: + subs["#cmakedefine %s" % macro] = "#define %s" % macro + else: + subs["#cmakedefine %s" % macro] = "/* #undef %s */" % macro + return subs diff --git a/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff b/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff new file mode 100644 index 000000000..0cd2dffa4 --- /dev/null +++ b/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff @@ -0,0 +1,14 @@ +diff --git a/absl/time/internal/cctz/BUILD.bazel b/absl/time/internal/cctz/BUILD.bazel +index 9fceffe..e7f9d01 100644 +--- a/absl/time/internal/cctz/BUILD.bazel ++++ b/absl/time/internal/cctz/BUILD.bazel +@@ -69,8 +69,5 @@ cc_library( + "include/cctz/zone_info_source.h", + ], + linkopts = select({ +- ":osx": [ +- "-framework Foundation", +- ], + ":ios": [ + "-framework Foundation", + ], \ No newline at end of file diff --git a/third_party/libyuv.BUILD b/third_party/libyuv.BUILD new file mode 100644 index 000000000..4b39a8c0a --- /dev/null +++ b/third_party/libyuv.BUILD @@ -0,0 +1,25 @@ +# Description: +# The libyuv package provides implementation yuv image conversion, rotation +# and scaling. + +licenses(["notice"]) # BSD license + +exports_files(["LICENSE"]) + +cc_library( + name = "libyuv", + srcs = glob( + [ + "source/*.cc", + "include/libyuv/*.h", + ], + ), + hdrs = [ + "include/libyuv.h", + "include/libyuv/compare.h", + "include/libyuv/convert.h", + "include/libyuv/video_common.h", + ], + includes = ["include"], + visibility = ["//visibility:public"], +) diff --git a/third_party/libzip.BUILD b/third_party/libzip.BUILD new file mode 100644 index 000000000..b69ccf410 --- /dev/null +++ b/third_party/libzip.BUILD @@ -0,0 +1,189 @@ +package( + default_visibility = ["//visibility:public"], +) + +load("@org_tensorflow_lite_support//tensorflow_lite_support/tools:build_rules/expand_template.bzl", "cmake_substitutions", "expand_template") + +_CMAKE_VARIABLES = { + "INT16_T_LIBZIP": 2, + "INT32_T_LIBZIP": 4, + "INT64_T_LIBZIP": 8, + "INT8_T_LIBZIP": 1, + "INT_LIBZIP": 4, + "LIBZIP_TYPES_INCLUDE": "#include ", + "LONG_LIBZIP": 8, + "LONG_LONG_LIBZIP": 8, + "PACKAGE_VERSION": "1.5.1", + "PACKAGE_VERSION_MAJOR": "1", + "PACKAGE_VERSION_MICRO": "1", + "PACKAGE_VERSION_MINOR": "5", + "SHORT_LIBZIP": 2, + "SIZEOF_OFF_T": 8, + "SIZE_T_LIBZIP": 8, + "SSIZE_T_LIBZIP": 8, + "UINT16_T_LIBZIP": 2, + "UINT32_T_LIBZIP": 4, + "UINT64_T_LIBZIP": 8, + "UINT8_T_LIBZIP": 1, + "__INT16_LIBZIP": None, + "__INT32_LIBZIP": None, + "__INT64_LIBZIP": None, + "__INT8_LIBZIP": None, +} + +_CMAKE_VARIABLES.update(dict([ + ( + "ZIP_{sign}INT{size}_T".format( + sign = sign.upper(), + size = size, + ), + "{sign}int{size}_t".format( + sign = sign.lower(), + size = size, + ), + ) + for sign in ("U", "") + for size in (8, 16, 32, 64) +])) + +_SUBSTITUTIONS = { + "@PACKAGE@": "libzip", + "@VERSION@": "1.5.1", # Keep in sync with actual package! +} + +_DEFINES = { + "HAVE_CLONEFILE": False, + "HAVE_COMMONCRYPTO": False, + "HAVE_CRYPTO": False, + "HAVE_DIRENT_H": False, + "HAVE_FICLONERANGE": False, + "HAVE_FILENO": True, + "HAVE_FSEEK": True, + "HAVE_FSEEKO": True, + "HAVE_FTELLO": True, + "HAVE_FTS_H": True, + "HAVE_GETPROGNAME": False, + "HAVE_GNUTLS": False, + "HAVE_LIBBZ2": False, + "HAVE_MKSTEMP": True, + "HAVE_NDIR_H": False, + "HAVE_OPEN": True, + "HAVE_OPENSSL": False, + "HAVE_SETMODE": False, + "HAVE_SHARED": True, + "HAVE_SNPRINTF": True, + "HAVE_SSIZE_T_LIBZIP": True, + "HAVE_STDBOOL_H": True, + "HAVE_STRCASECMP": True, + "HAVE_STRDUP": True, + "HAVE_STRICMP": False, + "HAVE_STRINGS_H": True, + "HAVE_STRTOLL": True, + "HAVE_STRTOULL": True, + "HAVE_STRUCT_TM_TM_ZONE": False, + "HAVE_SYS_DIR_H": False, + "HAVE_SYS_NDIR_H": False, + "HAVE_UNISTD_H": True, + "HAVE__CHMOD": False, + "HAVE__CLOSE": False, + "HAVE__DUP": False, + "HAVE__FDOPEN": False, + "HAVE__FILENO": False, + "HAVE__OPEN": False, + "HAVE__SETMODE": False, + "HAVE__SNPRINTF": False, + "HAVE__STRDUP": False, + "HAVE__STRICMP": False, + "HAVE__STRTOI64": False, + "HAVE__STRTOUI64": False, + "HAVE__UMASK": False, + "HAVE__UNLINK": False, + "HAVE___PROGNAME": False, + "WORDS_BIGENDIAN": False, +} + +_DEFINES.update(dict([( + key, + value != None, +) for key, value in _CMAKE_VARIABLES.items()])) + +_SUBSTITUTIONS.update(cmake_substitutions( + defines = _DEFINES, + vars = _CMAKE_VARIABLES, +)) + +expand_template( + name = "config_h", + out = "config.h", + substitutions = _SUBSTITUTIONS, + template = "cmake-config.h.in", +) + +_VARS = { + "LIBZIP_TYPES_INCLUDE": "#include ", + "PACKAGE_VERSION": "1.5.1", + "PACKAGE_VERSION_MAJOR": "1", + "PACKAGE_VERSION_MICRO": "1", + "PACKAGE_VERSION_MINOR": "5", +} + +_VARS.update(dict([ + ( + "ZIP_{sign}INT{size}_T".format( + sign = sign.upper(), + size = size, + ), + "{sign}int{size}_t".format( + sign = sign.lower(), + size = size, + ), + ) + for sign in ("U", "") + for size in (8, 16, 32, 64) +])) + +expand_template( + name = "zipconf_h", + out = "lib/zipconf.h", + substitutions = cmake_substitutions( + defines = { + "LIBZIP_VERSION": True, + "LIBZIP_VERSION_MAJOR": True, + "LIBZIP_VERSION_MICRO": True, + "LIBZIP_VERSION_MINOR": True, + "ZIP_STATIC": False, + }, + vars = _VARS, + ), + template = "cmake-zipconf.h.in", +) + +cc_library( + name = "zip", + srcs = glob( + [ + "lib/*.c", + "lib/*.h", + ], + exclude = [ + "lib/*win32*", + "lib/zip_random_uwp.c", + "lib/*crypto*", + "lib/*aes*", + "lib/*bzip2*", + ], + ) + [ + "config.h", + ], + hdrs = [ + "lib/zip.h", + "lib/zipconf.h", + ], + copts = [ + "-DHAVE_CONFIG_H", + ], + includes = ["lib"], + deps = [ + "@zlib", + ], +) diff --git a/third_party/tensorflow_text_fix_local_config_tf.patch b/third_party/tensorflow_text_fix_local_config_tf.patch new file mode 100644 index 000000000..b482b1bd6 --- /dev/null +++ b/third_party/tensorflow_text_fix_local_config_tf.patch @@ -0,0 +1,38 @@ +diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD +index bdca365..5f5cdf6 100644 +--- a/tensorflow_text/core/kernels/BUILD ++++ b/tensorflow_text/core/kernels/BUILD +@@ -16,8 +16,7 @@ OSS_DEPS = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", +- "@local_config_tf//:libtensorflow_framework", +- "@local_config_tf//:tf_header_lib", ++ "@org_tensorflow//tensorflow/core:tensorflow_opensource", + ] + + cc_library( +diff --git a/tensorflow_text/tftext.bzl b/tensorflow_text/tftext.bzl +index aa5e275..5eaff73 100644 +--- a/tensorflow_text/tftext.bzl ++++ b/tensorflow_text/tftext.bzl +@@ -44,8 +44,7 @@ def py_tf_text_library( + copts = [ "-pthread", ], + alwayslink = 1, + deps = cc_op_kernels + [ +- "@local_config_tf//:libtensorflow_framework", +- "@local_config_tf//:tf_header_lib", ++ "@org_tensorflow//tensorflow/core:tensorflow_opensource", + ], + ) + +@@ -55,8 +54,7 @@ def py_tf_text_library( + linkshared = 1, + deps = [ + ":" + library_name, +- "@local_config_tf//:libtensorflow_framework", +- "@local_config_tf//:tf_header_lib", ++ "@org_tensorflow//tensorflow/core:tensorflow_opensource", + ], + ) + \ No newline at end of file diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD new file mode 100644 index 000000000..275782e06 --- /dev/null +++ b/third_party/zlib.BUILD @@ -0,0 +1,39 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "zlib", + srcs = [ + "adler32.c", + "compress.c", + "crc32.c", + "crc32.h", + "deflate.c", + "deflate.h", + "gzclose.c", + "gzguts.h", + "gzlib.c", + "gzread.c", + "gzwrite.c", + "infback.c", + "inffast.c", + "inffast.h", + "inffixed.h", + "inflate.c", + "inflate.h", + "inftrees.c", + "inftrees.h", + "trees.c", + "trees.h", + "uncompr.c", + "zutil.c", + "zutil.h", + ], + hdrs = [ + "zconf.h", + "zlib.h", + ], + copts = ["-Wno-implicit-function-declaration"], + includes = ["."], +)