Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,33 @@ def gen_ids_and_flags(
# This will break if we change the way q/dq are partitioned

# Tensor can still be input if its quantizing node is an input
is_input = (self).is_graph_input(tensor)
if self.is_graph_input(tensor) or (
quant_params.is_input if quant_params else False
):
tensor_input = tensor
if quant_params:
if quant_params.is_input and not self.is_graph_input(tensor):
tensor_input = quant_params.q_input
assert (
tensor_input in self.external_ids.keys()
), f"Tensor {tensor_input}, is_input. ext_ids: {self.external_ids.keys()}"
ext_id = self.external_ids[tensor_input].external_id
xnn_graph.input_ids.append(id_out)
flag = self.external_ids[tensor_input].io_type
# Tensor can still be output if its quantizing node is an output
is_output = self.is_graph_output(tensor)
# handle logic for input/output tensors
if is_input or is_output:
elif self.is_graph_output(tensor) or (
quant_params.is_output if quant_params else False
):
tensor_output = tensor
if quant_params:
if quant_params.is_output and not self.is_graph_output(tensor):
tensor_output = list(tensor.users)[0]
assert (
tensor in self.external_ids.keys()
), f"Tensor {tensor}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}"
ext_id = self.external_ids[tensor].external_id
if is_input:
xnn_graph.input_ids.append(id_out)
flag = XNN_VALUE_FLAG_EXTERNAL_INPUT
if is_output:
xnn_graph.output_ids.append(id_out)
flag = XNN_VALUE_FLAG_EXTERNAL_OUTPUT
tensor_output in self.external_ids.keys()
), f"Tensor {tensor_output} is_output: ext_ids: {self.external_ids.keys()}"
ext_id = self.external_ids[tensor_output].external_id
xnn_graph.output_ids.append(id_out)
flag = self.external_ids[tensor_output].io_type

return ext_id, id_out, flag

Expand Down Expand Up @@ -230,6 +242,7 @@ def define_tensor(
# Get new xnn id for tensor value
ext_id, id_out, flag = self.gen_ids_and_flags(tensor, xnn_graph, quant_params)
dims = get_shape(tensor)
dims = [1] if len(dims) == 0 else dims

# constant values serialize data
buffer_idx = self.get_serialized_buffer(
Expand Down Expand Up @@ -336,6 +349,10 @@ def get_serialized_buffer(
# Quantize buffer if static data is indeed quantized
if quant_params is not None and not quant_params.is_dynamic:
const_val = quant_params.quantize_tensor(const_val).contiguous()
else:
# ensure that the const is fp32
const_val = const_val.to(dtype=torch.float32).contiguous()

if swap_nc_for_depthwise_weights:
const_val = const_val.permute(
dims=((1, 0) + tuple(range(2, const_val.dim())))
Expand Down
8 changes: 8 additions & 0 deletions backends/xnnpack/partition/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

SUPPORTED_MODULES = [
torch.nn.Conv1d,
# TODO(T161981984) recomposed hardswish into a single node
torch.nn.Hardswish,
torch.nn.Hardsigmoid,
torch.nn.Conv2d,
torch.nn.ReLU,
torch.nn.Sigmoid,
Expand Down Expand Up @@ -78,6 +81,11 @@
)
}

UNSUPPORTED_QUANT_MODULES = [
torch.nn.Hardswish,
torch.nn.Hardsigmoid,
]

# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support
SUPPORTED_QUANT_MODULES = [
torch.clamp,
Expand Down
24 changes: 20 additions & 4 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SUPPORTED_OPS,
SUPPORTED_QUANT_MODULES,
SUPPORTED_QUANT_OPS,
UNSUPPORTED_QUANT_MODULES,
)
from executorch.backends.xnnpack.partition.support_patterns import (
get_add_graphs,
Expand Down Expand Up @@ -103,17 +104,25 @@ def __init__(
Any, Callable[[torch.fx.Node], bool]
] = _OP_SUPPORT_CONSTRAINTS,
supported_ops: Optional[List] = None,
unsupported_modules: Optional[List] = None,
):
"""
@Arg constraints_dict: Dict mapping each node to a lambda function that
returns True if backend constraints are met for that instance of the
node.
@Arg supported_ops: List of supported operators for partitioning
"""
self.unsupported_modules = unsupported_modules
self.supported_ops = supported_ops
self.constraints = constraints_dict
assert len(self.constraints)

def check_common_constraints(self, node) -> bool:
if self.unsupported_modules and "source_fn" in node.meta:
return not node.meta["source_fn"][1] in self.unsupported_modules

return True

@staticmethod
def check_constraint(node) -> bool:
"""
Expand All @@ -132,7 +141,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
if self.supported_ops and node.target not in self.supported_ops:
return False

return self.check_constraint(node)
return self.check_constraint(node) and self.check_common_constraints(node)

def _constraint(target): # noqa
"""
Expand Down Expand Up @@ -540,10 +549,11 @@ def __init__(
self,
supported_modules: List[Callable] = SUPPORTED_MODULES,
supported_ops: Optional[List[Callable]] = SUPPORTED_OPS,
unsupported_modules: Optional[List[Callable]] = None,
):
super().__init__()
self.supported_modules = set(supported_modules)

self.unsupported_modules = unsupported_modules
self.supported_ops = set(supported_ops or [])

self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
Expand Down Expand Up @@ -614,7 +624,10 @@ def generate_partitions(self, graph_module: torch.fx.GraphModule) -> List[Any]:
return generate_partitions_from_list_of_nodes(
graph_module,
matched_module_nodes,
XnnpackOperatorSupport(supported_ops=self.supported_ops),
XnnpackOperatorSupport(
supported_ops=self.supported_ops,
unsupported_modules=self.unsupported_modules,
),
)

def tag_nodes(self, partitions: List[Partition]) -> None:
Expand Down Expand Up @@ -668,9 +681,12 @@ def __init__(
self,
supported_modules=SUPPORTED_QUANT_MODULES,
supported_ops=SUPPORTED_QUANT_OPS,
unsupported_modules=UNSUPPORTED_QUANT_MODULES,
):
supported_ops = supported_ops or []
super().__init__(supported_modules, supported_ops + self._QUANT_OPS)
super().__init__(
supported_modules, supported_ops + self._QUANT_OPS, unsupported_modules
)

# TODO Refactor this
# TODO Don't be greedy when pulling q->dq pairs for a given op, add convert tracker pass
Expand Down