From 68486af2178b399a75da10de7833f9a6191d2aa4 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Mon, 8 Sep 2025 17:32:54 -0700 Subject: [PATCH] Fix batch norm partitioning with Conv3d (#13696) Summary: Models with a batch norm following a conv3d cause an internal error during lowering. This diff fixes it by updating the partitioning logic to only rely on fusion with 1d and 2d convs. This is because XNNPACK doesn't currently support standalone batch norms and only partitions norms that can be fused. We can't fuse with Conv3d, because XNNPACK doesn't have an implementation. The partitioner constraint was missing logic to exclude Conv3d. Differential Revision: D81069236 --- backends/xnnpack/_passes/fuse_batch_norm.py | 10 +++++- .../test/passes/test_batch_norm_fusion.py | 35 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/_passes/fuse_batch_norm.py b/backends/xnnpack/_passes/fuse_batch_norm.py index a83be194e66..114f06dc6a6 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm.py +++ b/backends/xnnpack/_passes/fuse_batch_norm.py @@ -82,12 +82,16 @@ def call(self, graph_module: torch.fx.GraphModule): @staticmethod def can_fuse( - input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram + input_node: torch.fx.Node, + bn: torch.fx.Node, + program: ExportedProgram, ) -> bool: """ Determine whether a BatchNorm node can be fused with the preceding convolution or linear node. """ + is_conv = input_node.target == exir_ops.edge.aten.convolution.default + # All users of the batch_norm node must be getitem ops. # batch_norm returns a 3-element tuple. # Each user must only access the first element of the tuple. @@ -110,6 +114,10 @@ def can_fuse( ].count(False): return False + # Check the rank of the convolutution input - only Conv1d and 2d are supported. + if is_conv and len(input_node.args[0].meta["val"].shape) not in (3, 4): + return False + return True def _fuse_ops( diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index a095fa236fe..e5196b6a6a0 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -55,6 +55,26 @@ def forward(self, x): y = y + y return self.bn(y) + class ModelConv3dBN(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + kernel_size: Tuple[int, int], + ): + super().__init__() + op = torch.nn.Conv3d + self.conv3d = op(in_features, out_features, kernel_size) + self.bn = torch.nn.BatchNorm3d(out_features) + self.forward(torch.randn(2, 2, 4, 4, 4) * 2 + 2) # update the BN stats + + def forward(self, x): + y = self.conv3d(x) + y = self.bn(y) + y = self.conv3d(y) + y = y + y + return self.bn(y) + def test_fp32_conv_batch_norm_fusion(self): for transpose in [False, True]: ( @@ -142,3 +162,18 @@ def forward(self, x): .to_edge_transform_and_lower() .check_count({self.bn_name: 1}) ) + + def test_fp32_conv3d_batch_norm_doesnt_partition(self): + """ + Conv3d is not currently supported by XNNPACK. We also don't support standalone + batch norms yet (i.e. batch norms that are not fused with a conv). As such, we don't + want to partition the standalone batch norm and then fail to lower. + """ + ( + Tester(self.ModelConv3dBN(2, 2, (2, 2, 2)), (torch.randn(2, 2, 4, 4, 4),)) + .export() + .dump_artifact() + .to_edge_transform_and_lower() + .check_count({self.bn_name: 2}) + .run_method_and_compare_outputs() + )