diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index dce50af0259c96..5b198f143419af 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -31,6 +31,8 @@ cc_library( "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", + "//tensorflow/compiler/mlir/lite/debug", + "//tensorflow/compiler/mlir/lite/debug:debug_options_proto_cc", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 12be81041d66de..65f519175fc37c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h" +#include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/debug/debug.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -53,17 +55,19 @@ std::string TfLiteToMlir(const absl::string_view tflite_op_name) { // TODO(fengliuai): check the result for `fully_quantize` flag. TfLiteStatus QuantizeModel( - const absl::string_view model_buffer, const tflite::TensorType& input_type, - const tflite::TensorType& output_type, - const tflite::TensorType& inference_type, - const std::unordered_set& operator_names, - bool disable_per_channel, bool fully_quantize, std::string& output_buffer, - tflite::ErrorReporter* error_reporter, bool verify_numeric, + const absl::string_view model_buffer, const tflite::TensorType &input_type, + const tflite::TensorType &output_type, + const tflite::TensorType &inference_type, + const std::unordered_set &operator_names, + bool disable_per_channel, bool fully_quantize, std::string &output_buffer, + tflite::ErrorReporter *error_reporter, bool verify_numeric, bool whole_model_verify, bool legacy_float_scale, - const absl::flat_hash_set& denylisted_ops, - const absl::flat_hash_set& denylisted_nodes, + const absl::flat_hash_set &denylisted_ops, + const absl::flat_hash_set &denylisted_nodes, const bool enable_variable_quantization, - bool disable_per_channel_for_dense_layers) { + bool disable_per_channel_for_dense_layers, + const std::optional + &debug_options) { // Translate TFLite names to mlir op names. absl::flat_hash_set denylisted_mlir_op_names; for (const auto& entry : denylisted_ops) { @@ -85,6 +89,10 @@ TfLiteStatus QuantizeModel( // Apply quantization passes. PassManager pm((*module)->getName(), OpPassManager::Nesting::Implicit); + if (debug_options.has_value()) { + // Add debugging instrumentation + tensorflow::InitPassManager(pm, debug_options.value()); + } quant::QuantizationSpecs quant_specs; quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.post_training_quantization = true; diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 665766d700512d..339d562ce0a801 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/api/error_reporter.h" @@ -45,17 +46,19 @@ namespace lite { // of double, and call TOCO's quantization routines to maintain bit-exactness of // the values with the TOCO quantizer. TfLiteStatus QuantizeModel( - absl::string_view model_buffer, const tflite::TensorType& input_type, - const tflite::TensorType& output_type, - const tflite::TensorType& inference_type, - const std::unordered_set& operator_names, - bool disable_per_channel, bool fully_quantize, std::string& output_buffer, - tflite::ErrorReporter* error_reporter, bool verify_numeric = false, + absl::string_view model_buffer, const tflite::TensorType &input_type, + const tflite::TensorType &output_type, + const tflite::TensorType &inference_type, + const std::unordered_set &operator_names, + bool disable_per_channel, bool fully_quantize, std::string &output_buffer, + tflite::ErrorReporter *error_reporter, bool verify_numeric = false, bool whole_model_verify = false, bool legacy_float_scale = true, - const absl::flat_hash_set& denylisted_ops = {}, - const absl::flat_hash_set& denylisted_nodes = {}, + const absl::flat_hash_set &denylisted_ops = {}, + const absl::flat_hash_set &denylisted_nodes = {}, bool enable_variable_quantization = false, - bool disable_per_channel_for_dense_layers = false); + bool disable_per_channel_for_dense_layers = false, + const std::optional + &debug_options = std::nullopt); } // namespace lite } // namespace mlir diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 3790f64e0cec68..69f763abfda5a7 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -65,6 +65,7 @@ package_group( packages = [ "//tensorflow/...", "//tensorflow_text/...", + "//waymo/ml/compiler/frontend/kernels/...", "//waymo/onboard/ml/...", ], ) diff --git a/tensorflow/core/kernels/gather_nd_op.cc b/tensorflow/core/kernels/gather_nd_op.cc index c133556b4aaa43..9551bdd79d4ae5 100644 --- a/tensorflow/core/kernels/gather_nd_op.cc +++ b/tensorflow/core/kernels/gather_nd_op.cc @@ -45,7 +45,8 @@ class GatherNdOp : public OpKernel { Tensor out; OP_REQUIRES_OK( - c, functor::DoGatherNd(c, params, indices, &out)); + c, functor::DoGatherNd( + c, params, indices, &out)); c->set_output(0, out); } }; diff --git a/tensorflow/core/kernels/gather_nd_op.h b/tensorflow/core/kernels/gather_nd_op.h index 09bad00c59b070..6059a2bbdafb31 100644 --- a/tensorflow/core/kernels/gather_nd_op.h +++ b/tensorflow/core/kernels/gather_nd_op.h @@ -43,7 +43,8 @@ struct GatherNdSlice { typename TTypes::Matrix Tout); }; -template +template Status DoGatherNd(OpKernelContext* c, const Tensor& params, const Tensor& indices, Tensor* out) { if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) { @@ -151,6 +152,10 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params, indices_nd); } + if constexpr (kDropBadIndices) { + return absl::OkStatus(); + } + // bad_i will only return >= 0 on CPUs right now. if (bad_i >= 0) { auto shape = indices.shape(); diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 0f604b0e605879..ea369fd49a5ea2 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -878,7 +878,7 @@ class IndexFlattener { namespace { template + scatter_nd_op::UpdateOp Op, bool kDropBadIndices> Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, const Tensor& updates, const TensorShape& shape, Tensor* out, bool allocate) { @@ -925,7 +925,11 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, for (int i = 0; i < IXDIM; ++i) { \ output_shape_prefix[i] = shape.dim_size(i); \ } \ - functor::ScatterNdFunctor functor; \ + constexpr bool kShallDropBadIndices = \ + kDropBadIndices || std::is_same::value; \ + functor::ScatterNdFunctor \ + functor; \ bad_i = \ functor(c->eigen_device(), slice_size, output_shape_prefix, \ output_matrix, indices_flat, updates_flat, output_matrix); \ @@ -947,6 +951,9 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, slice_dim); } } + if constexpr (kDropBadIndices) { + return absl::OkStatus(); + } if (bad_i >= 0) { auto slice_shape = indices.shape(); slice_shape.RemoveLastDims(1); @@ -970,7 +977,8 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, // back to GPU. This is useful because the CPU implementation is deterministic // and the GPU implementation is not. Tensor inputs to this function must be on // the GPU. -template +template Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, const Tensor& updates, const TensorShape& shape, Tensor* out, bool allocate) { @@ -1015,7 +1023,7 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, } TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - TF_RETURN_IF_ERROR(DoScatterNd( + TF_RETURN_IF_ERROR(DoScatterNd( c, host_indices, host_updates, shape, &host_out, /*allocate=*/false)); // Copy 'host_out' to device. @@ -1033,15 +1041,15 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, } // namespace template + scatter_nd_op::UpdateOp Op, bool kDropBadIndices> Status DoScatterNd(OpKernelContext* c, const Tensor& indices, const Tensor& updates, const TensorShape& shape, Tensor* out, bool allocate) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (std::is_same::value && tensorflow::OpDeterminismRequired() && !DisableScatterOpDeterminism()) { - return DoScatterNdOnCpu(c, indices, updates, shape, out, - allocate); + return DoScatterNdOnCpu( + c, indices, updates, shape, out, allocate); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -1049,11 +1057,11 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices, // atomics, which are not supported for all integer types. if constexpr (std::is_same::value && std::is_integral::value) { - return DoScatterNdOnCpu(c, indices, updates, shape, out, - allocate); + return DoScatterNdOnCpu( + c, indices, updates, shape, out, allocate); } else { - return DoScatterNdImpl(c, indices, updates, shape, - out, allocate); + return DoScatterNdImpl( + c, indices, updates, shape, out, allocate); } } } // namespace functor @@ -1061,16 +1069,29 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices, #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ - template <> \ - Index ScatterNdFunctor::operator()( \ - const GPUDevice& d, const Index slice_size, \ - const Eigen::array output_shape_prefix, \ - typename TTypes::Tensor Tparams, \ - typename TTypes::ConstTensor Tindices, \ - typename TTypes::ConstTensor Tupdates, \ - typename TTypes::Tensor Toutput); \ - extern template struct ScatterNdFunctor; +#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ + template <> \ + Index \ + ScatterNdFunctor:: \ + operator()(const GPUDevice& d, const Index slice_size, \ + const Eigen::array output_shape_prefix, \ + typename TTypes::Tensor Tparams, \ + typename TTypes::ConstTensor Tindices, \ + typename TTypes::ConstTensor Tupdates, \ + typename TTypes::Tensor Toutput); \ + extern template struct ScatterNdFunctor; \ + template <> \ + Index ScatterNdFunctor:: \ + operator()(const GPUDevice& d, const Index slice_size, \ + const Eigen::array output_shape_prefix, \ + typename TTypes::Tensor Tparams, \ + typename TTypes::ConstTensor Tindices, \ + typename TTypes::ConstTensor Tupdates, \ + typename TTypes::Tensor Toutput); \ + extern template struct ScatterNdFunctor; #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \ DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \ diff --git a/tensorflow/core/kernels/scatter_nd_op.h b/tensorflow/core/kernels/scatter_nd_op.h index f9a2ce0ed6e12b..8d2e74b18ca864 100644 --- a/tensorflow/core/kernels/scatter_nd_op.h +++ b/tensorflow/core/kernels/scatter_nd_op.h @@ -44,7 +44,7 @@ namespace functor { // Functor used by ScatterOp to do the computations. template + scatter_nd_op::UpdateOp op, int IXDIM, bool kDropBadIndices> struct ScatterNdFunctor { // Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index. Index operator()( @@ -63,7 +63,7 @@ struct ScatterNdFunctor { // right type (T) and shape. This tensor will not be zeroed out // before the scatter is executed. template + scatter_nd_op::UpdateOp Op, bool kDropBadIndices = false> Status DoScatterNd(OpKernelContext* c, const Tensor& indices, const Tensor& updates, const TensorShape& shape, Tensor* out, bool allocate); diff --git a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h index b0123780cc6406..abdbc1ece968bf 100644 --- a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h @@ -103,8 +103,9 @@ class UpdateExecutor { namespace functor { // Implementation of update functor for CPU. -template -struct ScatterNdFunctor { +template +struct ScatterNdFunctor { Index operator()( const CPUDevice& d, const Index slice_size, const Eigen::array output_shape_prefix, @@ -136,33 +137,44 @@ struct ScatterNdFunctor { i += ix_d * batch_strides[dim]; } if (TF_PREDICT_FALSE(out_of_bounds)) { + if constexpr (kDropBadIndices) { + continue; + } error_loc = loc; break; - } else { - auto input_chip = Toutput.template chip<0>(i); - auto output_chip = input_chip; - auto update_chip = Tupdates.template chip<0>(loc); - update_executor::UpdateExecutor< - CPUDevice, decltype(input_chip), decltype(update_chip), - decltype(output_chip), OP>::Execute(d, input_chip, update_chip, - output_chip); } + auto input_chip = Toutput.template chip<0>(i); + auto output_chip = input_chip; + auto update_chip = Tupdates.template chip<0>(loc); + update_executor::UpdateExecutor< + CPUDevice, decltype(input_chip), decltype(update_chip), + decltype(output_chip), OP>::Execute(d, input_chip, update_chip, + output_chip); } return error_loc; } }; -#define REGISTER_SCATTER_ND_FULL(T, Index, op) \ - template Index \ - ScatterNdFunctor::operator()( \ - const CPUDevice& d, const Index slice_size, \ - const Eigen::array \ - output_shape_prefix, \ - typename TTypes::Tensor Tparams, \ - typename TTypes::ConstTensor Tindices, \ - typename TTypes::ConstTensor Tupdates, \ - typename TTypes::Tensor Toutput) +#define REGISTER_SCATTER_ND_FULL(T, Index, op) \ + template Index ScatterNdFunctor:: \ + operator()(const CPUDevice& d, const Index slice_size, \ + const Eigen::array \ + output_shape_prefix, \ + typename TTypes::Tensor Tparams, \ + typename TTypes::ConstTensor Tindices, \ + typename TTypes::ConstTensor Tupdates, \ + typename TTypes::Tensor Toutput); \ + template Index ScatterNdFunctor:: \ + operator()(const CPUDevice& d, const Index slice_size, \ + const Eigen::array \ + output_shape_prefix, \ + typename TTypes::Tensor Tparams, \ + typename TTypes::ConstTensor Tindices, \ + typename TTypes::ConstTensor Tupdates, \ + typename TTypes::Tensor Toutput) #define REGISTER_SCATTER_ND_INDEX(type, op) \ REGISTER_SCATTER_ND_FULL(type, int32, op); \ diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc index fd1d4747c40982..4e528c58e6ba0f 100644 --- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc @@ -124,8 +124,9 @@ __global__ void ScatterNdOpKernel( namespace functor { // Functor used by ScatterOp to do the computations. -template -struct ScatterNdFunctor { +template +struct ScatterNdFunctor { Index operator()( const GPUDevice& d, const Index slice_size, const Eigen::array output_shape_prefix, @@ -164,8 +165,9 @@ struct ScatterNdFunctor { } // namespace functor -#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ - template struct functor::ScatterNdFunctor; +#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ + template struct functor::ScatterNdFunctor; #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \ DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \ diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 32fc17a1ce5ae3..29e1073e2a3266 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -231,6 +231,7 @@ def mlir_quantize( denylisted_nodes=None, enable_variable_quantization=False, disable_per_channel_for_dense_layers=False, + debug_options_str="", ): """Quantize `input_data_str` with calibration results. @@ -259,6 +260,8 @@ def mlir_quantize( disable_per_channel_for_dense_layers: Bool indicating whether to do per-channel or per-tensor quantization in Fully Connected layers. Default value is False meaning per-channel quantization is enabled. + debug_options_str: Serialized proto describing TFLite converter debug + options, see `debug/debug_options.proto`. Returns: Quantized model in serialized form (e.g. a TFLITE model) with floating-point @@ -277,6 +280,7 @@ def mlir_quantize( denylisted_nodes, enable_variable_quantization, disable_per_channel_for_dense_layers, + debug_options_str, ) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index b6cfbaa60f632e..30895357924d74 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -33,6 +33,7 @@ from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb from tensorflow.lite.python import lite_constants as constants +from tensorflow.lite.python.convert import build_conversion_flags as _build_conversion_flags from tensorflow.lite.python.convert import convert_graphdef as _convert_graphdef from tensorflow.lite.python.convert import convert_graphdef_with_arrays as _convert_graphdef_with_arrays from tensorflow.lite.python.convert import convert_jax_hlo as _convert_jax_hlo @@ -717,6 +718,7 @@ def _quantize( bias_type, allow_float, enable_variable_quantization, + debug_options, ): """Quantize the model.""" # pylint: disable=protected-access @@ -758,6 +760,7 @@ def _quantize( output_data_type=output_type, enable_variable_quantization=enable_variable_quantization, disable_per_channel_for_dense_layers=self._experimental_disable_per_channel_quantization_for_dense_layers, + debug_options_str=debug_options.SerializeToString(), ) else: return calibrate_quantize.calibrate_and_quantize( @@ -1098,7 +1101,9 @@ def _set_conversion_latency_metric(self, value): self._tflite_metrics.set_converter_latency(value) @convert_phase(Component.OPTIMIZE_TFLITE_MODEL) - def _optimize_tflite_model(self, model, quant_mode, quant_io=True): + def _optimize_tflite_model( + self, model, quant_mode, debug_options, quant_io=True + ): """Apply optimizations on a TFLite model.""" # Disable TFLite quantization pass when @@ -1126,6 +1131,7 @@ def _optimize_tflite_model(self, model, quant_mode, quant_io=True): q_bias_type, q_allow_float, q_variable_quantization, + debug_options, ) m_in_type = in_type if in_type else _dtypes.float32 @@ -1415,7 +1421,10 @@ def _convert_from_saved_model(self, graph_def): result = _convert_saved_model(**converter_kwargs) return self._optimize_tflite_model( - result, quant_mode, quant_io=self.experimental_new_quantizer + result, + quant_mode, + _build_conversion_flags(**converter_kwargs).debug_options, + quant_io=self.experimental_new_quantizer, ) def convert(self, graph_def, input_tensors, output_tensors): @@ -1461,7 +1470,10 @@ def convert(self, graph_def, input_tensors, output_tensors): ) return self._optimize_tflite_model( - result, self._quant_mode, quant_io=self.experimental_new_quantizer + result, + self._quant_mode, + _build_conversion_flags(**converter_kwargs).debug_options, + quant_io=self.experimental_new_quantizer, ) @@ -2036,7 +2048,10 @@ def convert(self): result = _convert_jax_hlo(**converter_kwargs) return self._optimize_tflite_model( - result, quant_mode, quant_io=self.experimental_new_quantizer + result, + quant_mode, + _build_conversion_flags(**converter_kwargs).debug_options, + quant_io=self.experimental_new_quantizer, ) @@ -2572,7 +2587,10 @@ def convert(self): ) return self._optimize_tflite_model( - result, quant_mode, quant_io=self.experimental_new_quantizer + result, + quant_mode, + _build_conversion_flags(**converter_kwargs).debug_options, + quant_io=self.experimental_new_quantizer, ) def get_input_arrays(self): diff --git a/tensorflow/lite/python/wrap_toco.py b/tensorflow/lite/python/wrap_toco.py index 9d8a8bc11a6456..5badb07413ee56 100644 --- a/tensorflow/lite/python/wrap_toco.py +++ b/tensorflow/lite/python/wrap_toco.py @@ -52,6 +52,7 @@ def wrapped_experimental_mlir_quantize( denylisted_nodes, enable_variable_quantization, disable_per_channel_for_dense_layers, + debug_options_str, ): """Wraps experimental mlir quantize model.""" return _pywrap_toco_api.ExperimentalMlirQuantizeModel( @@ -67,6 +68,7 @@ def wrapped_experimental_mlir_quantize( denylisted_nodes, enable_variable_quantization, disable_per_channel_for_dense_layers, + debug_options_str, ) diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index 145a9aa0e0a57a..991eef9dde4e49 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -38,6 +38,8 @@ cc_library( ], deps = [ "//tensorflow/c:kernels", + "//tensorflow/c:tf_status_headers", + "//tensorflow/compiler/mlir/lite/debug:debug_options_proto_cc", "//tensorflow/compiler/mlir/lite/metrics:error_collector", "//tensorflow/compiler/mlir/lite/python:flatbuffer_to_mlir", "//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer", @@ -55,6 +57,7 @@ cc_library( "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", "//tensorflow/lite/python/interpreter_wrapper:python_utils", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_convert", "//tensorflow/lite/toco:toco_flags_proto_cc", @@ -67,7 +70,11 @@ cc_library( "//tensorflow/lite/toco/logging:toco_conversion_log_proto_cc", "//third_party/python_runtime:headers", # build_cleaner: keep; DNR: b/35864863 "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf_headers", + "@flatbuffers//:runtime_cc", + "@local_tsl//tsl/platform:status", ] + select({ # This is required when running `tflite_convert` from `bazel`. # It requires to link with TensorFlow Ops to get the op definitions. diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index f88e59bd68fa4f..d4fc4b2f2bcdf9 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -14,14 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/python/toco_python_api.h" +#include + #include #include +#include #include #include -#include "google/protobuf/text_format.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/c/kernels.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" #include "tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h" #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" @@ -32,20 +39,20 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/toco/import_tensorflow.h" #include "tensorflow/lite/toco/logging/conversion_log_util.h" #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h" +#include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_convert.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" -#include "tensorflow/lite/toco/toco_port.h" #include "tensorflow/lite/toco/toco_tooling.h" #include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" @@ -298,7 +305,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, bool enable_whole_model_verify, PyObject* op_denylist, PyObject* node_denylist, bool enable_variable_quantization, - bool disable_per_channel_for_dense_layers) { + bool disable_per_channel_for_dense_layers, + PyObject* debug_options_proto_txt_raw) { using tflite::interpreter_wrapper::PythonErrorReporter; char* buf = nullptr; Py_ssize_t length; @@ -309,6 +317,38 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, return nullptr; } + std::optional debug_options = + tensorflow::converter::DebugOptions(); + if (debug_options_proto_txt_raw != nullptr) { + auto ConvertArg = [&](PyObject* obj, bool* error) { + char* buf; + Py_ssize_t len; + if (::tflite::python_utils::ConvertFromPyString(obj, &buf, &len) == -1) { + *error = true; + return std::string(); + } else { + *error = false; + return std::string(buf, len); + } + }; + + bool error; + std::string debug_options_proto_txt = + ConvertArg(debug_options_proto_txt_raw, &error); + if (error) { + PyErr_SetString(PyExc_ValueError, "Toco flags are invalid."); + return nullptr; + } + + if (!debug_options->ParseFromString(debug_options_proto_txt)) { + PyErr_SetString(PyExc_ValueError, + "Failed to convert Toco to Python String."); + return nullptr; + } + } else { + debug_options = std::nullopt; + } + absl::flat_hash_set denylisted_ops; absl::flat_hash_set denylisted_nodes; if (ToStringSet(op_denylist, &denylisted_ops) == -1) { @@ -344,7 +384,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, /*operator_names=*/{}, disable_per_channel, fully_quantize, output_model, error_reporter.get(), enable_numeric_verify, enable_whole_model_verify, /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes, - enable_variable_quantization, disable_per_channel_for_dense_layers); + enable_variable_quantization, disable_per_channel_for_dense_layers, + debug_options); if (status != kTfLiteOk) { error_reporter->exception(); return nullptr; @@ -452,7 +493,7 @@ PyObject* RegisterCustomOpdefs(PyObject* list) { Py_RETURN_TRUE; } -const std::vector RetrieveCollectedErrors() { +std::vector RetrieveCollectedErrors() { mlir::TFL::ErrorCollector* collector = mlir::TFL::ErrorCollector::GetErrorCollector(); std::vector collected_errors; diff --git a/tensorflow/lite/toco/python/toco_python_api.h b/tensorflow/lite/toco/python/toco_python_api.h index 37d42bcf170316..56ad98cfa988cd 100644 --- a/tensorflow/lite/toco/python/toco_python_api.h +++ b/tensorflow/lite/toco/python/toco_python_api.h @@ -53,7 +53,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, PyObject* op_denylist = nullptr, PyObject* node_denylist = nullptr, bool enable_variable_quantization = false, - bool disable_per_channel_for_dense_layers = false); + bool disable_per_channel_for_dense_layers = false, + PyObject* debug_options_proto_txt_raw = nullptr); // Sparsifies model to encode sparse tensors with proper format. Throws error if // sparsification fails. @@ -63,7 +64,7 @@ PyObject* MlirSparsifyModel(PyObject* data); PyObject* RegisterCustomOpdefs(PyObject* list); // Returns the collected TFLite conversion errors. -const std::vector RetrieveCollectedErrors(); +std::vector RetrieveCollectedErrors(); // Returns MLIR string dump of the given Flatbuffer model. std::string FlatBufferFileToMlir(const std::string& model, diff --git a/tensorflow/python/_pywrap_toco_api.pyi b/tensorflow/python/_pywrap_toco_api.pyi index 213c6f14872f7d..319a678bcbcf48 100644 --- a/tensorflow/python/_pywrap_toco_api.pyi +++ b/tensorflow/python/_pywrap_toco_api.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -def ExperimentalMlirQuantizeModel(input_contents_txt_raw: object, disable_per_channel: bool = ..., fully_quantize: bool = ..., inference_type: int = ..., input_data_type: int = ..., output_data_type: int = ..., enable_numeric_verify: bool = ..., enable_whole_model_verify: bool = ..., op_blocklist: object = ..., node_blocklist: object = ..., enable_variable_quantization: bool = ..., disable_per_channel_for_dense_layers: bool = ...) -> object: ... +def ExperimentalMlirQuantizeModel(input_contents_txt_raw: object, disable_per_channel: bool = ..., fully_quantize: bool = ..., inference_type: int = ..., input_data_type: int = ..., output_data_type: int = ..., enable_numeric_verify: bool = ..., enable_whole_model_verify: bool = ..., op_blocklist: object = ..., node_blocklist: object = ..., enable_variable_quantization: bool = ..., disable_per_channel_for_dense_layers: bool = ..., debug_options_proto_txt_raw: object = ...) -> object: ... def ExperimentalMlirSparsifyModel(input_contents_txt_raw: object) -> object: ... def FlatBufferToMlir(arg0: str, arg1: bool) -> str: ... def RegisterCustomOpdefs(custom_opdefs_txt_raw: object) -> object: ... diff --git a/tensorflow/python/lite/toco_python_api_wrapper.cc b/tensorflow/python/lite/toco_python_api_wrapper.cc index 39dc7802f28f0e..40f035db732550 100644 --- a/tensorflow/python/lite/toco_python_api_wrapper.cc +++ b/tensorflow/python/lite/toco_python_api_wrapper.cc @@ -61,14 +61,15 @@ PYBIND11_MODULE(_pywrap_toco_api, m) { int output_data_type, bool enable_numeric_verify, bool enable_whole_model_verify, py::object op_blocklist, py::object node_blocklist, bool enable_variable_quantization, - bool disable_per_channel_for_dense_layers) { + bool disable_per_channel_for_dense_layers, + py::object debug_options_proto_txt_raw) { return tensorflow::PyoOrThrow(toco::MlirQuantizeModel( input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize, inference_type, input_data_type, output_data_type, enable_numeric_verify, enable_whole_model_verify, op_blocklist.ptr(), node_blocklist.ptr(), - enable_variable_quantization, - disable_per_channel_for_dense_layers)); + enable_variable_quantization, disable_per_channel_for_dense_layers, + debug_options_proto_txt_raw.ptr())); }, py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false, py::arg("fully_quantize") = true, py::arg("inference_type") = 9, @@ -79,6 +80,7 @@ PYBIND11_MODULE(_pywrap_toco_api, m) { py::arg("node_blocklist") = py::none(), py::arg("enable_variable_quantization") = false, py::arg("disable_per_channel_for_dense_layers") = false, + py::arg("debug_options_proto_txt_raw") = nullptr, R"pbdoc( Returns a quantized model. )pbdoc"); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 85d54dd84e1dd8..c589bfb0bc5e49 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -772,18 +772,6 @@ absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) { } } -absl::Status GpuExecutor::WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { - if (GpuDriver::WaitStreamOnEvent(context_, - absl::bit_cast(stream), - AsGpuEvent(event)->gpu_event())) { - return absl::OkStatus(); - } else { - return absl::InternalError( - "error waiting for CUDA event on external stream"); - } -} - Event::Status GpuExecutor::PollForEventStatus(Event* event) { return AsGpuEvent(event)->PollForStatus(); } diff --git a/third_party/xla/xla/stream_executor/event.cc b/third_party/xla/xla/stream_executor/event.cc index 3de5e9045d40fb..3b2131f28995a0 100644 --- a/third_party/xla/xla/stream_executor/event.cc +++ b/third_party/xla/xla/stream_executor/event.cc @@ -32,8 +32,4 @@ Event::Status Event::PollForStatus() { return stream_exec_->PollForEventStatus(this); } -absl::Status Event::WaitForEventOnExternalStream(std::intptr_t stream) { - return stream_exec_->WaitForEventOnExternalStream(stream, this); -} - } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/event.h b/third_party/xla/xla/stream_executor/event.h index f3fee06dde4f1b..cf014e55220208 100644 --- a/third_party/xla/xla/stream_executor/event.h +++ b/third_party/xla/xla/stream_executor/event.h @@ -41,7 +41,7 @@ class Event { kComplete, }; - Event(StreamExecutorInterface* stream_exec); + explicit Event(StreamExecutorInterface* stream_exec); // Releases any resources held by the Event object. virtual ~Event() = default; @@ -51,7 +51,9 @@ class Event { // Blocks `stream` on this event. `stream` is a raw platform-specific // stream (e.g. GpuStreamHandle). - absl::Status WaitForEventOnExternalStream(std::intptr_t stream); + virtual absl::Status WaitForEventOnExternalStream(std::intptr_t stream) { + return absl::UnimplementedError("Not supported for this Event."); + } Event(Event&&) = default; Event& operator=(Event&&) = default; diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index eb5babd09bc664..d43863f01115dd 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -197,6 +197,7 @@ gpu_only_cc_library( ":gpu_stream", ":gpu_types_header", "//xla/stream_executor:stream_executor_headers", + "@com_google_absl//absl/base", "@com_google_absl//absl/status", ], ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_event.cc b/third_party/xla/xla/stream_executor/gpu/gpu_event.cc index e1d078122212c2..4cd66783ea382c 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_event.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_event.cc @@ -15,7 +15,11 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_event.h" +#include + +#include "absl/base/casts.h" #include "absl/status/status.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" @@ -45,5 +49,15 @@ absl::Status GpuEvent::Record(GpuStream* stream) { GpuEventHandle GpuEvent::gpu_event() { return gpu_event_; } +absl::Status GpuEvent::WaitForEventOnExternalStream(std::intptr_t stream) { + if (GpuDriver::WaitStreamOnEvent(parent_->gpu_context(), + absl::bit_cast(stream), + gpu_event_)) { + return absl::OkStatus(); + } else { + return absl::InternalError("Error waiting for event on external stream"); + } +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_event.h b/third_party/xla/xla/stream_executor/gpu/gpu_event.h index 6574e50c426424..5ab851dfb60205 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_event.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_event.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ +#include + #include "absl/status/status.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/gpu_stream.h" @@ -47,6 +49,8 @@ class GpuEvent : public Event { // The underlying CUDA event element. GpuEventHandle gpu_event(); + absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override; + private: // The Executor used to which this object and GpuEventHandle are bound. GpuExecutor* parent_; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 65d85dc0a7ac85..5024220a1ab45d 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -237,9 +237,6 @@ class GpuExecutor : public StreamExecutor { absl::Status WaitForEvent(Stream* stream, Event* event) override; - absl::Status WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) override; - Event::Status PollForEventStatus(Event* event) override; absl::Status BlockHostUntilDone(Stream* stream) override; diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index 8e11d3998748a5..ad2de24f6292b2 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -136,8 +136,6 @@ class MockStreamExecutor : public StreamExecutorInterface { (override)); MOCK_METHOD(absl::Status, WaitForEvent, (Stream * stream, Event* event), (override)); - MOCK_METHOD(absl::Status, WaitForEventOnExternalStream, - (std::intptr_t stream, Event* event), (override)); MOCK_METHOD(Event::Status, PollForEventStatus, (Event * event), (override)); MOCK_METHOD(void, DeallocateStream, (Stream * stream), (override)); MOCK_METHOD(bool, CreateStreamDependency, (Stream * dependent, Stream* other), diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index 1003f659e231a9..60433afcbf9e3e 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -676,18 +676,6 @@ absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) { } } -absl::Status GpuExecutor::WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { - if (GpuDriver::WaitStreamOnEvent(context_, - absl::bit_cast(stream), - AsGpuEvent(event)->gpu_event())) { - return absl::OkStatus(); - } else { - return absl::InternalError( - "error waiting for ROCM event on external stream"); - } -} - Event::Status GpuExecutor::PollForEventStatus(Event* event) { return AsGpuEvent(event)->PollForStatus(); } diff --git a/third_party/xla/xla/stream_executor/stream_executor_interface.h b/third_party/xla/xla/stream_executor/stream_executor_interface.h index 982a32ecd8006f..a6a5eb4c0e313d 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_interface.h +++ b/third_party/xla/xla/stream_executor/stream_executor_interface.h @@ -269,14 +269,6 @@ class StreamExecutorInterface { // Waits for the specified event at the end of the specified stream. virtual absl::Status WaitForEvent(Stream* stream, Event* event) = 0; - // Waits for the specified event at the end of the raw platform-specific - // stream. - virtual absl::Status WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { - return absl::UnimplementedError( - "WaitForEventOnExternalStream not supported on this executor."); - } - // Requests the current status of the event from the underlying platform. virtual Event::Status PollForEventStatus(Event* event) = 0;