Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion backends/xnnpack/_passes/fuse_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,16 @@ def call(self, graph_module: torch.fx.GraphModule):

@staticmethod
def can_fuse(
input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
input_node: torch.fx.Node,
bn: torch.fx.Node,
program: ExportedProgram,
) -> bool:
"""
Determine whether a BatchNorm node can be fused with the preceding convolution or linear node.
"""

is_conv = input_node.target == exir_ops.edge.aten.convolution.default

# All users of the batch_norm node must be getitem ops.
# batch_norm returns a 3-element tuple.
# Each user must only access the first element of the tuple.
Expand All @@ -110,6 +114,10 @@ def can_fuse(
].count(False):
return False

# Check the rank of the convolutution input - only Conv1d and 2d are supported.
if is_conv and len(input_node.args[0].meta["val"].shape) not in (3, 4):
return False

return True

def _fuse_ops(
Expand Down
35 changes: 35 additions & 0 deletions backends/xnnpack/test/passes/test_batch_norm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ def forward(self, x):
y = y + y
return self.bn(y)

class ModelConv3dBN(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
kernel_size: Tuple[int, int],
):
super().__init__()
op = torch.nn.Conv3d
self.conv3d = op(in_features, out_features, kernel_size)
self.bn = torch.nn.BatchNorm3d(out_features)
self.forward(torch.randn(2, 2, 4, 4, 4) * 2 + 2) # update the BN stats

def forward(self, x):
y = self.conv3d(x)
y = self.bn(y)
y = self.conv3d(y)
y = y + y
return self.bn(y)

def test_fp32_conv_batch_norm_fusion(self):
for transpose in [False, True]:
(
Expand Down Expand Up @@ -142,3 +162,18 @@ def forward(self, x):
.to_edge_transform_and_lower()
.check_count({self.bn_name: 1})
)

def test_fp32_conv3d_batch_norm_doesnt_partition(self):
"""
Conv3d is not currently supported by XNNPACK. We also don't support standalone
batch norms yet (i.e. batch norms that are not fused with a conv). As such, we don't
want to partition the standalone batch norm and then fail to lower.
"""
(
Tester(self.ModelConv3dBN(2, 2, (2, 2, 2)), (torch.randn(2, 2, 4, 4, 4),))
.export()
.dump_artifact()
.to_edge_transform_and_lower()
.check_count({self.bn_name: 2})
.run_method_and_compare_outputs()
)
Loading