From a4e105d2fb7d046dc8c6463398b791e61eb9ed24 Mon Sep 17 00:00:00 2001 From: Bala Varadarajan Date: Sun, 23 Nov 2025 09:04:21 -0800 Subject: [PATCH] Fix fake fusion for convolutions without bias (#3353) Summary: Fix `AttributeError` when performing fake fusion on convolution layers without bias by creating a zero-filled bias parameter instead of attempting to access requires_grad on None. Reviewed By: jerryzh168 Differential Revision: D87356763 --- test/quantization/pt2e/test_quantize_pt2e.py | 32 ++++++++------------ torchao/quantization/pt2e/utils.py | 9 +++--- torchao/testing/model_architectures.py | 4 +-- 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 0b5fd64120..5530bbb322 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -157,29 +157,23 @@ def test_chunked_bn_fusion(self): n_chunks = 3 in_channels = 1 out_channels = 32 - m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels) - m.bn.running_var = torch.nn.Parameter( - torch.rand(out_channels) * 1e-2, requires_grad=False - ) + for bias in [True, False]: + m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels, bias=bias) + m.bn.running_var = torch.nn.Parameter( + torch.rand(out_channels) * 1e-2, requires_grad=False + ) - m.eval() - example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),) - ref_outputs = m(*example_inputs) - traced_model = torch.export.export(m, example_inputs, strict=True).module() - traced_outputs = traced_model(*example_inputs) - prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer()) - prepared_outputs = prepared_model(*example_inputs) - - if isinstance(ref_outputs, (tuple, list)): - for ref, prepared, traced in zip( - ref_outputs, prepared_outputs, traced_outputs - ): - torch.testing.assert_close(ref, traced) - torch.testing.assert_close(traced, prepared) - else: + m.eval() + example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),) + ref_outputs = m(*example_inputs) + traced_model = torch.export.export(m, example_inputs, strict=True).module() + traced_outputs = traced_model(*example_inputs) + prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer()) + prepared_outputs = prepared_model(*example_inputs) torch.testing.assert_close(ref_outputs, traced_outputs) torch.testing.assert_close(traced_outputs, prepared_outputs) + def test_wo_annotate_conv_output_quantizer(self): # TODO: use OP_TO_ANNOTATOR class BackendAQuantizer(Quantizer): diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index f3cbffa430..ccf4333489 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -710,10 +710,11 @@ def fold_bn_weights_into_conv_node( conv_args.append(None) if fake_fuse: - fused_weight, fused_bias = ( - torch.nn.Parameter(conv_w, conv_w.requires_grad), - torch.nn.Parameter(conv_b, conv_b.requires_grad), - ) + fused_weight = torch.nn.Parameter(conv_w, conv_w.requires_grad) + if conv_b is not None: + fused_bias = torch.nn.Parameter(conv_b, conv_b.requires_grad) + else: + fused_bias = torch.nn.Parameter(torch.zeros_like(bn_rm), requires_grad=conv_w.requires_grad) else: fused_weight, fused_bias = fuse_conv_bn_weights( conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py index 4100a3cd76..1f45510899 100644 --- a/torchao/testing/model_architectures.py +++ b/torchao/testing/model_architectures.py @@ -82,11 +82,11 @@ def forward(self, x): class ConvWithSharedWeightInExportedModel(nn.Module): def __init__( - self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1 + self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True ) -> None: super().__init__() self.n_chunks = n_chunks - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True)