From 0ca3414c2c7c3f5754fc34cd7adeeddd20a51a7f Mon Sep 17 00:00:00 2001 From: Gustav Larsson Date: Wed, 20 Sep 2023 17:04:50 -0700 Subject: [PATCH 1/4] torch->onnx export support: quantized::linear_relu - Adds support for quantized::linear_relu - Adds weight packing pattern matcher - Adds to export for opset 10 and 13. - Adds QAT test modeled after conv2d+relu fusion --- test/onnx/test_pytorch_onnx_onnxruntime.py | 32 +++++++++++++++++++ .../passes/onnx/unpack_quantized_weights.cpp | 11 +++++++ torch/onnx/symbolic_opset10.py | 17 ++++++++++ torch/onnx/symbolic_opset13.py | 18 +++++++++++ 4 files changed, 78 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 8373af9e62b38..30ded8a31f1ff 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -13119,6 +13119,38 @@ def forward(self, x): ) self.run_test(model, input) + @skipIfUnsupportedMinOpsetVersion(13) + def test_qat_linear_relu_fused(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.ao.quantization.QuantStub() + self.linear = torch.nn.Linear(2, 4) + self.relu = torch.nn.ReLU() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.linear(x) + x = self.relu(x) + x = self.dequant(x) + return x + + model = M() + model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") + model = torch.ao.quantization.fuse_modules(model.eval(), [["linear", "relu"]]) + model = torch.ao.quantization.prepare_qat(model.train()) + # Set fixed weight and bias to avoid flaky test. + model.linear.weight = torch.nn.Parameter( + _construct_tensor_for_quantization_test((2, 4), max_val=2) + ) + model.linear.bias = torch.nn.Parameter(torch.arange(4)) + model = torch.ao.quantization.convert(model) + + # Set fixed input to avoid flaky test. + input = _construct_tensor_for_quantization_test((3, 2), offset=-384, max_val=12) + self.run_test(model, input) + @skipIfUnsupportedMinOpsetVersion(10) def test_qat_maxpool2d(self): class M(torch.nn.Module): diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 4a50c5aebd8c6..9270028b98808 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -679,6 +679,10 @@ void UnpackQuantizedWeights( graph(%input, %packed_weight, %w_scale, %w_zero_point): %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point) return (%r) )"; + std::string qlinear_relu = R"( + graph(%input, %packed_weight, %w_scale, %w_zero_point): + %r = quantized::linear_relu(%input, %packed_weight, %w_scale, %w_zero_point) + return (%r) )"; std::string qconv1d = R"( graph(%input, %packed_params, %scale, %zero_point): %r = quantized::conv1d(%input, %packed_params, %scale, %zero_point) @@ -722,6 +726,13 @@ void UnpackQuantizedWeights( "quantized::linear_unpack", QuantizedParamsType::LINEAR, caffe2); + unpackQuantizedWeightsHelper( + graph, + paramsDict, + qlinear_relu, + "quantized::linear_unpack", + QuantizedParamsType::LINEAR, + caffe2); unpackQuantizedWeightsHelper( graph, paramsDict, diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 00dc18e6defa8..931550b304967 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -58,6 +58,7 @@ "quantized_layer_norm", "quantized_leaky_relu", "quantized_linear", + "quantized_linear_relu", "quantized_mul", "quantized_sigmoid", "slice", @@ -737,6 +738,22 @@ def quantized_linear( return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +@_onnx_symbolic("quantized::linear_relu") +@_beartype.beartype +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + @_onnx_symbolic("quantized::add") @_beartype.beartype def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index 7bc160e1fcc75..08928e77b87bf 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -888,6 +888,24 @@ def quantized_linear( return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +@_onnx_symbolic("quantized::linear_relu") +@_beartype.beartype +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + @_onnx_symbolic("quantized::conv1d_relu") @_beartype.beartype def quantized_conv1d_relu( From e0a6b6376a30fb78de89fde9f1c29807186b6f3b Mon Sep 17 00:00:00 2001 From: Gustav Larsson Date: Wed, 20 Sep 2023 18:23:41 -0700 Subject: [PATCH 2/4] Make bias float32. --- test/onnx/test_pytorch_onnx_onnxruntime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 30ded8a31f1ff..a8f929e262130 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -13144,7 +13144,7 @@ def forward(self, x): model.linear.weight = torch.nn.Parameter( _construct_tensor_for_quantization_test((2, 4), max_val=2) ) - model.linear.bias = torch.nn.Parameter(torch.arange(4)) + model.linear.bias = torch.nn.Parameter(torch.arange(4, dtype=torch.float32)) model = torch.ao.quantization.convert(model) # Set fixed input to avoid flaky test. From 990711c2356cb84916bd49ab9985a14e58498dd0 Mon Sep 17 00:00:00 2001 From: Gustav Larsson Date: Wed, 20 Sep 2023 19:22:16 -0700 Subject: [PATCH 3/4] Fix bias again, and correct conv tests. This has me confused, because the bias and shapes didn't match the in_feature/out_feature of the Conv2d that I was basing the Linear code on. This is because it was defining the ops incorrectly and then re-writing the weights swapping the inputs/outputs channels. It never complained that the original definition had a mismatch. I have fixed this for Linear now, but also fixed it for three cases of Conv2d. --- test/onnx/test_pytorch_onnx_onnxruntime.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index a8f929e262130..7aef2feed22b9 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -13027,7 +13027,7 @@ class M(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.ao.quantization.QuantStub() - self.conv = torch.nn.Conv2d(2, 4, 3, stride=2) + self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x): @@ -13058,7 +13058,7 @@ class M(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.ao.quantization.QuantStub() - self.conv = torch.nn.Conv2d(2, 4, 3, stride=2) + self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) self.relu = torch.nn.ReLU() self.dequant = torch.ao.quantization.DeQuantStub() @@ -13091,7 +13091,7 @@ class M(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.ao.quantization.QuantStub() - self.conv = torch.nn.Conv2d(2, 4, 3, stride=2) + self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) self.relu = torch.nn.ReLU() self.dequant = torch.ao.quantization.DeQuantStub() @@ -13125,7 +13125,7 @@ class M(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.ao.quantization.QuantStub() - self.linear = torch.nn.Linear(2, 4) + self.linear = torch.nn.Linear(4, 2) self.relu = torch.nn.ReLU() self.dequant = torch.ao.quantization.DeQuantStub() @@ -13144,7 +13144,7 @@ def forward(self, x): model.linear.weight = torch.nn.Parameter( _construct_tensor_for_quantization_test((2, 4), max_val=2) ) - model.linear.bias = torch.nn.Parameter(torch.arange(4, dtype=torch.float32)) + model.linear.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0])) model = torch.ao.quantization.convert(model) # Set fixed input to avoid flaky test. From 641fe7cb53bae2c7d9403ff739abd2f9c393fba0 Mon Sep 17 00:00:00 2001 From: Gustav Larsson Date: Wed, 20 Sep 2023 19:33:44 -0700 Subject: [PATCH 4/4] One more input shape fix. --- test/onnx/test_pytorch_onnx_onnxruntime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 7aef2feed22b9..72968f3334262 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -13148,7 +13148,7 @@ def forward(self, x): model = torch.ao.quantization.convert(model) # Set fixed input to avoid flaky test. - input = _construct_tensor_for_quantization_test((3, 2), offset=-384, max_val=12) + input = _construct_tensor_for_quantization_test((3, 4), offset=-384, max_val=12) self.run_test(model, input) @skipIfUnsupportedMinOpsetVersion(10)