diff --git a/backends/xnnpack/_passes/fuse_batch_norm.py b/backends/xnnpack/_passes/fuse_batch_norm.py index 114f06dc6a6..a51920ed5ad 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm.py +++ b/backends/xnnpack/_passes/fuse_batch_norm.py @@ -115,8 +115,14 @@ def can_fuse( return False # Check the rank of the convolutution input - only Conv1d and 2d are supported. - if is_conv and len(input_node.args[0].meta["val"].shape) not in (3, 4): - return False + if is_conv: + conv_input = input_node.args[0] + if ( + not isinstance(conv_input, torch.fx.Node) + or "val" not in conv_input.meta + or len(conv_input.meta["val"].shape) not in (3, 4) + ): + return False return True diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index e5196b6a6a0..4e4cd065fa5 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -60,7 +60,7 @@ def __init__( self, in_features: int, out_features: int, - kernel_size: Tuple[int, int], + kernel_size: Tuple[int, int, int], ): super().__init__() op = torch.nn.Conv3d