From 5a8198eb3c594aa18352930fd21f3c25bd7b7100 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 Oct 2020 19:57:22 -0700 Subject: [PATCH] [quant][graphmode][fx][fix] scalar as first input for add/mul (#46751) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46751 Currently we assume the first input for add/mul is node (Tensor), but it might not be the case Test Plan: python test/test_quantization.py TestQuantizeFxOps.test_quantized_add python test/test_quantization.py TestQuantizeFxOps.test_quantized_mul python test/test_quantization.py TestQuantizeFxOps.test_quantized_add_relu python test/test_quantization.py TestQuantizeFxOps.test_quantized_mul_relu Imported from OSS Reviewed By: raghuramank100 Differential Revision: D24494456 fbshipit-source-id: ef5e23ba60eb22a57771791f4934306b25c27c01 --- aten/src/ATen/native/quantized/cpu/qadd.cpp | 11 ++++++ aten/src/ATen/native/quantized/cpu/qmul.cpp | 14 ++++++++ aten/src/ATen/native/quantized/library.cpp | 4 +++ test/quantization/test_quantize_fx.py | 5 +++ .../quantization/fx/quantization_patterns.py | 34 +++++++++++++------ torch/quantization/fx/quantize.py | 20 +++++++---- 6 files changed, 72 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qadd.cpp b/aten/src/ATen/native/quantized/cpu/qadd.cpp index a12718502dd1..0b9bc6b8e901 100644 --- a/aten/src/ATen/native/quantized/cpu/qadd.cpp +++ b/aten/src/ATen/native/quantized/cpu/qadd.cpp @@ -243,6 +243,15 @@ Tensor qadd_scalar(Tensor qa, Scalar b) { return _add_scalar_out(qc, qa, b); } +template +Tensor qadd_scalar2(Scalar b, Tensor qa) { + TORCH_CHECK(qa.qscheme() == kPerTensorAffine || + qa.qscheme() == kPerTensorSymmetric, + "Only per tensor quantization is supported in Add."); + auto qc = at::empty_like(qa, qa.suggest_memory_format()); + return _add_scalar_out(qc, qa, b); +} + template Tensor qadd_scalar_out(Tensor qa, Scalar b, Tensor out) { check_inputs(qa, out); @@ -269,10 +278,12 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::add"), TORCH_FN(qadd)); m.impl(TORCH_SELECTIVE_NAME("quantized::add.out"), TORCH_FN(qadd_out)); m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar"), TORCH_FN(qadd_scalar)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar2"), TORCH_FN(qadd_scalar2)); m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar_out"), TORCH_FN(qadd_scalar_out)); m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu"), TORCH_FN(qadd)); m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.out"), TORCH_FN(qadd_out)); m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar"), TORCH_FN(qadd_scalar)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar2"), TORCH_FN(qadd_scalar2)); m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar_out"), TORCH_FN(qadd_scalar_out)); // deprecated functions, kept for backward compatibility m.impl(TORCH_SELECTIVE_NAME("quantized::add_out"), TORCH_FN(qadd_out)); diff --git a/aten/src/ATen/native/quantized/cpu/qmul.cpp b/aten/src/ATen/native/quantized/cpu/qmul.cpp index deeae36dc502..dccc2d718bf1 100644 --- a/aten/src/ATen/native/quantized/cpu/qmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmul.cpp @@ -136,6 +136,18 @@ class QMulScalar final { } }; +template +class QMulScalar2 final { + public: + static Tensor run(Scalar b, Tensor qa) { + TORCH_CHECK(qa.qscheme() == kPerTensorAffine || + qa.qscheme() == kPerTensorSymmetric, + "Only per tensor quantization is supported in Mul."); + auto qc = at::empty_like(qa, qa.suggest_memory_format()); + return _mul_scalar_out(qc, qa, b); + } +}; + template class QMulScalarOut final { public: @@ -176,10 +188,12 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::mul"), TORCH_FN(QMul::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul.out"), TORCH_FN(QMulOut::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar"), TORCH_FN(QMulScalar::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar2"), TORCH_FN(QMulScalar2::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar_out"), TORCH_FN(QMulScalarOut::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu"), TORCH_FN(QMul::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.out"), TORCH_FN(QMulOut::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar"), TORCH_FN(QMulScalar::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar2"), TORCH_FN(QMulScalar2::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar_out"), TORCH_FN(QMulScalarOut::run)); // deprecated functions, kept for backward compatibility m.impl(TORCH_SELECTIVE_NAME("quantized::mul_out"), TORCH_FN(QMulOut::run)); diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 3b866cf2fd12..3150fd986300 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -23,9 +23,11 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar(Tensor qa, Scalar b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar2(Scalar b, Tensor qa) -> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar(Tensor qa, Scalar b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar2(Scalar b, Tensor qa) -> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out")); // deprecated functions, kept for backward compatibility @@ -142,10 +144,12 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar(Tensor qa, Scalar b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar2(Scalar b, Tensor qa)-> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar(Tensor qa, Scalar b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar2(Scalar b, Tensor qa)-> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out")); // deprecated functions, kept for backward compatibility m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index d2b2d6c82ddf..2c6c3221fcfd 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -997,7 +997,10 @@ def __init__(self, is_inplace, is_scalar): def forward(self, x, y): x = self.conv1(x) y = 3 if self.is_scalar else self.conv2(y) + # x = x + y x = self.op(x, y) + # x = y + x + x = self.op(y, x) return x # TODO: decide whether we want to quantize or not @@ -1040,6 +1043,8 @@ def forward(self, x, y): y = 3 if self.is_scalar else self.conv2(y) x = self.op(x, y) x = self.relu(x) + x = self.op(y, x) + x = self.relu(x) return x data = (torch.rand((1, 1, 1, 1), dtype=torch.float), diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 2898a07a17c1..1f2763d6aeb4 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -46,7 +46,8 @@ def __init__(self, quantizer, node): # this is an indicator of whether all the inputs are Node or not # since some op might be quantized differently depending on whether # all inputs are tensors or not, e.g. add/mul - self.all_nodes = True + self.num_node_args = len(node.args) + self.all_node_args = True @abstractmethod def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): @@ -71,18 +72,24 @@ def __init__(self, quantizer, node): node = node.args[0] assert node.op == 'call_function' and node.target in [operator.add, torch.add] self.add_node = node - self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]]) + self.num_node_args = len([a for a in self.add_node.args[:2] if isinstance(a, Node)]) def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): - if not self.all_nodes: + if self.num_node_args == 1: # add scalar if self.relu_node is not None: op = torch.ops.quantized.add_relu else: op = torch.ops.quantized.add + + if isinstance(self.add_node.args[0], Node): + quantized_index = 0 + else: + quantized_index = 1 + return quantizer.quantized_graph.create_node( 'call_function', op, - load_arg(quantized=[0])(self.add_node.args), self.add_node.kwargs) + load_arg(quantized=[quantized_index])(self.add_node.args), self.add_node.kwargs) else: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() @@ -96,6 +103,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ return quantizer.quantized_graph.create_node( 'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs) +# TODO: merge with Add @register_quant_pattern(operator.mul) @register_quant_pattern(torch.mul) @register_quant_pattern((torch.nn.ReLU, operator.mul)) @@ -112,17 +120,23 @@ def __init__(self, quantizer, node): node = node.args[0] assert node.op == 'call_function' and node.target in [operator.mul, torch.mul] self.mul_node = node - self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]]) + self.num_node_args = len([a for a in self.mul_node.args[:2] if isinstance(a, Node)]) def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): - if not self.all_nodes: + if self.num_node_args == 1: # mul scalar if self.relu_node is not None: op = torch.ops.quantized.mul_relu else: op = torch.ops.quantized.mul + + if isinstance(self.mul_node.args[0], Node): + quantized_index = 0 + else: + quantized_index = 1 + return quantizer.quantized_graph.create_node( - 'call_function', op, load_arg(quantized=[0])(self.mul_node.args), self.mul_node.kwargs) + 'call_function', op, load_arg(quantized=[quantized_index])(self.mul_node.args), self.mul_node.kwargs) else: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() @@ -138,7 +152,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern(torch.cat) class Cat(QuantizeHandler): def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): - if not self.all_nodes: + if not self.all_node_args: return NotImplemented activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() @@ -438,7 +452,7 @@ class DefaultNode(QuantizeHandler): ''' Common quantized op, first input and first output will be quantized ''' def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): - if not self.all_nodes: + if not self.all_node_args: return NotImplemented assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \ 'call_function are handled in DefaultNode' @@ -580,7 +594,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ # of quantizable objects (e.g. modules and functionals) class DefaultQuantizeHandler(QuantizeHandler): def convert(self, quantizer, node): - assert self.all_nodes + assert self.all_node_args root_module = quantizer.modules[''] return quantize_node( root_module, diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 60da8142ee2d..3f7a699d97b3 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -384,7 +384,7 @@ def load_arg(a): continue prefix = node.name + '_activation_post_process_' - root_node, _, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None)) + root_node, matched_nodes, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: @@ -458,15 +458,23 @@ def is_observed(input_arg): # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) - elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: - if node.args[0].name in observed_node_names_set: + elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1: + input_node = matched_nodes[-1] # first node in the sequence + + def input_is_observed(arg): + return isinstance(arg, Node) and arg.name in observed_node_names_set + # This is checking if one of the argument of add/mul + # is an observed node + # If both of the inputs are number, + # we will not consider the output to be observed + if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]): observed_node_names_set.add(node.name) elif isinstance(obj, StandaloneModuleQuantizeHandler): assert node.op == 'call_module' output_is_observed = self.modules[node.target]._output_is_observed if output_is_observed: observed_node_names_set.add(node.name) - elif qconfig is not None and obj.all_nodes: + elif qconfig is not None and obj.all_node_args: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers @@ -741,8 +749,8 @@ def is_quantized(node): # the node is quantized in parent module quant_env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) else: - # dequantize inputs for the node that are not quantized - env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) + # copy quantized or non-quantized node + env[node.name] = self.quantized_graph.node_copy(node, load_x) # remove activation post process act_post_process_removed_graph = Graph()