From 332f760cab90fd9af4d18d94a5dafbb2bb5eb9f9 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 25 Aug 2023 16:55:48 -0700 Subject: [PATCH] handle static ints and floats (#140) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/140 In order to support MV3 which has decomposed hardswish and hardsigmoid Decomp rules for both: ### Hardswish https://www.internalfb.com/code/fbsource/[9368f8417bd843ee8c91e24ac616ed7f4b194ed8]/xplat/caffe2/torch/_decomp/decompositions.py?lines=182-185 ### Hardsigmoid https://www.internalfb.com/code/fbsource/[9368f8417bd843ee8c91e24ac616ed7f4b194ed8]/xplat/caffe2/torch/_decomp/decompositions.py?lines=159-162 ### Fixing Zero-Dim tensors Both of these decompositions produce zero-dim tensors in the graph ( The + 3 and the / 6). This breaks for XNNPACK because it does not have zero-dim tensors. Instead if the static data is zero dim, then we will interpret it as [1]. #### Fixing torch.int64 static data In the decomposition 3 is converted via to_copy(torch.float32). However 6 remains as an int64. XNNPACK does not handle non-quantized integers, so we also cast all static data that is not quantized to float32 values. Reviewed By: digantdesai Differential Revision: D48667679 fbshipit-source-id: 5edc3d881a599b0f1ee9fc6fddbc582db2f729ee --- backends/xnnpack/operators/node_visitor.py | 5 +++++ backends/xnnpack/partition/configs.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index e1f12bb51b8..5f4e871f90b 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -230,6 +230,7 @@ def define_tensor( # Get new xnn id for tensor value ext_id, id_out, flag = self.gen_ids_and_flags(tensor, xnn_graph, quant_params) dims = get_shape(tensor) + dims = [1] if len(dims) == 0 else dims # constant values serialize data buffer_idx = self.get_serialized_buffer( @@ -336,6 +337,10 @@ def get_serialized_buffer( # Quantize buffer if static data is indeed quantized if quant_params is not None and not quant_params.is_dynamic: const_val = quant_params.quantize_tensor(const_val).contiguous() + else: + # ensure that the const is fp32 + const_val = const_val.to(dtype=torch.float32).contiguous() + if swap_nc_for_depthwise_weights: const_val = const_val.permute( dims=((1, 0) + tuple(range(2, const_val.dim()))) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index a3b66d3fcb4..1a5f567bf43 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -37,6 +37,9 @@ SUPPORTED_MODULES = [ torch.nn.Conv1d, + # TODO(T161981984) recomposed hardswish into a single node + torch.nn.Hardswish, + torch.nn.Hardsigmoid, torch.nn.Conv2d, torch.nn.ReLU, torch.nn.Sigmoid,