Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def _tosa_pipeline(
ConvertSplitToSlicePass(),
QuantizeClampArgumentsPass(),
RemoveGetItemPass(),
FuseBatchNorm2dPass(exported_program),
DecomposeBatchNormNoStatsPass(),
DecomposeLogitPass(),
DecomposeMaskedFillPass(),
Expand Down Expand Up @@ -501,7 +502,6 @@ def _tosa_pipeline(
RewriteBoolBitwiseToLogicalPass(),
DecomposeRemainderPass(),
DecomposeDivTensorModePass(),
FuseBatchNorm2dPass(exported_program),
ConvertMmToBmmPass(),
DecomposeGluPass(),
DecomposeDivPass(),
Expand Down
8 changes: 6 additions & 2 deletions backends/arm/_passes/fuse_batch_norm2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
!= exir_ops.edge.aten._native_batch_norm_legit_no_training.default
):
continue
if get_first_fake_tensor(node).dtype == torch.bfloat16:
# Don't fuse if the data type is bfloat16, as the fused weights may
# not be accurate enough and cause significant accuracy drop.
continue

# Get data from batchnorm
input_node = node.all_input_nodes[0]
Expand Down Expand Up @@ -153,8 +157,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
if not (
(input_bias_node is None)
or (
isinstance(input_weight_node, Node)
and input_weight_node.op == "placeholder"
isinstance(input_bias_node, Node)
and input_bias_node.op == "placeholder"
)
):
raise RuntimeError(
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/test/misc/test_transpose_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,16 @@ def forward(self, x):
Model6GruLinear(), (torch.randn(2, 16, 8),), 2
),
"model_7_dwconv_batchnorm_linear": TransposeCountCase(
Model7DwConvBatchNormLinear(), (torch.randn(2, 8, 64),), 3
Model7DwConvBatchNormLinear(), (torch.randn(2, 8, 64),), 1
),
"model_8_conv_batchnorm_maxpool_residual": TransposeCountCase(
Model8ConvBatchNormMaxPoolResidual(), (torch.randn(1, 8, 16, 16),), 6
Model8ConvBatchNormMaxPoolResidual(), (torch.randn(1, 8, 16, 16),), 4
),
"model_9_dilated_conv_batchnorm_avgpool_residual": TransposeCountCase(
Model9DilatedConvBatchNormAvgPoolResidual(), (torch.randn(1, 8, 16, 16),), 6
Model9DilatedConvBatchNormAvgPoolResidual(), (torch.randn(1, 8, 16, 16),), 4
),
"model_10_dwconv_batchnorm_linear_cat": TransposeCountCase(
Model10DwConvBatchNormLinearCat(), (torch.randn(2, 8, 64),), 3
Model10DwConvBatchNormLinearCat(), (torch.randn(2, 8, 64),), 1
),
}

Expand Down
14 changes: 14 additions & 0 deletions backends/arm/test/ops/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,20 @@ def test_native_batch_norm_legit_no_training_tosa_FP_conv(test_data: Tuple):
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_native_batch_norm_legit_no_training_tosa_FP_conv_fuses_before_decompose(
test_data: Tuple,
):
test_data, model_params = test_data()
pipeline = TosaPipelineFP[input_t1](
BatchNorm2dConv(*model_params),
(test_data,),
aten_op=BatchNorm2dConv.aten_ops,
)
pipeline.count_tosa_ops({"CONV2D": 1, "RSQRT": 0, "SUB": 0})
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_native_batch_norm_legit_no_training_tosa_INT_conv(test_data: Tuple):
test_data, model_params = test_data()
Expand Down
27 changes: 27 additions & 0 deletions backends/arm/test/passes/test_fuse_batchnorm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class MergeOneOfTwoBNBf16(MergeOneOfTwoBN):
ops_before_pass: ClassVar[Dict[str, int]] = MergeOneOfTwoBN.ops_before_pass
ops_after_pass: ClassVar[Dict[str, int]] = MergeOneOfTwoBN.ops_before_pass

def __init__(self, affine: bool):
super().__init__(affine)
self.to(torch.bfloat16)

def get_inputs(self) -> input_t:
return (torch.randn(1, 3, 256, 256, dtype=torch.bfloat16),)


class MergeTwosOfTwoBN(torch.nn.Module):
ops_before_pass: ClassVar[Dict[str, int]] = {
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
Expand Down Expand Up @@ -163,3 +175,18 @@ def test_fuse_batch_norm2d_tosa_FP(module: ModuleWithBatchNormAttrs) -> None:
passes_with_exported_program=[FuseBatchNorm2dPass],
)
pipeline.run()


def test_fuse_batch_norm2d_tosa_FP_bf16_skips_fusion() -> None:
module = cast(ModuleWithBatchNormAttrs, MergeOneOfTwoBNBf16(True))
nn_module = cast(torch.nn.Module, module)
pipeline = PassPipeline[input_t](
nn_module,
module.get_inputs(),
quantize=False,
ops_before_pass=module.ops_before_pass,
ops_after_pass=module.ops_after_pass,
passes_with_exported_program=[FuseBatchNorm2dPass],
tosa_extensions=["bf16"],
)
pipeline.run()
Loading