From 7a871fa8e222e3897971e66ef4f2318da65e81ab Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 9 Apr 2024 17:02:08 -0700 Subject: [PATCH] Convert scalar to tensor before quantizer annoate Differential Revision: [D55946527](https://our.internmc.facebook.com/intern/diff/D55946527/) [ghstack-poisoned] --- backends/qualcomm/quantizer/quantizer.py | 48 +++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 674314d991c..ed3e362660e 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -18,6 +18,7 @@ from torch import Tensor from torch._ops import OpOverload +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix from torch.ao.quantization.observer import ( HistogramObserver, MinMaxObserver, @@ -371,6 +372,51 @@ def set_per_channel_weight_dtype( def set_per_channel_quant(self, enable: bool) -> None: self.enable_per_channel_conv_quant = enable + def _lift_constant_scalar_operands(self, gm: torch.fx.GraphModule) -> None: + """ + For the case like mul(x, 2), convert the the scalr to tensor + """ + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in ( + torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul.Scalar, + torch.ops.aten.rsub.Scalar, + ): + continue + + const_arg = None + non_const_arg = None + for arg in n.args: + if isinstance(arg, torch.fx.Node): + non_const_arg = arg + else: + const_arg = arg + + if non_const_arg is None or const_arg is None: + continue + + # print(" n'args are all constant: ", n) + tensor_constant = torch.tensor([const_arg], dtype=torch.float32) + tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")( + gm + ) + gm.register_buffer(tensor_constant_name, tensor_constant) + + fake_mode = n.meta["val"].fake_mode + with gm.graph.inserting_before(n): + get_attr_node = gm.graph.get_attr(tensor_constant_name) + get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant) + + if n.target == torch.ops.aten.rsub.Scalar: + n.args = (get_attr_node, non_const_arg) + n.args[2:] + n.target = torch.ops.aten.sub.Tensor + else: + n.args = (non_const_arg, get_attr_node) + n.args[2:] + + gm.recompile() + def transform_for_annotation(self, model: GraphModule) -> GraphModule: model = RemoveClone()(model).graph_module model = ReduceDynamicRange()(model).graph_module @@ -378,7 +424,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule: model = DecomposeScaledDotProductAttention()(model).graph_module model = DecomposeSilu()(model).graph_module model = ReplaceInfBuffer()(model).graph_module - + self._lift_constant_scalar_operands(model) return model def validate(self, model: GraphModule) -> None: