Skip to content

Commit

Permalink
tensorflow: Add per-channel quantization for UniformQuantized Convolu…
Browse files Browse the repository at this point in the history
…tion ops.

Commit: c03c47017de9453cd39696f13938706ea01c86f3
  • Loading branch information
A. Unique TensorFlower authored and sourcegraph-bot committed Dec 29, 2022
1 parent 67d8189 commit 25fa182
Show file tree
Hide file tree
Showing 14 changed files with 382 additions and 142 deletions.
17 changes: 17 additions & 0 deletions tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ td_library(
"passes/post_quantize.td",
"passes/prepare_lifting.td",
"passes/prepare_quantize.td",
"passes/preprocess_op.td",
"passes/quantize_composite_functions.td",
"passes/replace_cast_hacks_with_tf_xla_ops.td",
"passes/tf_quant_ops.td",
Expand Down Expand Up @@ -220,6 +221,20 @@ gentbl_cc_library(
deps = [":quant_td_files"],
)

gentbl_cc_library(
name = "preprocess_op_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
(
["-gen-rewriters"],
"passes/preprocess_op.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "passes/preprocess_op.td",
deps = [":quant_td_files"],
)

cc_library(
name = "tf_quant_ops",
srcs = [
Expand Down Expand Up @@ -323,6 +338,8 @@ cc_library(
"passes/prepare_quantize.cc",
"passes/prepare_quantize.inc",
"passes/prepare_quantize_drq.cc",
"passes/preprocess_op.cc",
"passes/preprocess_op.inc",
"passes/quantize.cc",
"passes/quantize_composite_functions.cc",
"passes/quantize_composite_functions.inc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ namespace mlir::quant {

std::unique_ptr<OpQuantSpec> GetUniformOpQuantSpec(Operation* op) {
auto spec = std::make_unique<OpQuantSpec>();
if (auto call_op = dyn_cast<TF::UniformQuantizedConvolutionHybridOp>(op)) {
if (isa<TF::UniformQuantizedConvolutionHybridOp>(op) ||
isa<TF::UniformQuantizedConvolutionOp>(op)) {
spec->coeff_op_quant_dim[1] = 3;
} else if (auto call_op = dyn_cast<TF::UniformQuantizedDotHybridOp>(op)) {
} else if (isa<TF::UniformQuantizedDotHybridOp>(op)) {
spec->coeff_op_quant_dim[1] = -1;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,22 @@ std::unique_ptr<OperationPass<func::FuncOp>> CreateQuantizePass();
std::unique_ptr<OperationPass<func::FuncOp>> CreateQuantizePass(
QuantizationSpecs quant_specs, OpSet target_opset);

// Creates an instance of the PrepareQuantize pass, which will perfrom similar
// Creates an instance of the PrepareQuantize pass, which will perform similar
// transformations as TFL::PrepareQuantizePass.
std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareQuantizePass(
tensorflow::quantization::QuantizationMethod::ExperimentalMethod
quantization_method);

// Creates an instance of the PrepareQuantizeDRQ pass, which will
// perfrom similar transformations as TFL::PrepareQuantizeDynamicRangePass.
// perform similar transformations as TFL::PrepareQuantizeDynamicRangePass.
std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareQuantizeDRQPass(
const QuantizationSpecs& quant_specs, OpSet op_set);

// Creates an instance of the PreprocessOp pass, which will perform op
// preprocessing to allow multi-axis quantization, prior to quantization.
std::unique_ptr<OperationPass<ModuleOp>> CreatePreprocessOpPass(
const QuantizationSpecs& quant_specs, OpSet op_set);

// Creates an instance of the PostQuantize pass, which will remove unnecessary
// ops from the final quantized graph.
std::unique_ptr<OperationPass<func::FuncOp>> CreatePostQuantizePass();
Expand Down Expand Up @@ -134,7 +139,7 @@ std::unique_ptr<OperationPass<func::FuncOp>>
CreateDuplicateShapeDeterminingConstantsPass();

// Creates a pass that creates a RestoreV2 op in the initializer function with
// type "restore_op" that initializes variables from checkpoint. It finds
// type "restore_op" that initializes variables from the checkpoint. It finds
// tf.AssignVariableOp(tf.VarHandleOp, tf.Const) patterns in the initializer
// function and replaces tf.Consts with the results of RestoreV2.
std::unique_ptr<OperationPass<ModuleOp>> CreateInsertRestoreOpPass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ void PrepareQuantizePass::runOnOperation() {

// During the legalization, unsigned quantized type is used, so we have to
// convert all of them to signed.
RewritePatternSet patterns(&getContext());
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
patterns.add<quant::ConvertUnsignedToSigned<quantfork::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,90 +251,6 @@ void PrepareQuantizeDRQPass::removeAllStatsOp(func::FuncOp func) {
});
}

// Apply constant transformations for the op_set.
class PreprocessConstantOp : public OpRewritePattern<TF::PartitionedCallOp> {
public:
explicit PreprocessConstantOp(MLIRContext* context, OpSet op_set)
: OpRewritePattern<TF::PartitionedCallOp>(context), op_set_(op_set) {}

LogicalResult matchAndRewrite(TF::PartitionedCallOp op,
PatternRewriter& rewriter) const override {
const auto f_attr = op.getFAttr().dyn_cast<FlatSymbolRefAttr>();
// Non-quantizable op
if (!op->hasAttr(kQuantTraitAttrName)) return failure();
StringRef function_name = f_attr.getValue();
if (!function_name.startswith("composite_")) {
return failure();
}

std::unique_ptr<OpQuantSpec> spec = GetTFOpQuantSpec(op);
absl::flat_hash_set<int> operands = spec->quantizable_operands;

if (function_name.contains("depthwise_conv2d")) {
// Uniform Quantized op requires weights of tf.DepthwiseConv2dNative to
// be transformed from [H,W,C,M] to [H,W,1,CxM] where
// H=height,W=width,C=channel,M=multiplier. Therefore, a reshape op is
// inserted between the constant op and the function op so that the
// constant is safely transformed for the multi-use cases as well. Note
// that bias doesn't need transformation as its shape is already in [CxM].
if (operands.size() != 1) return failure();
int weight_operand_idx = *(operands.begin());
Operation* weight_op = op.getOperand(weight_operand_idx).getDefiningOp();

if (op_set_ == OpSet::UNIFORM_QUANTIZED) {
DenseFPElementsAttr attr;
if (!matchPattern(weight_op->getResult(0), m_Constant(&attr))) {
return failure();
}

// Get new shape.
llvm::ArrayRef<int64_t> cur_shape = attr.getType().getShape();
int cur_rank = cur_shape.size();
if (cur_rank != 4 || cur_shape[2] == 1) return failure();
TensorType new_shape = RankedTensorType::get(
{cur_shape[0], cur_shape[1], 1, cur_shape[2] * cur_shape[3]},
attr.getElementType());

// Inserts a reshape op.
RankedTensorType shape_spec_type =
RankedTensorType::get({cur_rank}, rewriter.getIntegerType(64));
DenseElementsAttr new_shape_const_attr =
DenseElementsAttr::get(shape_spec_type, new_shape.getShape());
rewriter.setInsertionPointAfter(weight_op);
arith::ConstantOp new_shape_const = rewriter.create<arith::ConstantOp>(
weight_op->getLoc(), shape_spec_type, new_shape_const_attr);
TF::ReshapeOp reshape_op = rewriter.create<TF::ReshapeOp>(
weight_op->getLoc(), new_shape, weight_op->getResult(0),
new_shape_const);
op->setOperand(weight_operand_idx, reshape_op);

// Create a new function with preprocessed types.
ModuleOp module = op->getParentOfType<ModuleOp>();
SymbolTable symbol_table(module);
func::FuncOp float_func =
dyn_cast<func::FuncOp>(symbol_table.lookup(function_name));
OperandRange func_args = op.getArgs();
func::FuncOp new_float_func = float_func.clone();

SmallVector<Value> new_float_func_args{func_args.begin(),
func_args.end()};
new_float_func_args[weight_operand_idx] = reshape_op;
new_float_func.getArgument(weight_operand_idx).setType(new_shape);
new_float_func.setType(FunctionType::get(
getContext(), TypeRange{ValueRange{new_float_func_args}},
new_float_func.getResultTypes()));
symbol_table.insert(new_float_func);

op->setAttr("f", SymbolRefAttr::get(rewriter.getContext(),
new_float_func.getName()));
return success();
}
}
return failure();
}
OpSet op_set_;
};

#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.inc"

void PrepareQuantizeDRQPass::runOnOperation() {
Expand All @@ -343,7 +259,6 @@ void PrepareQuantizeDRQPass::runOnOperation() {
ModuleOp module_op = getOperation();

populateWithGenerated(patterns);
patterns.add<PreprocessConstantOp>(ctx, op_set_);
patterns.add<PrepareDRQQuantizableOp>(ctx, quant_specs_,
enable_per_channel_quantization_);
FrozenRewritePatternSet frozen_patterns(std::move(patterns));
Expand Down
Loading

0 comments on commit 25fa182

Please sign in to comment.