Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion backends/transforms/remove_clone_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ class RemoveCloneOpsTransform(ExportPass):
exir_ops.edge.dim_order_ops._clone_dim_order.default,
}

def __init__(self) -> None:
def __init__(self, preserve_input_output_copies: bool = False) -> None:
super().__init__()
self._preserve_input_output_copies = preserve_input_output_copies

def _remove(self, graph_module: torch.fx.GraphModule) -> None:
dequant_nodes = []
Expand All @@ -38,6 +39,11 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
if self._is_non_identity_clone(n):
continue

# If preserve_input_output_copies is set, don't remove clones that directly
# copy from input to output.
if self._is_input_output_copy(n) and self._preserve_input_output_copies:
continue

to_be_removed = n
for user_n in list(n.users.keys()):
user_n.replace_input_with(n, n.args[0])
Expand Down Expand Up @@ -76,3 +82,16 @@ def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
)

return False

def _is_input_output_copy(self, node: torch.fx.Node) -> bool:
"""Return True if the node input is a graph input and output goes into an output node."""

input_node = node.args[0]
if input_node.op != "placeholder":
return False

for users in node.users:
if users.op == "output":
return True

return False
1 change: 1 addition & 0 deletions backends/xnnpack/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ runtime.python_library(
deps = [
"//caffe2:torch",
"//executorch/backends/transforms:addmm_mm_to_linear",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/backends/transforms:lib",
"//executorch/backends/xnnpack/partition:partitioner_graphs",
"//executorch/backends/xnnpack/serialization:xnnpack_schema",
Expand Down
10 changes: 10 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import List, Optional, Type

from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform

from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass

from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
Expand Down Expand Up @@ -42,6 +46,11 @@
from torch.export import ExportedProgram


class XNNPACKRemoveCloneOpsTransform(RemoveCloneOpsTransform):
def __init__(self):
super().__init__(preserve_input_output_copies=True)


class XNNPACKPassManager:
def __init__(
self,
Expand All @@ -58,6 +67,7 @@ def __init__(
if not passes:
# All the XNNPACK passes
self.passes = [
XNNPACKRemoveCloneOpsTransform,
# TODO - remove this pass once we have a better support for dim_order ops lowering
DimOrderOpsRevertPass,
ConvertToUpsampleBilinear2d,
Expand Down
3 changes: 3 additions & 0 deletions backends/xnnpack/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from . import ( # noqa
node_visitor,
op_abs,
Expand All @@ -14,6 +16,7 @@
op_cat,
op_ceiling,
op_clamp,
op_clone,
op_conv2d,
op_div,
op_dynamic_dequantize_ops,
Expand Down
60 changes: 60 additions & 0 deletions backends/xnnpack/operators/op_clone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Dict

import torch
from executorch.backends.xnnpack.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNCopy,
XNNGraph,
XNode,
)
from executorch.backends.xnnpack.utils.utils import get_input_node


@register_node_visitor
class CloneVisitor(NodeVisitor):
target = "aten.clone.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
xnn_graph: XNNGraph,
vals_to_ids: Dict[torch.fx.Node, int],
debug_handle: int,
) -> None:
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)

# Sanity check that the input and output dim order are the same. We don't
# handle dim order conversions yet.
dim_order = node.kwargs.get("dim_order", None)
input_meta = node.args[0].meta["val"]
assert dim_order is None or list(input_meta.dim_order()) == dim_order

# input
input_id = vals_to_ids[get_input_node(node, 0)]

# output
output_id = vals_to_ids[node]

ser_node = XNode(
xnode_union=XNNCopy(
input_id=input_id,
output_id=output_id,
flags=0,
),
debug_handle=debug_handle,
)
xnn_graph.xnodes.append(ser_node)
3 changes: 3 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import List, Type

Expand All @@ -22,6 +23,7 @@
CatConfig,
CeilConfig,
ClampConfig,
CloneDimOrderConfig,
ConstantPadConfig,
DeQuantizedPerTensorConfig,
DivConfig,
Expand Down Expand Up @@ -77,6 +79,7 @@
BMMConfig,
CatConfig,
CeilConfig,
CloneDimOrderConfig,
ConstantPadConfig,
ConvolutionConfig,
ClampConfig,
Expand Down
22 changes: 22 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,3 +643,25 @@ class SinConfig(GenericNodePartitionerConfig):

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class CloneDimOrderConfig(GenericNodePartitionerConfig):
target_name = "_clone_dim_order.default"

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
if not self.check_common_constraints(node, ep):
return False

# Only partition no-op _clone_dim_order nodes (output dim order = input).
# We can relax this in the future.
# This is also a conservative check and doesn't consider ambiguity.
dim_order = node.kwargs.get("dim_order", None)
input_meta = node.args[0].meta["val"]
if dim_order is not None and list(input_meta.dim_order()) != dim_order:
why(node, reason="Only dim-order preserving clones are supported.")
return False

return True
29 changes: 29 additions & 0 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,34 @@ Error defineBatchMatrixMultiplyNode(
return Error::Ok;
}

/*
* Defines a copy node in the XNN subgraph.
*/
Error defineCopyNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);

auto graph_node = node->xnode_union_as_XNNCopy();

xnn_status status = xnn_define_copy(
subgraph_ptr,
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());

ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create copy node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

return Error::Ok;
}

/*
Returns not Implemented Error code. This function is meant to be
called when the compiler encountes a XNodeType from the flatbuffer
Expand Down Expand Up @@ -1763,6 +1791,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
_DEFINE(Concatenate5)
_DEFINE(StaticSlice)
_DEFINE(BatchMatrixMultiply)
_DEFINE(Copy)
case fb_xnnpack::XNodeUnion::NONE:
default: // Adding here as a catch all, just in case
return &defineNotImplementedNode;
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/runtime_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ union XNodeUnion {
XNNTanh: _XNNNode1x1,
XNNExp: _XNNNode1x1,
XNNSin: _XNNNode1x1,
XNNCopy: _XNNNode1x1,
}

union XValueUnion {
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ union XNodeUnion {
XNNTanh: _XNNNode1x1,
XNNExp: _XNNNode1x1,
XNNSin: _XNNNode1x1,
XNNCopy: _XNNNode1x1,
}

union XValueUnion {
Expand Down
6 changes: 6 additions & 0 deletions backends/xnnpack/serialization/xnnpack_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,11 @@ class XNNSin(XNNNode1x1):
pass


@dataclass
class XNNCopy(XNNNode1x1):
pass


@dataclass
class XNNScaledDotProductAttention:
query_id: int
Expand Down Expand Up @@ -409,6 +414,7 @@ class XNNScaledDotProductAttention:
XNNTanh,
XNNExp,
XNNSin,
XNNCopy,
]


Expand Down
Loading
Loading