Skip to content

Commit

Permalink
[quant][graphmode][fx] Add support for functional conv1d and conv3d (#…
Browse files Browse the repository at this point in the history
…51155)

Summary:
This PR added support for quantizing functional conv1d, conv3d,  conv1d_relu and conv3d_relu

Test Plan:
python test/test_quantization.py TestQuantizeFxOps.test_functional_conv

Reviewed By: vkuzo

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Jan 28, 2021
1 parent 2de4ecd commit 1ccf7b7
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 22 deletions.
60 changes: 42 additions & 18 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1598,14 +1598,21 @@ def forward(self, x):
def test_functional_conv(self):
""" Test for function conv and functional conv + relu
"""
convs = {
1: torch.nn.functional.conv1d,
2: torch.nn.functional.conv2d,
3: torch.nn.functional.conv3d,
}

class FuncConv(torch.nn.Module):
def __init__(self, use_bias, has_relu, f_relu):
def __init__(self, dim, use_bias, has_relu, f_relu):
super().__init__()
self.w = torch.randn(3, 3, 3, 3)
self.dim = dim
self.w = torch.randn(tuple([3] * (dim + 2)))
self.b = torch.randn(3) if use_bias else None
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.stride = tuple([1] * dim)
self.padding = tuple([0] * dim)
self.dilation = tuple([1] * dim)
self.groups = 1
self.use_bias = use_bias
if has_relu:
Expand All @@ -1617,12 +1624,10 @@ def __init__(self, use_bias, has_relu, f_relu):
self.relu = torch.nn.Identity()

def forward(self, x):
x = F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups)
x = convs[self.dim](x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups)
x = self.relu(x)
return x

data = (torch.randn((2, 3, 4, 4), dtype=torch.float),)

quant_type_to_prepare_expected_node_occurrence = {
QuantType.DYNAMIC: {},
# There should be 3 observers: after input, weight and activation.
Expand All @@ -1636,31 +1641,50 @@ def forward(self, x):
},
}
quant_type_to_qconv_fun = {
QuantType.STATIC: ns.call_function(torch.ops.quantized.conv2d),
QuantType.QAT: ns.call_function(torch.ops.quantized.conv2d),
QuantType.STATIC: {
1: ns.call_function(torch.ops.quantized.conv1d),
2: ns.call_function(torch.ops.quantized.conv2d),
3: ns.call_function(torch.ops.quantized.conv3d)
},
QuantType.QAT: {
1: ns.call_function(torch.ops.quantized.conv1d),
2: ns.call_function(torch.ops.quantized.conv2d),
3: ns.call_function(torch.ops.quantized.conv3d)
},
}
quant_type_to_qconv_relu_fun = {
QuantType.STATIC: ns.call_function(torch.ops.quantized.conv2d_relu),
QuantType.QAT: ns.call_function(torch.ops.quantized.conv2d_relu),
QuantType.STATIC: {
1: ns.call_function(torch.ops.quantized.conv1d_relu),
2: ns.call_function(torch.ops.quantized.conv2d_relu),
3: ns.call_function(torch.ops.quantized.conv3d_relu)
},
QuantType.QAT: {
1: ns.call_function(torch.ops.quantized.conv1d_relu),
2: ns.call_function(torch.ops.quantized.conv2d_relu),
3: ns.call_function(torch.ops.quantized.conv3d_relu)
},
}

options = itertools.product(
[1, 2, 3], # dims
self.static_quant_types,
(True, False), # use_bias
(True, False), # has_relu
(True, False), # functional relu
)
for quant_type, use_bias, has_relu, f_relu in options:
model = FuncConv(use_bias, has_relu, f_relu)
for dim, quant_type, use_bias, has_relu, f_relu in options:
data_dims = [2, 3] + [4] * dim
data = (torch.randn(tuple(data_dims), dtype=torch.float),)
model = FuncConv(dim, use_bias, has_relu, f_relu)
if has_relu:
qconv_fun = quant_type_to_qconv_relu_fun[quant_type]
qconv_fun = quant_type_to_qconv_relu_fun[quant_type][dim]
else:
qconv_fun = quant_type_to_qconv_fun[quant_type]
qconv_fun = quant_type_to_qconv_fun[quant_type][dim]

convert_node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0,
ns.call_function(torch.quantize_per_tensor): 1,
qconv_fun: 1,
ns.call_method("dequantize"): 1 if quant_type != QuantType.DYNAMIC else 0
ns.call_method("dequantize"): 1
}
prepare_expected_node_occurrence = \
quant_type_to_prepare_expected_node_occurrence[quant_type]
Expand Down
20 changes: 16 additions & 4 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -32,6 +32,8 @@
quantize_node,
get_per_tensor_qparams,
get_linear_prepack_op_for_dtype,
get_qconv_prepack_op,
get_qconv_op,
)

from .quantization_types import QuantizerCls
Expand Down Expand Up @@ -188,7 +190,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
@register_quant_pattern(torch.nn.Conv1d)
@register_quant_pattern(torch.nn.Conv2d)
@register_quant_pattern(torch.nn.Conv3d)
@register_quant_pattern(torch.nn.functional.conv1d)
@register_quant_pattern(torch.nn.functional.conv2d)
@register_quant_pattern(torch.nn.functional.conv3d)
# TODO: add qat.Conv1d and qat.Conv3d
@register_quant_pattern(torch.nn.qat.Conv2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU1d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
Expand All @@ -198,8 +203,12 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv1d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv3d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv1d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv3d))
# just for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
Expand All @@ -212,8 +221,10 @@ def __init__(self, quantizer: QuantizerCls, node: Node):
self.relu_node = node
node = node.args[0] # type: ignore
self.conv_node = node
if node.op == 'call_module':
if node.op == "call_module":
self.conv = quantizer.modules[self.conv_node.target]
elif node.op == "call_function":
self.conv = node.target

def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
Expand Down Expand Up @@ -275,7 +286,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
args = load_arg(quantized=False)(self.conv_node.args)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
op_out = quantizer.quantized_graph.create_node(
"call_function", torch.nn.functional.conv2d, args, kwargs)
"call_function", self.conv, args, kwargs)
if self.relu_node:
relu_args = [op_out]
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
Expand All @@ -300,13 +311,14 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
weight = load_arg(quantized=True)(self.conv_node.args[1])
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
prepack_args = tuple([weight] + list(other_args))
prepack_op = get_qconv_prepack_op(self.conv)
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {})
"call_function", prepack_op, prepack_args, {})
assert activation_statically_quantized, \
"currently only static quantization is supported for conv"
# construct conv input
if activation_statically_quantized:
qconv_op = torch.ops.quantized.conv2d_relu if self.relu_node else torch.ops.quantized.conv2d
qconv_op = get_qconv_op(self.conv, self.relu_node is not None)
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
activation_post_process = quantizer.activation_post_process_map[act_post_process_name]
Expand Down
28 changes: 28 additions & 0 deletions torch/quantization/fx/utils.py
Expand Up @@ -179,6 +179,34 @@ def get_linear_prepack_op_for_dtype(dtype):
else:
raise Exception("can't get linear prepack op for dtype:", dtype)

def get_qconv_prepack_op(conv_op: Callable) -> Callable:
prepack_ops = {
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack
}
prepack_op = prepack_ops.get(conv_op, None)
assert prepack_op, "Didn't find prepack op for {}".format(conv_op)
return prepack_op

def get_qconv_op(conv_op: Callable, has_relu: bool) -> Callable:
qconv_op = {
# has relu
True: {
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_relu,
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_relu,
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_relu
},
False: {
torch.nn.functional.conv1d: torch.ops.quantized.conv1d,
torch.nn.functional.conv2d: torch.ops.quantized.conv2d,
torch.nn.functional.conv3d: torch.ops.quantized.conv3d
}
}
qconv = qconv_op[has_relu].get(conv_op)
assert qconv, "Can't find corresponding quantized conv op for {} {}".format(conv_op, has_relu)
return qconv

# Returns a function that can get a new attribute name for module with given
# prefix, for example,
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
Expand Down

0 comments on commit 1ccf7b7

Please sign in to comment.