diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index bf16855afc1..872ba355c70 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -210,6 +210,11 @@ def _get_bias_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: gemm_deps = [] + if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: + # if force force_fp32_dynamic_linear is enabled, then we + # do not partition the weight node + return (True, gemm_deps) + if len(node.all_input_nodes) > 2 and self.bias_idx is not None: bias_node = get_input_node(node, self.bias_idx) if bias_node: @@ -477,7 +482,15 @@ def find_partition_args(input_node): node.args = old_args node.users = old_users - return valid_deps, list(set(deps) | set(src_partition.nodes)) + # When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes. + # Else we want to be greedy. + ret_deps = ( + list(set(deps) & set(src_partition.nodes)) + if self.force_fp32_dynamic_linear + else list(set(deps) | set(src_partition.nodes)) + ) + + return valid_deps, ret_deps def supported_precision_types(self): return [ diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index eccda406b80..30bb4f0aba2 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -31,6 +31,8 @@ ToEdgeTransformAndLower, ) +from torch.export.graph_signature import ExportGraphSignature, InputKind + try: from torchao.quantization.quant_api import ( int8_dynamic_activation_int4_weight, @@ -871,3 +873,71 @@ def test_linear_qd8_as_fp32(self): "dequantize_per_channel.default": 1, # 1: weight }, ) + + def test_linear_fp32_with_force_as_mm(self): + def check_signature( + signature: ExportGraphSignature, + force_flag: bool, + use_bias: bool, + legacy_mode: bool, + ): + num_params = 0 + if force_flag: + num_params = 1 # weight_param + if use_bias: + num_params += 1 # bias_param + sign_params: int = 0 + input_specs = signature.input_specs + for input_spec in input_specs: + if input_spec.kind == InputKind.PARAMETER: + sign_params += 1 + assert ( + sign_params == num_params + ), f"Expected {num_params} params, got {sign_params} with force_flag={force_flag}, use_bias={use_bias}, legacy_mode={legacy_mode}" + + for force_flag in (True, False): + for use_bias in (True, False): + for legacy_mode in (True, False): + module = BaseLinear( + in_size=8, + input_channels=13, + output_channels=17, + use_bias=use_bias, + ) + inputs = module.get_inputs() + tester = Tester(module, inputs).export() + partitioner = XnnpackPartitioner( + force_fp32_dynamic_linear=force_flag + ) + if legacy_mode: + tester.to_edge() + partitioner_stage = Partition(partitioner=partitioner) + tester.partition(partition_stage=partitioner_stage) + tester.check_not( + [ + ( + "executorch_exir_dialects_edge__ops_aten_mm_default" + if use_bias + else "executorch_exir_dialects_edge__ops_aten_addmm_default" + ) + ] + ) + else: + to_edge_and_transform_stage = ToEdgeTransformAndLower( + partitioners=[partitioner] + ) + tester.to_edge_transform_and_lower( + to_edge_and_transform_stage=to_edge_and_transform_stage + ) + tester.check_not( + ["executorch_exir_dialects_edge__ops_aten_linear_default"] + ) + + signature: ExportGraphSignature = ( + tester.get_artifact().exported_program().graph_signature + ) + check_signature(signature, force_flag, use_bias, legacy_mode) + + tester.to_executorch() + tester.serialize() + tester.run_method_and_compare_outputs() diff --git a/backends/xnnpack/test/ops/test_lstm.py b/backends/xnnpack/test/ops/test_lstm.py index bfc6113c417..be209082b37 100644 --- a/backends/xnnpack/test/ops/test_lstm.py +++ b/backends/xnnpack/test/ops/test_lstm.py @@ -54,9 +54,8 @@ def test_fp32_lstm_force_dynamic_linear(self): ) .check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"]) # Weights are supplied as input to linears - .check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0"]) - # Biases are owned by delegates - .check_not(["p_lstm_bias"]) + # Biases are not owned by delegates when force_fp32_dynamic_linear is set + .check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"]) .to_executorch() .serialize() .run_method_and_compare_outputs()