Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 117 additions & 37 deletions backends/arm/_passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
get_node_arg,
insert_q_dq_pair,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
Expand Down Expand Up @@ -83,14 +84,48 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):

return False

def insert_input_transpose(self, node, input_node, graph_module):
@staticmethod
def memory_format_differs(shape):
"""Returns true if the shape will have a different memory layout in NCHW and NHWC format"""
if len(shape) >= 4:
C = shape[1]
H = shape[2]
W = shape[3]
elif len(shape) == 3:
C = shape[0]
H = shape[1]
W = shape[2]
if len(shape) <= 2:
return False

return C > 1 and (H > 1 or W > 1)

@staticmethod
def is_channel_reshape(input_shape, output_shape):
"""Returns true if the reshape changes the channel dimension"""
if not len(input_shape) == len(output_shape) == 4:
return False

C_old = input_shape[1]
C_new = output_shape[1]

N_new = output_shape[0]
N_old = input_shape[0]

return (N_old != N_new) or (C_old != C_new)

@staticmethod
def insert_input_transpose(node, input_node, graph_module):
quantize = input_node.target == dq_op
q_params = input_node.args[1:] if quantize else None
with graph_module.graph.inserting_before(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(input_node, list(self.NHWC_inverse_order)),
args=(
input_node,
list(AnnotateChannelsLastDimOrder.NHWC_inverse_order),
),
quantize=quantize,
q_params=q_params,
)
Expand All @@ -100,14 +135,17 @@ def insert_input_transpose(self, node, input_node, graph_module):
range(len(input_node.meta["val"].size()))
)

def insert_output_transpose(self, node, graph_module):
@staticmethod
def insert_output_transpose(node, graph_module):
with graph_module.graph.inserting_after(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(node, list(self.NHWC_order)),
args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)),
)
permute_node.meta["tosa_dim_order"] = (
AnnotateChannelsLastDimOrder.NHWC_order
)
permute_node.meta["tosa_dim_order"] = self.NHWC_order
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
users = [user for user in node.users if user != permute_node]
for user in users:
Expand All @@ -118,54 +156,96 @@ def insert_output_transpose(self, node, graph_module):
q_params = node.args[0].args[1:]
insert_q_dq_pair(graph_module.graph, node, q_params)

@staticmethod
def _insert_squeeze_transpose(
input_shape, output_shape, node, input_node, graph_module
):
nhwc_to_nhwc = len(input_shape) == 4 and len(output_shape) <= 3

if nhwc_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
input_shape
):
AnnotateChannelsLastDimOrder.insert_input_transpose(
node, input_node, graph_module
)

@staticmethod
def _insert_unsqueeze_transpose(input_shape, output_shape, node, graph_module):
nchw_to_nhwc = len(input_shape) == 3 and len(output_shape) == 4
if nchw_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
output_shape
):
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)

@staticmethod
def _insert_view_transpose(
input_shape, output_shape, node, input_node, graph_module
):
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) == 4
nhwc_to_nchw = len(input_shape) == 4 and len(output_shape) < 4
channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape(
output_shape, input_shape
)

if (
channel_reshape or nhwc_to_nchw
) and AnnotateChannelsLastDimOrder.memory_format_differs(input_shape):
AnnotateChannelsLastDimOrder.insert_input_transpose(
node, input_node, graph_module
)
if (
channel_reshape or nchw_to_nhwc
) and AnnotateChannelsLastDimOrder.memory_format_differs(output_shape):
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
"""
Reshape operations are not equivalent in NCHW and NHWC.
To get around this, transposes need to be added if the previous or new shape
fulfil the following condition:
C > 1 and (H or W > 1)

This is relevant for the following operations;
squeeze: 4D -> 3D
unsqueeze: <4D -> 4D
view: <4D -> 4D
view: 4D -> <4D
view: 4D -> 4D
"""

def transpose_condition(shape):
if len(shape) != 4:
return False
C = shape[1]
H = shape[2]
W = shape[3]
return C > 1 and (H > 1 or W > 1)
Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
This is relevant for the following cases:
- squeeze: 4D -> <4D
- unsqueeze: 3D -> 4D
- view: <4D -> 4D
- view: 4D -> <4D
Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.

Transposes can be avoided for shapes where there is no difference in actual memory, e.g for
- H == W == 1
- C == 1
- 1D/2D tensors
"""
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if node.target == exir_ops.edge.aten.squeeze_copy.dims:
input_node = node.args[0]
input_shape = input_node.meta["val"].shape
if transpose_condition(input_shape):
self.insert_input_transpose(node, input_node, graph_module)
output_shape = node.meta["val"].shape

self._insert_squeeze_transpose(
input_shape, output_shape, node, input_node, graph_module
)

elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
input_node = get_node_arg(node.args, 0, default_value=False)
if input_node:
input_shape = input_node.meta["val"].shape
else:
input_shape = ()
output_shape = node.meta["val"].shape
if transpose_condition(output_shape):
self.insert_output_transpose(node, graph_module)

self._insert_unsqueeze_transpose(
input_shape, output_shape, node, graph_module
)

elif node.target == exir_ops.edge.aten.view_copy.default:
input_node = node.args[0]
input_shape = input_node.meta["val"].shape
output_shape = node.meta["val"].shape

old_shape = input_node.meta["val"].shape
new_shape = node.meta["val"].shape

if transpose_condition(old_shape):
self.insert_input_transpose(node, input_node, graph_module)

if transpose_condition(new_shape):
self.insert_output_transpose(node, graph_module)
self._insert_view_transpose(
input_shape, output_shape, node, input_node, graph_module
)

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
Expand Down
1 change: 1 addition & 0 deletions backends/arm/test/ops/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class View(torch.nn.Module):
(torch.rand(1, 1, 5, 10), (1, 1, 50, 1)),
(torch.rand(5, 10, 1, 1), (1, 25, 2)),
(torch.rand(2, 50, 1, 1), (1, 100)),
(torch.rand(2, 3, 2, 3), (2, 3, 3, 2)),
]

def forward(self, x: torch.Tensor, new_shape):
Expand Down
Loading