Skip to content

Commit

Permalink
Rename WeightOnlyPreset to WeightOnlyPtqPreset
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10965 from trevor-m:tmorris-perstreamcomms-false fed944b06a9fb830a79b874f68732bdd41a9cd75
PiperOrigin-RevId: 621138691
  • Loading branch information
doyeonkim0 authored and tensorflower-gardener committed Apr 2, 2024
1 parent fe05bc9 commit 4ffb060
Show file tree
Hide file tree
Showing 31 changed files with 199 additions and 138 deletions.
44 changes: 22 additions & 22 deletions tensorflow/c/eager/tape.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ class GradientTape {

// Returns whether any tensor in a list of tensors is being watched and has
// a trainable dtype.
bool ShouldRecord(gtl::ArraySlice<int64_t> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) const;
bool ShouldRecord(absl::Span<const int64_t> tensor_ids,
absl::Span<const tensorflow::DataType> dtypes) const;

// Adds this tensor to the list of watched tensors.
//
Expand All @@ -158,8 +158,8 @@ class GradientTape {
// nullptr instead of building zeros when build_default_zeros_grads == true.
void RecordOperation(
const string& op_type, const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64_t> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
absl::Span<const int64_t> input_tensor_id,
absl::Span<const tensorflow::DataType> input_dtypes,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter);

Expand All @@ -174,8 +174,8 @@ class GradientTape {
// is set to false.
Status ComputeGradient(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
const gtl::ArraySlice<int64_t> target_tensor_ids,
const gtl::ArraySlice<int64_t> source_tensor_ids,
const absl::Span<const int64_t> target_tensor_ids,
const absl::Span<const int64_t> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
bool build_default_zeros_grads = true);
Expand Down Expand Up @@ -283,8 +283,8 @@ class ForwardAccumulator {
Status Accumulate(
const string& op_type, const std::vector<TapeTensor>& input_tensors,
const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64_t> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
absl::Span<const int64_t> input_tensor_id,
absl::Span<const tensorflow::DataType> input_dtypes,
const ForwardFunction<Gradient>* forward_function,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter);
Expand All @@ -306,8 +306,8 @@ class ForwardAccumulator {

// Indicates whether the forward accumulator should run on an operation with
// the specified inputs and dtypes.
bool ShouldRecord(gtl::ArraySlice<int64_t> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes);
bool ShouldRecord(absl::Span<const int64_t> tensor_ids,
absl::Span<const tensorflow::DataType> dtypes);

// Temporarily push or pop transient state for this accumulator.
//
Expand Down Expand Up @@ -392,8 +392,8 @@ inline bool IsDtypeTrainable(DataType dtype) {

template <typename Gradient, typename BackwardFunction, typename TapeTensor>
bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64_t> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) const {
absl::Span<const int64_t> tensor_ids,
absl::Span<const tensorflow::DataType> dtypes) const {
CHECK_EQ(tensor_ids.size(), dtypes.size());
for (int i = 0; i < tensor_ids.size(); ++i) {
if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
Expand All @@ -414,8 +414,8 @@ void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
const string& op_type, const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64_t> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
absl::Span<const int64_t> input_tensor_id,
absl::Span<const tensorflow::DataType> input_dtypes,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
Expand Down Expand Up @@ -530,7 +530,7 @@ struct BackpropInitialState {
// are needed, are copied and returned in BackpropInitialState.
template <typename BackwardFunction, typename TapeTensor>
BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
gtl::ArraySlice<int64_t> target, const TensorTape& tensor_tape,
absl::Span<const int64_t> target, const TensorTape& tensor_tape,
OpTape<BackwardFunction, TapeTensor>* op_tape,
const std::unordered_set<int64_t>& sources_set, bool persistent_tape) {
std::vector<int64_t> tensor_stack;
Expand Down Expand Up @@ -605,7 +605,7 @@ std::vector<int64_t> InitialStack(
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Status InitialGradients(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64_t> target_tensor_ids,
absl::Span<const int64_t> target_tensor_ids,
const std::unordered_map<int64_t, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
const OpTape<BackwardFunction, TapeTensor>& op_tape,
Expand Down Expand Up @@ -690,8 +690,8 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
const gtl::ArraySlice<int64_t> target_tensor_ids,
const gtl::ArraySlice<int64_t> source_tensor_ids,
const absl::Span<const int64_t> target_tensor_ids,
const absl::Span<const int64_t> source_tensor_ids,
const std::unordered_map<int64_t, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
bool build_default_zeros_grads) {
Expand Down Expand Up @@ -907,8 +907,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(

template <typename Gradient, typename BackwardFunction, typename TapeTensor>
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64_t> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
absl::Span<const int64_t> tensor_ids,
absl::Span<const tensorflow::DataType> dtypes) {
if (call_state_.top().backward_tape != nullptr) {
// If we're forwarding Accumulate calls to backward_tape's RecordOperation,
// we should also delegate ShouldRecord.
Expand Down Expand Up @@ -1031,8 +1031,8 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
const string& op_type, const std::vector<TapeTensor>& input_tensors,
const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64_t> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
absl::Span<const int64_t> input_tensor_id,
absl::Span<const tensorflow::DataType> input_dtypes,
const ForwardFunction<Gradient>* forward_function,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ cc_library(
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:config",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq",
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:weight_only_ptq",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib",
"//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h"
Expand All @@ -43,6 +44,7 @@ namespace tensorflow {
namespace {

using ::mlir::quant::stablehlo::StaticRangePtqComponent;
using ::mlir::quant::stablehlo::WeightOnlyPtqComponent;
using ::stablehlo::quantization::PopulateDefaults;
using ::stablehlo::quantization::QuantizationConfig;
using ::tensorflow::SignatureDef;
Expand Down Expand Up @@ -131,13 +133,25 @@ absl::StatusOr<mlir::ModuleOp> RunQuantization(
return absl::InternalError("Failed to run legalize TF to StableHLO.");
}

StaticRangePtqComponent static_range_ptq_component(
module_op.getContext(), quantization_py_function_lib, saved_model_dir,
/*signature_keys=*/exported_names, saved_model_tags, signature_def_map,
GetFunctionAliases(*saved_model_bundle));
absl::StatusOr<mlir::ModuleOp> quantized_module_op;
if (quantization_config.has_static_range_ptq_preset()) {
StaticRangePtqComponent static_range_ptq_component(
module_op.getContext(), quantization_py_function_lib, saved_model_dir,
/*signature_keys=*/exported_names, saved_model_tags, signature_def_map,
GetFunctionAliases(*saved_model_bundle));

quantized_module_op =
static_range_ptq_component.Run(module_op, updated_config);
} else if (quantization_config.has_weight_only_ptq_preset()) {
WeightOnlyPtqComponent weight_only_ptq_component(module_op.getContext());
quantized_module_op =
weight_only_ptq_component.Run(module_op, updated_config);
} else {
return absl::InvalidArgumentError(
"Quantization config must have either static_range_ptq_preset or "
"weight_only_ptq_preset.");
}

absl::StatusOr<mlir::ModuleOp> quantized_module_op =
static_range_ptq_component.Run(module_op, updated_config);
if (!quantized_module_op.ok()) {
return absl::InternalError("Failed to run quantization. Status msg: " +
quantized_module_op.status().ToString());
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/compiler/mlir/quantization/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
package_group(
name = "internal_visibility_allowlist_package",
packages = [
"//learning/brain/mlir/quantization/stablehlo/python/integration_test/...",
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/compiler/mlir/quantization/...",
"//tensorflow/compiler/mlir/tf2xla/transforms/...",
Expand Down Expand Up @@ -745,7 +746,9 @@ tf_proto_library(
# py_proto_library(
# name = "quantization_config_py_pb2",
# api_version = 2,
# visibility = [":internal_visibility_allowlist_package"],
# visibility = [
# ":internal_visibility_allowlist_package",
# ],
# deps = [":quantization_config_proto"],
# )
# copybara:uncomment_end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def test_matmul_weight_only_model(
)

config = qc.QuantizationConfig(
weight_only_preset=qc.WeightOnlyPreset(),
weight_only_ptq_preset=qc.WeightOnlyPtqPreset(),
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
)
quantization.quantize_saved_model(
Expand Down Expand Up @@ -1077,7 +1077,7 @@ def test_conv_weight_only_model(
)

config = qc.QuantizationConfig(
weight_only_preset=qc.WeightOnlyPreset(),
weight_only_ptq_preset=qc.WeightOnlyPtqPreset(),
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
)
quantization.quantize_saved_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def quantize_saved_model(
if not (
config.HasField('static_range_ptq_preset')
and len(config.static_range_ptq_preset.representative_datasets) == 1
) and not config.HasField('weight_only_preset'):
) and not config.HasField('weight_only_ptq_preset'):
raise ValueError(
'`quantize_saved_model` currently only supports static-range PTQ with a'
' single signature or weight-only quantization.'
Expand Down Expand Up @@ -98,7 +98,7 @@ def quantize_saved_model(
signature_def_map_serialized=signature_def_map_serialized,
py_function_library=py_function_lib.PyFunctionLibrary(),
)
elif config.HasField('weight_only_preset'):
elif config.HasField('weight_only_ptq_preset'):
pywrap_quantization.weight_only_ptq(
src_saved_model_path,
dst_saved_model_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ message StaticRangePtqPreset {
bool enable_full_int_quantization = 3;
}

// Applies int8 per-tensor weight-only quantization for all dot_general op.
message WeightOnlyPreset {}
// Applies int8 per-tensor weight-only post-training quantization for all
// dot_general op.
message WeightOnlyPtqPreset {}

// Metadata specific to the input TensorFlow SavedModel, which may be required
// to identify the specific MetaGraphDef to quantize, for example.
Expand Down Expand Up @@ -322,7 +323,7 @@ message QuantizationConfig {
oneof preset {
// Performs best-effort static-range post-training quantization (PTQ).
StaticRangePtqPreset static_range_ptq_preset = 1;
WeightOnlyPreset weight_only_preset = 7;
WeightOnlyPtqPreset weight_only_ptq_preset = 7;
}

// TF SavedModel specific information for the input model.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/common_runtime/partitioning_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class PartitioningUtilsTest : public ::testing::Test {
// where each node has type `dtype` and arg/ret nodes have
// indices `arg_index` and `ret_index`.
void SubGraph(Graph* subgraph, DataType dtype,
gtl::ArraySlice<int> arg_indices,
gtl::ArraySlice<int> ret_indices) {
absl::Span<const int> arg_indices,
absl::Span<const int> ret_indices) {
Scope s = Scope::NewRootScope();
Scope s1 = s.WithDevice("/job:a/replica:0/task:0/device:CPU:0");
CHECK_EQ(arg_indices.size(), ret_indices.size());
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/common_runtime/placer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1966,7 +1966,7 @@ TEST_P(SoftPlacementPlacerTest,
NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
NDef("id1", "Identity", {"a"},
{{"T", DT_RESOURCE},
{"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
{"_class", absl::Span<const string>({"loc:@id2"})}}),
NDef("id2", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
},
// FunctionLib
Expand Down Expand Up @@ -2013,7 +2013,7 @@ TEST_F(PlacerTest, RequestedDeviceCanBeOverridden) {
NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}, kCPU),
NDef("id1", "Identity", {"id_a"},
{{"T", DT_RESOURCE},
{"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
{"_class", absl::Span<const string>({"loc:@id2"})}}),
NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
},
// FunctionLib
Expand Down Expand Up @@ -2076,7 +2076,7 @@ TEST_P(SoftPlacementPlacerTest,
NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
NDef("id1", "Identity", {"id_a"},
{{"T", DT_RESOURCE},
{"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
{"_class", absl::Span<const string>({"loc:@id2"})}}),
NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
},
// FunctionLib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
Status ProcessFunctionLibraryRuntime::SendTensors(
const string& source_device, const string& target_device,
const string& key_prefix, int64_t src_incarnation,
gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
absl::Span<const Tensor> tensors_to_send, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
RendezvousInterface* rendezvous) {
std::vector<string> keys;
Expand Down Expand Up @@ -398,7 +398,7 @@ ProcessFunctionLibraryRuntime::IsMultiDevice(

namespace {
// Returns the local tensors referred by `args`.
std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
std::vector<Tensor> GetLocalArgs(absl::Span<const FunctionArg> args) {
std::vector<Tensor> tensors;
for (const auto& arg : args) {
if (arg.index() == 0) {
Expand Down Expand Up @@ -1320,7 +1320,7 @@ Status ProcessFunctionLibraryRuntime::CreateRendezvous(
}

Status ProcessFunctionLibraryRuntime::GetComponentArgs(
const gtl::ArraySlice<Tensor> args,
const absl::Span<const Tensor> args,
const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data,
ProcessFunctionLibraryRuntime::InternalArgs* comp_args) {
// "Index"s of _Arg nodes are unique when all arguments are local Tensors.
Expand Down Expand Up @@ -1375,7 +1375,7 @@ Status ProcessFunctionLibraryRuntime::GetComponentArgs(

void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
FunctionLibraryRuntime::Handle handle, absl::Span<const Tensor> args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
FunctionLibraryRuntime::Options new_opts = opts;
Expand Down Expand Up @@ -1420,7 +1420,7 @@ void ProcessFunctionLibraryRuntime::Run(
// This method handles the simple remote call case (not multi-device).
void ProcessFunctionLibraryRuntime::RunInternal(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
FunctionLibraryRuntime::Handle handle, absl::Span<const FunctionArg> args,
std::vector<FunctionRet>* rets,
std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
FunctionLibraryRuntime::DoneCallback done) const {
Expand Down Expand Up @@ -1560,7 +1560,7 @@ void ProcessFunctionLibraryRuntime::Run(

Status ProcessFunctionLibraryRuntime::RunSync(
const FunctionLibraryRuntime::Options& orig_opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
FunctionLibraryRuntime::Handle handle, absl::Span<const Tensor> args,
std::vector<Tensor>* rets) const {
MultiDeviceFunctionData* multi_device_data = IsMultiDevice(handle);
if (multi_device_data && multi_device_data->enable_sync_execution) {
Expand Down

0 comments on commit 4ffb060

Please sign in to comment.