diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 8373af9e62b3..72968f333426 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -13027,7 +13027,7 @@ class M(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.ao.quantization.QuantStub() - self.conv = torch.nn.Conv2d(2, 4, 3, stride=2) + self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x): @@ -13058,7 +13058,7 @@ class M(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.ao.quantization.QuantStub() - self.conv = torch.nn.Conv2d(2, 4, 3, stride=2) + self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) self.relu = torch.nn.ReLU() self.dequant = torch.ao.quantization.DeQuantStub() @@ -13091,7 +13091,7 @@ class M(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.ao.quantization.QuantStub() - self.conv = torch.nn.Conv2d(2, 4, 3, stride=2) + self.conv = torch.nn.Conv2d(4, 2, 3, stride=2) self.relu = torch.nn.ReLU() self.dequant = torch.ao.quantization.DeQuantStub() @@ -13119,6 +13119,38 @@ def forward(self, x): ) self.run_test(model, input) + @skipIfUnsupportedMinOpsetVersion(13) + def test_qat_linear_relu_fused(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.ao.quantization.QuantStub() + self.linear = torch.nn.Linear(4, 2) + self.relu = torch.nn.ReLU() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.linear(x) + x = self.relu(x) + x = self.dequant(x) + return x + + model = M() + model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") + model = torch.ao.quantization.fuse_modules(model.eval(), [["linear", "relu"]]) + model = torch.ao.quantization.prepare_qat(model.train()) + # Set fixed weight and bias to avoid flaky test. + model.linear.weight = torch.nn.Parameter( + _construct_tensor_for_quantization_test((2, 4), max_val=2) + ) + model.linear.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0])) + model = torch.ao.quantization.convert(model) + + # Set fixed input to avoid flaky test. + input = _construct_tensor_for_quantization_test((3, 4), offset=-384, max_val=12) + self.run_test(model, input) + @skipIfUnsupportedMinOpsetVersion(10) def test_qat_maxpool2d(self): class M(torch.nn.Module): diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 4a50c5aebd8c..9270028b9880 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -679,6 +679,10 @@ void UnpackQuantizedWeights( graph(%input, %packed_weight, %w_scale, %w_zero_point): %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point) return (%r) )"; + std::string qlinear_relu = R"( + graph(%input, %packed_weight, %w_scale, %w_zero_point): + %r = quantized::linear_relu(%input, %packed_weight, %w_scale, %w_zero_point) + return (%r) )"; std::string qconv1d = R"( graph(%input, %packed_params, %scale, %zero_point): %r = quantized::conv1d(%input, %packed_params, %scale, %zero_point) @@ -722,6 +726,13 @@ void UnpackQuantizedWeights( "quantized::linear_unpack", QuantizedParamsType::LINEAR, caffe2); + unpackQuantizedWeightsHelper( + graph, + paramsDict, + qlinear_relu, + "quantized::linear_unpack", + QuantizedParamsType::LINEAR, + caffe2); unpackQuantizedWeightsHelper( graph, paramsDict, diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 00dc18e6defa..931550b30496 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -58,6 +58,7 @@ "quantized_layer_norm", "quantized_leaky_relu", "quantized_linear", + "quantized_linear_relu", "quantized_mul", "quantized_sigmoid", "slice", @@ -737,6 +738,22 @@ def quantized_linear( return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +@_onnx_symbolic("quantized::linear_relu") +@_beartype.beartype +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + @_onnx_symbolic("quantized::add") @_beartype.beartype def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index 7bc160e1fcc7..08928e77b87b 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -888,6 +888,24 @@ def quantized_linear( return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +@_onnx_symbolic("quantized::linear_relu") +@_beartype.beartype +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + @_onnx_symbolic("quantized::conv1d_relu") @_beartype.beartype def quantized_conv1d_relu(