diff --git a/backends/xnnpack/passes/convert_to_linear.py b/backends/xnnpack/passes/convert_to_linear.py index 69f882523c8..2cef71bf927 100644 --- a/backends/xnnpack/passes/convert_to_linear.py +++ b/backends/xnnpack/passes/convert_to_linear.py @@ -13,9 +13,8 @@ from executorch.backends.transforms.addmm_mm_to_linear import ( apply_addmm_mm_to_linear_transform, ) -from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass -from executorch.backends.xnnpack.utils.utils import is_param_node from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.utils.source_matcher_utils import ( @@ -27,7 +26,7 @@ logger.setLevel(logging.WARNING) -class ConvertToLinearPass(XNNPACKPass): +class ConvertToLinearPass(ExportPass): linear_modules = [ torch.nn.Linear, torch.nn.functional.linear, @@ -71,28 +70,24 @@ def get_arg(node: torch.fx.Node, arg: str): map_ = {"input": 0, "weight": 1} return None if arg == "bias" else node.args[map_[arg]] - def find_bias_for_mm(self, src_partition: SourcePartition, weight: torch.fx.Node): + def find_bias_for_mm(self, src_partition: SourcePartition, mm_node: torch.fx.Node): """ For linear decomposed with mm + add, find bias in src partition """ - out_channels = get_shape(weight)[0] - bias = None - - # Try to find bias node in all nodes - for node in src_partition.nodes: - if is_param_node(self.exported_program, node) and node != weight: - bias = node - - if bias is not None: - assert get_shape(bias) == [ - out_channels - ], f"Expected bias shape {[out_channels]} but got {get_shape(bias)}" - else: - assert exir_ops.edge.aten.add.Tensor not in [ - node.target for node in src_partition.nodes - ], f"Expecting to find bias for Linear module: {src_partition} but could not find it" - return bias + mm_users = list(mm_node.users.keys()) + if len(mm_users) != 1: + return None + + add_node = mm_users[0] + if add_node.target != exir_ops.edge.aten.add.Tensor: + return None + + for arg in add_node.all_input_nodes: + if arg != mm_node and arg in src_partition.input_nodes: + return arg + + return None def create_linear( self, @@ -119,7 +114,7 @@ def create_linear( src_partition.input_nodes + src_partition.params, # bias can be in params ) if linear_bias is None and node.target == exir_ops.edge.aten.mm.default: - linear_bias = self.find_bias_for_mm(src_partition, linear_weight) + linear_bias = self.find_bias_for_mm(src_partition, node) logger.debug(f"Found bias(?): {linear_bias} from node {node}")