Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 17 additions & 22 deletions backends/xnnpack/passes/convert_to_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from executorch.backends.transforms.addmm_mm_to_linear import (
apply_addmm_mm_to_linear_transform,
)
from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.utils import is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.utils.source_matcher_utils import (
Expand All @@ -27,7 +26,7 @@
logger.setLevel(logging.WARNING)


class ConvertToLinearPass(XNNPACKPass):
class ConvertToLinearPass(ExportPass):
linear_modules = [
torch.nn.Linear,
torch.nn.functional.linear,
Expand Down Expand Up @@ -71,28 +70,24 @@ def get_arg(node: torch.fx.Node, arg: str):
map_ = {"input": 0, "weight": 1}
return None if arg == "bias" else node.args[map_[arg]]

def find_bias_for_mm(self, src_partition: SourcePartition, weight: torch.fx.Node):
def find_bias_for_mm(self, src_partition: SourcePartition, mm_node: torch.fx.Node):
"""
For linear decomposed with mm + add, find bias in src partition
"""
out_channels = get_shape(weight)[0]
bias = None

# Try to find bias node in all nodes
for node in src_partition.nodes:
if is_param_node(self.exported_program, node) and node != weight:
bias = node

if bias is not None:
assert get_shape(bias) == [
out_channels
], f"Expected bias shape {[out_channels]} but got {get_shape(bias)}"
else:
assert exir_ops.edge.aten.add.Tensor not in [
node.target for node in src_partition.nodes
], f"Expecting to find bias for Linear module: {src_partition} but could not find it"

return bias
mm_users = list(mm_node.users.keys())
if len(mm_users) != 1:
return None

add_node = mm_users[0]
if add_node.target != exir_ops.edge.aten.add.Tensor:
return None

for arg in add_node.all_input_nodes:
if arg != mm_node and arg in src_partition.input_nodes:
return arg

return None

def create_linear(
self,
Expand All @@ -119,7 +114,7 @@ def create_linear(
src_partition.input_nodes + src_partition.params, # bias can be in params
)
if linear_bias is None and node.target == exir_ops.edge.aten.mm.default:
linear_bias = self.find_bias_for_mm(src_partition, linear_weight)
linear_bias = self.find_bias_for_mm(src_partition, node)

logger.debug(f"Found bias(?): {linear_bias} from node {node}")

Expand Down
Loading