From 4c5980cd51b41a548ce001cc067ddc1f60f2a2bb Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 10 Sep 2025 12:14:00 -0700 Subject: [PATCH] Forward fix Fix batch norm partitioning with Conv3d (#13696) Summary: Forward pyre fix https://www.internalfb.com/diff/D81069236 https://github.com/pytorch/executorch/pull/13696 bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: digantdesai Differential Revision: D82118651 --- backends/xnnpack/_passes/fuse_batch_norm.py | 10 ++++++++-- backends/xnnpack/test/passes/test_batch_norm_fusion.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/backends/xnnpack/_passes/fuse_batch_norm.py b/backends/xnnpack/_passes/fuse_batch_norm.py index 114f06dc6a6..a51920ed5ad 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm.py +++ b/backends/xnnpack/_passes/fuse_batch_norm.py @@ -115,8 +115,14 @@ def can_fuse( return False # Check the rank of the convolutution input - only Conv1d and 2d are supported. - if is_conv and len(input_node.args[0].meta["val"].shape) not in (3, 4): - return False + if is_conv: + conv_input = input_node.args[0] + if ( + not isinstance(conv_input, torch.fx.Node) + or "val" not in conv_input.meta + or len(conv_input.meta["val"].shape) not in (3, 4) + ): + return False return True diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index e5196b6a6a0..4e4cd065fa5 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -60,7 +60,7 @@ def __init__( self, in_features: int, out_features: int, - kernel_size: Tuple[int, int], + kernel_size: Tuple[int, int, int], ): super().__init__() op = torch.nn.Conv3d