Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 620398844
  • Loading branch information
tensorflower-gardener committed Apr 3, 2024
1 parent 62d3ccd commit ef773b3
Show file tree
Hide file tree
Showing 34 changed files with 148 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,14 @@ using OpQuantSpecGetter =
// Quantization scale spec of an op. The information defined in the MLIR
// interfaces FixedOutputRangeInterface and SameOperandsAndResultsScale should
// be checked first if present.
// TODO: b/323478683: Consider deprecating this.
struct OpQuantScaleSpec {
// Whether this op has a fixed range requirement (e.g. sigmoid)
bool has_fixed_output_range = false;
// Whether this op should have same result and operand scales (e.g. concat)
// Whether this op should have same operand and result scales (e.g. concat)
bool has_same_scale_requirement = false;
// Whether this op should have same operand and result type (e.g. gather)
bool has_same_operand_and_result_type_requirement = false;
// Returns the fixed output range, when has_fixed_output_range is set.
GetFixedOutputRangeFunc fixed_output_range_func;
// Returns whether same operands and results scales are required.
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ tf_cc_test(
deps = [
":stablehlo_op_quant_spec",
"//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
"//tensorflow/compiler/mlir/quantization/common:func",
"//tensorflow/compiler/mlir/quantization/common:test_base",
"//tensorflow/compiler/mlir/quantization/common/quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ std::unique_ptr<OpQuantSpec> GetStableHloOpQuantSpec(Operation* op) {
return spec;
}

std::unique_ptr<OpQuantScaleSpec> GetStableHloQuantScaleSpec(Operation* op) {
std::unique_ptr<OpQuantScaleSpec> GetStableHloQuantConstraints(Operation* op) {
auto scale_spec = std::make_unique<OpQuantScaleSpec>();
if (llvm::isa<mlir::stablehlo::BroadcastInDimOp,
mlir::stablehlo::ConcatenateOp,
Expand All @@ -142,6 +142,10 @@ std::unique_ptr<OpQuantScaleSpec> GetStableHloQuantScaleSpec(Operation* op) {
mlir::stablehlo::SliceOp, mlir::stablehlo::TransposeOp>(op)) {
scale_spec->has_same_scale_requirement = true;
}
if (llvm::isa<mlir::stablehlo::DynamicSliceOp, mlir::stablehlo::GatherOp,
mlir::stablehlo::PadOp, mlir::stablehlo::SliceOp>(op)) {
scale_spec->has_same_operand_and_result_type_requirement = true;
}
return scale_spec;
}

Expand All @@ -165,7 +169,7 @@ bool IsOpQuantizableStableHlo(Operation* op) {
return false;
}

if (GetStableHloQuantScaleSpec(op)->has_same_scale_requirement) {
if (GetStableHloQuantConstraints(op)->has_same_scale_requirement) {
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ namespace mlir::quant::stablehlo {
// Returns StableHLO quantization specs for an op.
std::unique_ptr<OpQuantSpec> GetStableHloOpQuantSpec(Operation* op);

// Returns quantization scale specs (fixed output, same scale) for a StableHLO
// op.
std::unique_ptr<OpQuantScaleSpec> GetStableHloQuantScaleSpec(Operation* op);
// Returns quantization constraints (ex: fixed output, same scale) given
// a StableHLO op.
std::unique_ptr<OpQuantScaleSpec> GetStableHloQuantConstraints(Operation* op);

// Checks if an op is quantizable in StableHLO quantizer. Argument op is not
// necessarily a StableHLO op.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
#include "tensorflow/compiler/mlir/quantization/common/func.h"
#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h"
#include "tensorflow/compiler/mlir/quantization/common/test_base.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
Expand All @@ -34,7 +35,9 @@ limitations under the License.
namespace mlir::quant::stablehlo {
namespace {

using ::mlir::stablehlo::GatherOp;
using ::testing::IsEmpty;
using ::testing::IsTrue;
using ::testing::NotNull;
using ::testing::Pair;
using ::testing::UnorderedElementsAre;
Expand Down Expand Up @@ -284,5 +287,42 @@ TEST_F(GetStableHloOpQuantSpecTest,
UnorderedElementsAre(Pair(1, 3)));
}

using GetStableHloQuantConstraintsTest = ::mlir::quant::QuantizationTestBase;

TEST_F(GetStableHloQuantConstraintsTest,
HasSameOperandAndResultTypeRequirementSucceeds) {
// Quantizable ops: constants
// Non-quantizable ops: normal StableHLO ops and terminators
constexpr absl::string_view kModuleGather = R"mlir(
module {
func.func @main() -> (tensor<2x3x2x2xf32>) {
%0 = stablehlo.constant dense<1.0> : tensor<3x4x2xf32>
%1 = stablehlo.constant dense<2> : tensor<2x3x2xi64>
%2 = "stablehlo.gather"(%0, %1) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
slice_sizes = array<i64: 1, 2, 2>,
indices_are_sorted = false
} : (tensor<3x4x2xf32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32>
func.return %2 : tensor<2x3x2x2xf32>
}
}
)mlir";
OwningOpRef<ModuleOp> module_op = ParseModuleOpString(kModuleGather);
ASSERT_TRUE(module_op);

func::FuncOp main_fn = FindMainFuncOp(*module_op);
ASSERT_THAT(main_fn, NotNull());

Operation* gather_op = FindOperationOfType<GatherOp>(main_fn);
const auto spec = GetStableHloQuantConstraints(gather_op);

EXPECT_THAT(spec, NotNull());
EXPECT_THAT(spec->has_same_operand_and_result_type_requirement, IsTrue());
}

} // namespace
} // namespace mlir::quant::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void PrepareQuantizePass::runOnOperation() {
MLIRContext* ctx = module_op.getContext();

auto func_op_quant_spec = GetStableHloOpQuantSpec;
auto func_op_quant_scale_spec = GetStableHloQuantScaleSpec;
auto func_op_quant_scale_spec = GetStableHloQuantConstraints;

for (auto func_op : module_op.getOps<func::FuncOp>()) {
// The function might contain more stats ops than required, and it will
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,35 @@ class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern {
void rewrite(func::FuncOp entry_func_op, const Method& quantization_method,
PatternRewriter& rewriter) const override {
auto singular_op = *entry_func_op.getOps<SingularOpT>().begin();

Value singular_op_result = singular_op.getResult();
singular_op_result.setType(entry_func_op.getResultTypes()[0]);

// For ops that require same operand and result types, use explicit
// requantize op rather than using `entry_func_op`'s result as op result.
auto spec = GetStableHloQuantConstraints(singular_op);
const bool has_same_operand_and_result_type =
spec->has_same_operand_and_result_type_requirement;
if (has_same_operand_and_result_type) {
const Type operand_type = entry_func_op.getArgumentTypes()[0];
const Type func_result_type = entry_func_op.getResultTypes()[0];

// Get the quantized tensor manipulation op's output type and update.
const auto singular_op_result_type =
singular_op_result.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> singular_op_shape =
singular_op_result_type.getShape();
const TensorType new_singular_op_result_type =
singular_op_result_type.cloneWith(
singular_op_shape,
getElementTypeOrSelf(operand_type).cast<UniformQuantizedType>());
singular_op_result.setType(new_singular_op_result_type);

// Create requantization op and return.
rewriter.setInsertionPointAfter(singular_op);
CreateAndReturnUniformQuantizeOp(rewriter, *singular_op, entry_func_op,
func_result_type);
} else {
singular_op_result.setType(entry_func_op.getResultTypes()[0]);
}
}
};

Expand Down Expand Up @@ -664,7 +690,7 @@ class QuantizeOpWithRegionPattern
// Quantization parameters can be propagated only for same-scale ops and
// same-scale ops are quantized only when they are connected to quantized
// composite functions.
if (!GetStableHloQuantScaleSpec(op_with_region)
if (!GetStableHloQuantConstraints(op_with_region)
->has_same_scale_requirement ||
!IsConnectedWithQuantizedCompsiteFunction(op_with_region)) {
return failure();
Expand Down Expand Up @@ -866,7 +892,8 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) {
}

// Check whether the preceding op is a quantized same-scale op.
if (GetStableHloQuantScaleSpec(preceding_op)->has_same_scale_requirement) {
if (GetStableHloQuantConstraints(preceding_op)
->has_same_scale_requirement) {
for (const OpResult result : preceding_op->getResults()) {
const Type element_type = getElementTypeOrSelf(result.getType());
if (element_type.isa<UniformQuantizedType>()) {
Expand All @@ -893,7 +920,7 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) {
}

// Check whether the following op is a quantized same-scale op.
if (GetStableHloQuantScaleSpec(following_op)
if (GetStableHloQuantConstraints(following_op)
->has_same_scale_requirement) {
for (Value operand : following_op->getOperands()) {
const Type element_type = getElementTypeOrSelf(operand.getType());
Expand Down Expand Up @@ -923,7 +950,9 @@ class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern {
};

// Compute heavy patterns should be quantized for both server and ODML targets.
void PopulateComputeHeavyPatterns(
// Most patterns here are useful when quantized since they are compute heavy
// or memory bound.
void PopulateCommonQuantizationPatterns(
MLIRContext& ctx, RewritePatternSet& patterns,
const bool enable_per_channel_quantized_weight) {
patterns.add<XlaCallModuleOpToCallOp<QuantizeConvolutionOpPattern>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class StableHloQuantizationPattern : public OpRewritePattern<RootOpT> {
return failure();
}

if (GetStableHloQuantScaleSpec(candidate_op)
if (GetStableHloQuantConstraints(candidate_op)
->has_same_scale_requirement &&
!IsConnectedWithQuantizedCompsiteFunction(candidate_op)) {
return failure();
Expand Down Expand Up @@ -250,9 +250,10 @@ class StableHloQuantizationPattern : public OpRewritePattern<RootOpT> {
}
};

// Populates pattern for compute heavy operations.
void PopulateComputeHeavyPatterns(MLIRContext& ctx, RewritePatternSet& patterns,
bool enable_per_channel_quantized_weight);
// Populates common patterns that are usually compute heavy or memory bound.
void PopulateCommonQuantizationPatterns(
MLIRContext& ctx, RewritePatternSet& patterns,
bool enable_per_channel_quantized_weight);

// Populates conversion patterns for all quantizable ops, including
// ops that are not compute-heavy and data movement ops.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ void QuantizePass::runOnOperation() {
PopulateQuantizeWeightOnlyPatterns(ctx, patterns);
}

PopulateComputeHeavyPatterns(ctx, patterns,
enable_per_channel_quantized_weight_);
PopulateCommonQuantizationPatterns(ctx, patterns,
enable_per_channel_quantized_weight_);

// Quantize all quantizable ops, including ops that are not compute-heavy.
if (enable_full_int_quantization_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -776,5 +776,6 @@ module attributes {tf_saved_model.semantics} {
return %0 : tensor<2x3x2x2xf32>
}
// CHECK: %[[GATHER:.+]] = "stablehlo.gather"(%[[ARG_0]], %[[ARG_1]]) {{.*}} : (tensor<3x4x2x!quant.uniform<i8:f32, {{.*}}>>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform<i8:f32, {{.*}}>>
// CHECK: return %[[GATHER]] : tensor<2x3x2x2x!quant.uniform<i8:f32, {{.*}}>>
// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[GATHER]] : tensor<2x3x2x2x!quant.uniform<i8:f32, {{.*}}>>
// CHECK: return %[[UNIFORM_QUANTIZE]] : tensor<2x3x2x2x!quant.uniform<i8:f32, {{.*}}>>
}
4 changes: 1 addition & 3 deletions tensorflow/compiler/mlir/tfrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,7 @@ tf_proto_library(

cc_library(
name = "passes",
visibility = [
"//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private.
],
visibility = ["//visibility:private"],
deps = [
"//tensorflow/compiler/mlir/tfrt:tf_to_tfrt",
],
Expand Down
4 changes: 1 addition & 3 deletions tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ td_library(
"tf_mlrt_tpu_ops.td",
],
includes = ["."],
visibility = [
"//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private.
],
visibility = ["//visibility:private"],
deps = [
":mlrt_td_files",
":tf_mlrt_td_files",
Expand Down
4 changes: 1 addition & 3 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -662,9 +662,7 @@ cc_library(

cc_library(
name = "dynamic_kernels_impl",
visibility = [
"//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private.
],
visibility = ["//visibility:private"],
deps = [
"//tensorflow/core/kernels:sobol_op",
],
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
// Attrs
{},
// Nodes
{FDH::Const<int32>("shape", gtl::ArraySlice<int32>({1})),
{FDH::Const<int32>("shape", absl::Span<const int32>({1})),
FDH::Const<int32>("minval", 0),
FDH::Const<int32>("maxval", 10),
// A stateful node.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/function_testlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ FunctionDef BlockingOpFn() {

// TODO(phawkins): replace with C++ API for calling functions, when that exists.
Output Call(Scope* scope, const string& op_name, const string& fn_name,
gtl::ArraySlice<Input> inputs) {
absl::Span<const Input> inputs) {
NodeDef def;
NodeDefBuilder builder(op_name, fn_name, scope->graph()->op_registry());
for (const Input& input : inputs) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/common_runtime/function_testlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ FunctionDef BlockingOpFn();
// Adds a function call to the given scope and returns the output for the node.
// TODO(phawkins): replace with C++ API for calling functions, when that exists.
Output Call(Scope* scope, const string& op_name, const string& fn_name,
gtl::ArraySlice<Input> inputs);
absl::Span<const Input> inputs);

} // namespace function
} // namespace test
Expand Down
26 changes: 13 additions & 13 deletions tensorflow/core/common_runtime/gradients.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ static Node* AddZerosLike(Graph* g, NodeOut input) {
}
}

static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {
static Node* AddSymGrad(Graph* g, Node* n, absl::Span<const NodeOut> grads) {
const int num_x = n->num_inputs();
const int num_y = n->num_outputs();
CHECK_EQ(num_y, grads.size());
Expand Down Expand Up @@ -151,18 +151,18 @@ static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {

class SymbolicGradientBuilder {
public:
SymbolicGradientBuilder(gtl::ArraySlice<NodeOut> y_node_outputs,
gtl::ArraySlice<NodeOut> x_node_outputs,
gtl::ArraySlice<NodeOut> y_grad_node_outputs,
SymbolicGradientBuilder(absl::Span<const NodeOut> y_node_outputs,
absl::Span<const NodeOut> x_node_outputs,
absl::Span<const NodeOut> y_grad_node_outputs,
std::vector<NodeOut>* x_grad_node_outputs,
Graph* graph);

Status Compute();

private:
gtl::ArraySlice<NodeOut> y_node_outputs_;
gtl::ArraySlice<NodeOut> x_node_outputs_;
gtl::ArraySlice<NodeOut> y_grad_node_outputs_;
absl::Span<const NodeOut> y_node_outputs_;
absl::Span<const NodeOut> x_node_outputs_;
absl::Span<const NodeOut> y_grad_node_outputs_;
std::vector<NodeOut>* x_grad_node_outputs_;
Graph* graph_; // Not owned.

Expand Down Expand Up @@ -209,9 +209,9 @@ class SymbolicGradientBuilder {
};

SymbolicGradientBuilder::SymbolicGradientBuilder(
gtl::ArraySlice<NodeOut> y_node_outputs,
gtl::ArraySlice<NodeOut> x_node_outputs,
gtl::ArraySlice<NodeOut> y_grad_node_outputs,
absl::Span<const NodeOut> y_node_outputs,
absl::Span<const NodeOut> x_node_outputs,
absl::Span<const NodeOut> y_grad_node_outputs,
std::vector<NodeOut>* x_grad_node_outputs, Graph* graph)
: y_node_outputs_(y_node_outputs),
x_node_outputs_(x_node_outputs),
Expand Down Expand Up @@ -405,9 +405,9 @@ Status SymbolicGradientBuilder::Compute() {
return absl::OkStatus();
}

Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
gtl::ArraySlice<NodeOut> x_node_outputs,
gtl::ArraySlice<NodeOut> y_grad_node_outputs,
Status AddSymbolicGradients(absl::Span<const NodeOut> y_node_outputs,
absl::Span<const NodeOut> x_node_outputs,
absl::Span<const NodeOut> y_grad_node_outputs,
std::vector<NodeOut>* x_grad_node_outputs,
Graph* graph) {
SymbolicGradientBuilder builder(y_node_outputs, x_node_outputs,
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/common_runtime/gradients.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ struct NodeOut {
// implementation only supports gradients for functions). In particular,
// the nodes in 'x_nodes' are currently restricted to have one output.

Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
gtl::ArraySlice<NodeOut> x_node_outputs,
gtl::ArraySlice<NodeOut> y_grad_node_outputs,
Status AddSymbolicGradients(absl::Span<const NodeOut> y_node_outputs,
absl::Span<const NodeOut> x_node_outputs,
absl::Span<const NodeOut> y_grad_node_outputs,
std::vector<NodeOut>* x_grad_node_outputs,
Graph* graph);

Expand Down

0 comments on commit ef773b3

Please sign in to comment.