From 75ac6b161adf2364a79e8875d30b5b11f9c279be Mon Sep 17 00:00:00 2001 From: George Gekov Date: Thu, 4 Sep 2025 15:45:08 +0100 Subject: [PATCH] Arm backend: Fix annotation of inplace ReLU The ResNet18 model uses a lot of ReLUs with inplace=True As a result of the correct annotation, we can pass the numerical accuracy check on resnet with lower atol. Change-Id: If629cf20df7bfeaa699c7ae8919c52f510cecb68 --- .../arm/quantizer/quantization_annotator.py | 13 ++++- backends/arm/test/models/test_resnet18.py | 2 +- backends/arm/test/ops/test_relu.py | 51 +++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index bea8fe2eddc..d7c85447dd5 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -392,7 +392,11 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.conv2d.padding, ], [torch.ops.aten.batch_norm.default, F.batch_norm], - [torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default], + [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.hardtanh.default, + ], ], filter_fn=any_or_hardtanh_min_zero, ): @@ -408,6 +412,7 @@ def any_or_hardtanh_min_zero(n: Node): ] elif node.target in ( torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, torch.ops.aten.hardtanh.default, ): quant_properties.quant_output = _QuantProperty(0, output_act_qspec) @@ -444,7 +449,11 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.linear.default, torch.ops.aten.conv2d.padding, ], - [torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default], + [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.hardtanh.default, + ], ], any_or_hardtanh_min_zero, ): diff --git a/backends/arm/test/models/test_resnet18.py b/backends/arm/test/models/test_resnet18.py index 6e965daeb8b..cbd8c39f4ce 100644 --- a/backends/arm/test/models/test_resnet18.py +++ b/backends/arm/test/models/test_resnet18.py @@ -54,7 +54,7 @@ def test_resnet_tosa_INT(per_channel_quantization): exir_op=[], use_to_edge_transform_and_lower=True, per_channel_quantization=per_channel_quantization, - atol=0.5, + atol=0.25, qtol=1, ) pipeline.run() diff --git a/backends/arm/test/ops/test_relu.py b/backends/arm/test/ops/test_relu.py index 0b29bc24e75..0b76874d2eb 100644 --- a/backends/arm/test/ops/test_relu.py +++ b/backends/arm/test/ops/test_relu.py @@ -43,6 +43,28 @@ def forward(self, x): return self.relu(x) +test_data_conv_relu = { + # (test_name, test_data) + "4d_randn_inplace=True": (lambda: (torch.randn(1, 64, 96, 96) * 1000, True)), + "4d_randn_inplace=False": (lambda: (torch.randn(1, 64, 96, 96) * 1000, False)), +} + + +class Conv2d_Relu_Add(torch.nn.Module): + def __init__(self, inplace: bool = True): + super().__init__() + self.conv1 = torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=7, padding="same" + ) + self.relu = torch.nn.ReLU(inplace=inplace) + + def forward(self, x: torch.Tensor): + y = self.conv1(x) + z = self.relu(y) + out = x + z + return out + + @common.parametrize("test_data", test_data_suite) def test_relu_tosa_FP(test_data: torch.Tensor): pipeline = TosaPipelineFP[input_t1]( @@ -54,6 +76,35 @@ def test_relu_tosa_FP(test_data: torch.Tensor): pipeline.run() +# Test the folding of Conv2D with ReLU +@common.parametrize("test_data", test_data_conv_relu) +def test_conv_relu_folding_tosa_INT(test_data: torch.Tensor): + input_data, inplace = test_data() + pipeline = TosaPipelineINT[input_t1]( + Conv2d_Relu_Add(inplace=inplace), + (input_data,), + [], + [], + ) + # We should have : + # 3 quantize_per_tensor nodes: input activation , output of the conv-relu sequence, out of the add + # 4 dequantize_per_tensor nodes: into the conv2d input, into the add, output of the conv-relu sequence, before returning + # 2 dequantize_per_channel nodes: one for the weights and another one for the bias + # In case of incorrect annotation of the ReLU, we get separate Q/DR around both the conv2d and the ReLU and + # therefore more quantize_per_tensor and dequantize_per_tensor nodes + pipeline.add_stage_after( + "quantize", + pipeline.tester.check_count, + { + "quantized_decomposed.quantize_per_tensor.default": 3, + "torch.ops.quantized_decomposed.dequantize_per_tensor.default": 4, + "quantized_decomposed.dequantize_per_channel.default": 2, + }, + suffix="quant_nodes", + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) def test_relu_tosa_INT(test_data: torch.Tensor): pipeline = TosaPipelineINT[input_t1](