From 1ccf7b73fe6aded6b104df0b76fd5e4a354e8b54 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 27 Jan 2021 16:31:32 -0800 Subject: [PATCH] [quant][graphmode][fx] Add support for functional conv1d and conv3d (#51155) Summary: This PR added support for quantizing functional conv1d, conv3d, conv1d_relu and conv3d_relu Test Plan: python test/test_quantization.py TestQuantizeFxOps.test_functional_conv Reviewed By: vkuzo [ghstack-poisoned] --- test/quantization/test_quantize_fx.py | 60 +++++++++++++------ .../quantization/fx/quantization_patterns.py | 20 +++++-- torch/quantization/fx/utils.py | 28 +++++++++ 3 files changed, 86 insertions(+), 22 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index b2243eead1d0..295144bddc3f 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1598,14 +1598,21 @@ def forward(self, x): def test_functional_conv(self): """ Test for function conv and functional conv + relu """ + convs = { + 1: torch.nn.functional.conv1d, + 2: torch.nn.functional.conv2d, + 3: torch.nn.functional.conv3d, + } + class FuncConv(torch.nn.Module): - def __init__(self, use_bias, has_relu, f_relu): + def __init__(self, dim, use_bias, has_relu, f_relu): super().__init__() - self.w = torch.randn(3, 3, 3, 3) + self.dim = dim + self.w = torch.randn(tuple([3] * (dim + 2))) self.b = torch.randn(3) if use_bias else None - self.stride = (1, 1) - self.padding = (0, 0) - self.dilation = (1, 1) + self.stride = tuple([1] * dim) + self.padding = tuple([0] * dim) + self.dilation = tuple([1] * dim) self.groups = 1 self.use_bias = use_bias if has_relu: @@ -1617,12 +1624,10 @@ def __init__(self, use_bias, has_relu, f_relu): self.relu = torch.nn.Identity() def forward(self, x): - x = F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups) + x = convs[self.dim](x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups) x = self.relu(x) return x - data = (torch.randn((2, 3, 4, 4), dtype=torch.float),) - quant_type_to_prepare_expected_node_occurrence = { QuantType.DYNAMIC: {}, # There should be 3 observers: after input, weight and activation. @@ -1636,31 +1641,50 @@ def forward(self, x): }, } quant_type_to_qconv_fun = { - QuantType.STATIC: ns.call_function(torch.ops.quantized.conv2d), - QuantType.QAT: ns.call_function(torch.ops.quantized.conv2d), + QuantType.STATIC: { + 1: ns.call_function(torch.ops.quantized.conv1d), + 2: ns.call_function(torch.ops.quantized.conv2d), + 3: ns.call_function(torch.ops.quantized.conv3d) + }, + QuantType.QAT: { + 1: ns.call_function(torch.ops.quantized.conv1d), + 2: ns.call_function(torch.ops.quantized.conv2d), + 3: ns.call_function(torch.ops.quantized.conv3d) + }, } quant_type_to_qconv_relu_fun = { - QuantType.STATIC: ns.call_function(torch.ops.quantized.conv2d_relu), - QuantType.QAT: ns.call_function(torch.ops.quantized.conv2d_relu), + QuantType.STATIC: { + 1: ns.call_function(torch.ops.quantized.conv1d_relu), + 2: ns.call_function(torch.ops.quantized.conv2d_relu), + 3: ns.call_function(torch.ops.quantized.conv3d_relu) + }, + QuantType.QAT: { + 1: ns.call_function(torch.ops.quantized.conv1d_relu), + 2: ns.call_function(torch.ops.quantized.conv2d_relu), + 3: ns.call_function(torch.ops.quantized.conv3d_relu) + }, } options = itertools.product( + [1, 2, 3], # dims self.static_quant_types, (True, False), # use_bias (True, False), # has_relu (True, False), # functional relu ) - for quant_type, use_bias, has_relu, f_relu in options: - model = FuncConv(use_bias, has_relu, f_relu) + for dim, quant_type, use_bias, has_relu, f_relu in options: + data_dims = [2, 3] + [4] * dim + data = (torch.randn(tuple(data_dims), dtype=torch.float),) + model = FuncConv(dim, use_bias, has_relu, f_relu) if has_relu: - qconv_fun = quant_type_to_qconv_relu_fun[quant_type] + qconv_fun = quant_type_to_qconv_relu_fun[quant_type][dim] else: - qconv_fun = quant_type_to_qconv_fun[quant_type] + qconv_fun = quant_type_to_qconv_fun[quant_type][dim] convert_node_occurrence = { - ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0, + ns.call_function(torch.quantize_per_tensor): 1, qconv_fun: 1, - ns.call_method("dequantize"): 1 if quant_type != QuantType.DYNAMIC else 0 + ns.call_method("dequantize"): 1 } prepare_expected_node_occurrence = \ quant_type_to_prepare_expected_node_occurrence[quant_type] diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 06f15240e761..7d0e3f7d0a78 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -32,6 +32,8 @@ quantize_node, get_per_tensor_qparams, get_linear_prepack_op_for_dtype, + get_qconv_prepack_op, + get_qconv_op, ) from .quantization_types import QuantizerCls @@ -188,7 +190,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, @register_quant_pattern(torch.nn.Conv1d) @register_quant_pattern(torch.nn.Conv2d) @register_quant_pattern(torch.nn.Conv3d) +@register_quant_pattern(torch.nn.functional.conv1d) @register_quant_pattern(torch.nn.functional.conv2d) +@register_quant_pattern(torch.nn.functional.conv3d) +# TODO: add qat.Conv1d and qat.Conv3d @register_quant_pattern(torch.nn.qat.Conv2d) @register_quant_pattern(torch.nn.intrinsic.ConvReLU1d) @register_quant_pattern(torch.nn.intrinsic.ConvReLU2d) @@ -198,8 +203,12 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, @register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d) +@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv1d)) @register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d)) +@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv3d)) +@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv1d)) @register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d)) +@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv3d)) # just for error checks @register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d)) @@ -212,8 +221,10 @@ def __init__(self, quantizer: QuantizerCls, node: Node): self.relu_node = node node = node.args[0] # type: ignore self.conv_node = node - if node.op == 'call_module': + if node.op == "call_module": self.conv = quantizer.modules[self.conv_node.target] + elif node.op == "call_function": + self.conv = node.target def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, debug: bool = False, @@ -275,7 +286,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, args = load_arg(quantized=False)(self.conv_node.args) kwargs = load_arg(quantized=False)(self.conv_node.kwargs) op_out = quantizer.quantized_graph.create_node( - "call_function", torch.nn.functional.conv2d, args, kwargs) + "call_function", self.conv, args, kwargs) if self.relu_node: relu_args = [op_out] relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:])) @@ -300,13 +311,14 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, weight = load_arg(quantized=True)(self.conv_node.args[1]) other_args = load_arg(quantized=False)(self.conv_node.args[2:]) prepack_args = tuple([weight] + list(other_args)) + prepack_op = get_qconv_prepack_op(self.conv) packed_weight = quantizer.quantized_graph.create_node( - 'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {}) + "call_function", prepack_op, prepack_args, {}) assert activation_statically_quantized, \ "currently only static quantization is supported for conv" # construct conv input if activation_statically_quantized: - qconv_op = torch.ops.quantized.conv2d_relu if self.relu_node else torch.ops.quantized.conv2d + qconv_op = get_qconv_op(self.conv, self.relu_node is not None) conv_input = load_arg(quantized=True)(self.conv_node.args[0]) act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name activation_post_process = quantizer.activation_post_process_map[act_post_process_name] diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index 8285e204b1ed..3671df81cd00 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -179,6 +179,34 @@ def get_linear_prepack_op_for_dtype(dtype): else: raise Exception("can't get linear prepack op for dtype:", dtype) +def get_qconv_prepack_op(conv_op: Callable) -> Callable: + prepack_ops = { + torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack, + torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack, + torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack + } + prepack_op = prepack_ops.get(conv_op, None) + assert prepack_op, "Didn't find prepack op for {}".format(conv_op) + return prepack_op + +def get_qconv_op(conv_op: Callable, has_relu: bool) -> Callable: + qconv_op = { + # has relu + True: { + torch.nn.functional.conv1d: torch.ops.quantized.conv1d_relu, + torch.nn.functional.conv2d: torch.ops.quantized.conv2d_relu, + torch.nn.functional.conv3d: torch.ops.quantized.conv3d_relu + }, + False: { + torch.nn.functional.conv1d: torch.ops.quantized.conv1d, + torch.nn.functional.conv2d: torch.ops.quantized.conv2d, + torch.nn.functional.conv3d: torch.ops.quantized.conv3d + } + } + qconv = qconv_op[has_relu].get(conv_op) + assert qconv, "Can't find corresponding quantized conv op for {} {}".format(conv_op, has_relu) + return qconv + # Returns a function that can get a new attribute name for module with given # prefix, for example, # >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')