Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support quantized::linear_relu #109755

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 35 additions & 3 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13027,7 +13027,7 @@ class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(2, 4, 3, stride=2)
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x):
Expand Down Expand Up @@ -13058,7 +13058,7 @@ class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(2, 4, 3, stride=2)
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
self.relu = torch.nn.ReLU()
self.dequant = torch.ao.quantization.DeQuantStub()

Expand Down Expand Up @@ -13091,7 +13091,7 @@ class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(2, 4, 3, stride=2)
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
Copy link
Contributor Author

@gustavla gustavla Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment about these typos. If you expand the context below, the weights are replaced as follows:

        model.conv.weight = torch.nn.Parameter(
            _construct_tensor_for_quantization_test((2, 4, 3, 3), max_val=2)
        )

But this is not consistent with Conv2d(2, 4, 3), it is consistent with Conv2d(4, 2, 3), which is what I'm changing it to.

self.relu = torch.nn.ReLU()
self.dequant = torch.ao.quantization.DeQuantStub()

Expand Down Expand Up @@ -13119,6 +13119,38 @@ def forward(self, x):
)
self.run_test(model, input)

@skipIfUnsupportedMinOpsetVersion(13)
def test_qat_linear_relu_fused(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.linear = torch.nn.Linear(4, 2)
self.relu = torch.nn.ReLU()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.relu(x)
x = self.dequant(x)
return x

model = M()
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
model = torch.ao.quantization.fuse_modules(model.eval(), [["linear", "relu"]])
model = torch.ao.quantization.prepare_qat(model.train())
# Set fixed weight and bias to avoid flaky test.
model.linear.weight = torch.nn.Parameter(
_construct_tensor_for_quantization_test((2, 4), max_val=2)
)
model.linear.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0]))
model = torch.ao.quantization.convert(model)

# Set fixed input to avoid flaky test.
input = _construct_tensor_for_quantization_test((3, 2), offset=-384, max_val=12)
self.run_test(model, input)

@skipIfUnsupportedMinOpsetVersion(10)
def test_qat_maxpool2d(self):
class M(torch.nn.Module):
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,10 @@ void UnpackQuantizedWeights(
graph(%input, %packed_weight, %w_scale, %w_zero_point):
%r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point)
return (%r) )";
std::string qlinear_relu = R"(
graph(%input, %packed_weight, %w_scale, %w_zero_point):
%r = quantized::linear_relu(%input, %packed_weight, %w_scale, %w_zero_point)
return (%r) )";
std::string qconv1d = R"(
graph(%input, %packed_params, %scale, %zero_point):
%r = quantized::conv1d(%input, %packed_params, %scale, %zero_point)
Expand Down Expand Up @@ -722,6 +726,13 @@ void UnpackQuantizedWeights(
"quantized::linear_unpack",
QuantizedParamsType::LINEAR,
caffe2);
unpackQuantizedWeightsHelper(
graph,
paramsDict,
qlinear_relu,
"quantized::linear_unpack",
QuantizedParamsType::LINEAR,
caffe2);
unpackQuantizedWeightsHelper(
graph,
paramsDict,
Expand Down
17 changes: 17 additions & 0 deletions torch/onnx/symbolic_opset10.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"quantized_layer_norm",
"quantized_leaky_relu",
"quantized_linear",
"quantized_linear_relu",
"quantized_mul",
"quantized_sigmoid",
"slice",
Expand Down Expand Up @@ -737,6 +738,22 @@ def quantized_linear(
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)


@_onnx_symbolic("quantized::linear_relu")
@_beartype.beartype
def quantized_linear_relu(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)

output = opset9.linear(g, input, weight, bias)
output = opset9.relu(g, output)

return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)


@_onnx_symbolic("quantized::add")
@_beartype.beartype
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
Expand Down
18 changes: 18 additions & 0 deletions torch/onnx/symbolic_opset13.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,24 @@ def quantized_linear(
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)


@_onnx_symbolic("quantized::linear_relu")
@_beartype.beartype
def quantized_linear_relu(
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
):
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
q_bias = symbolic_helper.requantize_bias_helper(
g, bias, input_scale, weight_scale, axis
)
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)

output = opset9.linear(g, input, weight, bias)
output = opset9.relu(g, output)

return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)


@_onnx_symbolic("quantized::conv1d_relu")
@_beartype.beartype
def quantized_conv1d_relu(
Expand Down