Skip to content

Commit

Permalink
Adding MLIR debugging instrumentation to experimental MLIR Quantizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623867514
  • Loading branch information
tensorflower-gardener committed May 22, 2024
1 parent 98a4c09 commit f985804
Show file tree
Hide file tree
Showing 28 changed files with 238 additions and 128 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/quantization/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ cc_library(
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
"//tensorflow/compiler/mlir/lite/debug",
"//tensorflow/compiler/mlir/lite/debug:debug_options_proto_cc",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config",
"//tensorflow/compiler/mlir/tensorflow:error_util",
Expand Down
26 changes: 17 additions & 9 deletions tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"

#include <optional>
#include <string>
#include <unordered_set>

Expand All @@ -30,6 +31,7 @@ limitations under the License.
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/debug/debug.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
Expand All @@ -53,17 +55,19 @@ std::string TfLiteToMlir(const absl::string_view tflite_op_name) {

// TODO(fengliuai): check the result for `fully_quantize` flag.
TfLiteStatus QuantizeModel(
const absl::string_view model_buffer, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize, std::string& output_buffer,
tflite::ErrorReporter* error_reporter, bool verify_numeric,
const absl::string_view model_buffer, const tflite::TensorType &input_type,
const tflite::TensorType &output_type,
const tflite::TensorType &inference_type,
const std::unordered_set<std::string> &operator_names,
bool disable_per_channel, bool fully_quantize, std::string &output_buffer,
tflite::ErrorReporter *error_reporter, bool verify_numeric,
bool whole_model_verify, bool legacy_float_scale,
const absl::flat_hash_set<std::string>& denylisted_ops,
const absl::flat_hash_set<std::string>& denylisted_nodes,
const absl::flat_hash_set<std::string> &denylisted_ops,
const absl::flat_hash_set<std::string> &denylisted_nodes,
const bool enable_variable_quantization,
bool disable_per_channel_for_dense_layers) {
bool disable_per_channel_for_dense_layers,
const std::optional<const tensorflow::converter::DebugOptions>
&debug_options) {
// Translate TFLite names to mlir op names.
absl::flat_hash_set<std::string> denylisted_mlir_op_names;
for (const auto& entry : denylisted_ops) {
Expand All @@ -85,6 +89,10 @@ TfLiteStatus QuantizeModel(

// Apply quantization passes.
PassManager pm((*module)->getName(), OpPassManager::Nesting::Implicit);
if (debug_options.has_value()) {
// Add debugging instrumentation
tensorflow::InitPassManager(pm, debug_options.value());
}
quant::QuantizationSpecs quant_specs;
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
quant_specs.post_training_quantization = true;
Expand Down
21 changes: 12 additions & 9 deletions tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/core/api/error_reporter.h"
Expand All @@ -45,17 +46,19 @@ namespace lite {
// of double, and call TOCO's quantization routines to maintain bit-exactness of
// the values with the TOCO quantizer.
TfLiteStatus QuantizeModel(
absl::string_view model_buffer, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize, std::string& output_buffer,
tflite::ErrorReporter* error_reporter, bool verify_numeric = false,
absl::string_view model_buffer, const tflite::TensorType &input_type,
const tflite::TensorType &output_type,
const tflite::TensorType &inference_type,
const std::unordered_set<std::string> &operator_names,
bool disable_per_channel, bool fully_quantize, std::string &output_buffer,
tflite::ErrorReporter *error_reporter, bool verify_numeric = false,
bool whole_model_verify = false, bool legacy_float_scale = true,
const absl::flat_hash_set<std::string>& denylisted_ops = {},
const absl::flat_hash_set<std::string>& denylisted_nodes = {},
const absl::flat_hash_set<std::string> &denylisted_ops = {},
const absl::flat_hash_set<std::string> &denylisted_nodes = {},
bool enable_variable_quantization = false,
bool disable_per_channel_for_dense_layers = false);
bool disable_per_channel_for_dense_layers = false,
const std::optional<const tensorflow::converter::DebugOptions>
&debug_options = std::nullopt);

} // namespace lite
} // namespace mlir
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ package_group(
packages = [
"//tensorflow/...",
"//tensorflow_text/...",
"//waymo/ml/compiler/frontend/kernels/...",
"//waymo/onboard/ml/...",
],
)
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/gather_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class GatherNdOp : public OpKernel {

Tensor out;
OP_REQUIRES_OK(
c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
c, functor::DoGatherNd<Device, T, Index, /*kDropBadIndices=*/false>(
c, params, indices, &out));
c->set_output(0, out);
}
};
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/core/kernels/gather_nd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ struct GatherNdSlice {
typename TTypes<T>::Matrix Tout);
};

template <typename Device, typename T, typename Index>
template <typename Device, typename T, typename Index,
bool kDropBadIndices = false>
Status DoGatherNd(OpKernelContext* c, const Tensor& params,
const Tensor& indices, Tensor* out) {
if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) {
Expand Down Expand Up @@ -151,6 +152,10 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params,
indices_nd);
}

if constexpr (kDropBadIndices) {
return absl::OkStatus();
}

// bad_i will only return >= 0 on CPUs right now.
if (bad_i >= 0) {
auto shape = indices.shape();
Expand Down
63 changes: 42 additions & 21 deletions tensorflow/core/kernels/scatter_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ class IndexFlattener {
namespace {

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
scatter_nd_op::UpdateOp Op, bool kDropBadIndices>
Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Expand Down Expand Up @@ -925,7 +925,11 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
for (int i = 0; i < IXDIM; ++i) { \
output_shape_prefix[i] = shape.dim_size(i); \
} \
functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor; \
constexpr bool kShallDropBadIndices = \
kDropBadIndices || std::is_same<Device, GPUDevice>::value; \
functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM, \
kShallDropBadIndices> \
functor; \
bad_i = \
functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
output_matrix, indices_flat, updates_flat, output_matrix); \
Expand All @@ -947,6 +951,9 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
slice_dim);
}
}
if constexpr (kDropBadIndices) {
return absl::OkStatus();
}
if (bad_i >= 0) {
auto slice_shape = indices.shape();
slice_shape.RemoveLastDims(1);
Expand All @@ -970,7 +977,8 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
// back to GPU. This is useful because the CPU implementation is deterministic
// and the GPU implementation is not. Tensor inputs to this function must be on
// the GPU.
template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
template <typename T, typename Index, scatter_nd_op::UpdateOp Op,
bool kDropBadIndices>
Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Expand Down Expand Up @@ -1015,7 +1023,7 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
}

TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op>(
TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op, kDropBadIndices>(
c, host_indices, host_updates, shape, &host_out, /*allocate=*/false));

// Copy 'host_out' to device.
Expand All @@ -1033,44 +1041,57 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
} // namespace

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
scatter_nd_op::UpdateOp Op, bool kDropBadIndices>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (std::is_same<Device, GPUDevice>::value &&
tensorflow::OpDeterminismRequired() && !DisableScatterOpDeterminism()) {
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate);
return DoScatterNdOnCpu<T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Run on the CPU for integer types, since the GPU implementation uses
// atomics, which are not supported for all integer types.
if constexpr (std::is_same<Device, GPUDevice>::value &&
std::is_integral<T>::value) {
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate);
return DoScatterNdOnCpu<T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
} else {
return DoScatterNdImpl<Device, T, Index, Op>(c, indices, updates, shape,
out, allocate);
return DoScatterNdImpl<Device, T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
}
}
} // namespace functor

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()( \
const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template <> \
Index \
ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, /*kDropBadIndices=*/true>:: \
operator()(const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/true>; \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/false>:: \
operator()(const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/false>;

#define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/scatter_nd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace functor {

// Functor used by ScatterOp to do the computations.
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp op, int IXDIM>
scatter_nd_op::UpdateOp op, int IXDIM, bool kDropBadIndices>
struct ScatterNdFunctor {
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
Index operator()(
Expand All @@ -63,7 +63,7 @@ struct ScatterNdFunctor {
// right type (T) and shape. This tensor will not be zeroed out
// before the scatter is executed.
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
scatter_nd_op::UpdateOp Op, bool kDropBadIndices = false>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate);
Expand Down
52 changes: 32 additions & 20 deletions tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> {
namespace functor {

// Implementation of update functor for CPU.
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM,
bool kDropBadIndices>
struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM, kDropBadIndices> {
Index operator()(
const CPUDevice& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
Expand Down Expand Up @@ -136,33 +137,44 @@ struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
i += ix_d * batch_strides[dim];
}
if (TF_PREDICT_FALSE(out_of_bounds)) {
if constexpr (kDropBadIndices) {
continue;
}
error_loc = loc;
break;
} else {
auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor<
CPUDevice, decltype(input_chip), decltype(update_chip),
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
}
auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor<
CPUDevice, decltype(input_chip), decltype(update_chip),
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
}

return error_loc;
}
};

#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
template Index \
ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput)
#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
template Index ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM, \
/*kDropBadIndices=*/false>:: \
operator()(const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
template Index ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM, \
/*kDropBadIndices=*/true>:: \
operator()(const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput)

#define REGISTER_SCATTER_ND_INDEX(type, op) \
REGISTER_SCATTER_ND_FULL(type, int32, op); \
Expand Down
10 changes: 6 additions & 4 deletions tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ __global__ void ScatterNdOpKernel(
namespace functor {

// Functor used by ScatterOp to do the computations.
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM,
bool kDropBadIndices>
struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, kDropBadIndices> {
Index operator()(
const GPUDevice& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
Expand Down Expand Up @@ -164,8 +165,9 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {

} // namespace functor

#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/true>;

#define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
Expand Down

0 comments on commit f985804

Please sign in to comment.