diff --git a/backends/xnnpack/_passes/fuse_batch_norm.py b/backends/xnnpack/_passes/fuse_batch_norm.py index a83be194e66..114f06dc6a6 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm.py +++ b/backends/xnnpack/_passes/fuse_batch_norm.py @@ -82,12 +82,16 @@ def call(self, graph_module: torch.fx.GraphModule): @staticmethod def can_fuse( - input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram + input_node: torch.fx.Node, + bn: torch.fx.Node, + program: ExportedProgram, ) -> bool: """ Determine whether a BatchNorm node can be fused with the preceding convolution or linear node. """ + is_conv = input_node.target == exir_ops.edge.aten.convolution.default + # All users of the batch_norm node must be getitem ops. # batch_norm returns a 3-element tuple. # Each user must only access the first element of the tuple. @@ -110,6 +114,10 @@ def can_fuse( ].count(False): return False + # Check the rank of the convolutution input - only Conv1d and 2d are supported. + if is_conv and len(input_node.args[0].meta["val"].shape) not in (3, 4): + return False + return True def _fuse_ops( diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index a095fa236fe..e5196b6a6a0 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -55,6 +55,26 @@ def forward(self, x): y = y + y return self.bn(y) + class ModelConv3dBN(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + kernel_size: Tuple[int, int], + ): + super().__init__() + op = torch.nn.Conv3d + self.conv3d = op(in_features, out_features, kernel_size) + self.bn = torch.nn.BatchNorm3d(out_features) + self.forward(torch.randn(2, 2, 4, 4, 4) * 2 + 2) # update the BN stats + + def forward(self, x): + y = self.conv3d(x) + y = self.bn(y) + y = self.conv3d(y) + y = y + y + return self.bn(y) + def test_fp32_conv_batch_norm_fusion(self): for transpose in [False, True]: ( @@ -142,3 +162,18 @@ def forward(self, x): .to_edge_transform_and_lower() .check_count({self.bn_name: 1}) ) + + def test_fp32_conv3d_batch_norm_doesnt_partition(self): + """ + Conv3d is not currently supported by XNNPACK. We also don't support standalone + batch norms yet (i.e. batch norms that are not fused with a conv). As such, we don't + want to partition the standalone batch norm and then fail to lower. + """ + ( + Tester(self.ModelConv3dBN(2, 2, (2, 2, 2)), (torch.randn(2, 2, 4, 4, 4),)) + .export() + .dump_artifact() + .to_edge_transform_and_lower() + .check_count({self.bn_name: 2}) + .run_method_and_compare_outputs() + )