Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
get_first_fake_tensor,
is_param_node,
)
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -43,6 +42,19 @@ def __init__(self, exported_program: ExportedProgram) -> None:
self.exported_program = exported_program
super().__init__()

@staticmethod
def _is_consumer_node_depthwise_conv2d(node: torch.fx.Node):
consumer_node = list(node.users)[0]
if consumer_node.target == exir_ops.edge.aten.convolution.default:
consumer_node_inputs = consumer_node.all_input_nodes
groups = consumer_node.args[-1]
in_channels = consumer_node_inputs[0].meta["val"].shape[1]
out_channels = consumer_node_inputs[1].meta["val"].shape[0]
if (in_channels == groups) and (out_channels % in_channels) == 0:
return True

return False

def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
"""
returns True for w in the following sequence;
Expand All @@ -53,7 +65,7 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
consumer_node = list(node.users)[0]
if self.is_weight_node_for_depthwise_conv2d(consumer_node):
return True
if is_consumer_node_depthwise_conv2d(node):
if self._is_consumer_node_depthwise_conv2d(node):
# Check that node is the weight-argument and not input or bias
return consumer_node.args[1] == node

Expand Down
15 changes: 0 additions & 15 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
import torch

from executorch.backends.arm.tosa_mapping import extract_tensor_meta

from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

from torch._subclasses.fake_tensor import FakeTensor
from torch.fx import Node
Expand Down Expand Up @@ -155,19 +153,6 @@ def build_reshape_tosa_1_0(
)


def is_consumer_node_depthwise_conv2d(node: Node):
consumer_node = list(node.users)[0]
if consumer_node.target == exir_ops.edge.aten.convolution.default:
consumer_node_inputs = consumer_node.all_input_nodes
groups = consumer_node.args[-1]
in_channels = consumer_node_inputs[0].meta["val"].shape[1]
out_channels = consumer_node_inputs[1].meta["val"].shape[0]
if (in_channels == groups) and (out_channels % in_channels) == 0:
return True

return False


def tosa_shape(shape, dim_order):
reordered = tuple([shape[dim] for dim in dim_order])
# Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes,
Expand Down
Loading