From e9066174a407639dba04aa839c8d6a63b6fb8351 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 24 Oct 2024 11:24:04 -0700 Subject: [PATCH] [ET-VK] Introduce AOT operator registry ## Changes Move the following files to the root directory of Vulkan backend: * `backends/vulkan/partitioner/supported_ops.py` -> `backends/vulkan/op_registry.py` * `backends/vulkan/_passes/custom_ops_defs.py` -> `backends/vulkan/custom_ops_lib.py` In the new `op_registry.py` file, the way operator features are specified is reworked to provide much more detail about the features of the operator implementation in Vulkan. See the new `OpFeatures` class for more details. An example of registering a new operator to the export flow is ``` @update_features( [ exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, ] ) def register_reduce_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( uses_packed_dim=True, ) features.resize_fn = True def check_reduce_node(node: torch.fx.Node) -> bool: dim_list = node.args[1] assert isinstance(dim_list, list) if len(dim_list) != 1: return False keepdim = node.args[2] assert isinstance(keepdim, bool) if not keepdim: return False return True features.check_node_fn = check_reduce_node return features ``` ## Rationale The purpose of these changes is to centralize operator definitions so that there is a common source of truth about the capabilities of operator implementation in Vulkan. This way, the partitioner does not have to implement ad-hoc functions for specific operators (i.e. `is_valid_to_copy`) and graph transforms do not have to maintain their own operator metadata (`USES_WEIGHTS` in `insert_prepack_nodes`). Differential Revision: [D64915640](https://our.internmc.facebook.com/intern/diff/D64915640/) [ghstack-poisoned] --- backends/transforms/fuse_conv_with_clamp.py | 5 +- backends/transforms/targets.bzl | 2 +- backends/vulkan/TARGETS | 33 -- backends/vulkan/_passes/TARGETS | 27 +- .../vulkan/_passes/insert_prepack_nodes.py | 11 +- .../_passes/int4_weight_only_quantizer.py | 6 +- backends/vulkan/_passes/test_custom_ops.py | 124 ----- .../custom_ops_defs.py => custom_ops_lib.py} | 0 backends/vulkan/op_registry.py | 437 ++++++++++++++++++ backends/vulkan/partitioner/TARGETS | 2 +- backends/vulkan/partitioner/supported_ops.py | 159 ------- .../vulkan/partitioner/vulkan_partitioner.py | 39 +- backends/vulkan/serialization/TARGETS | 4 + backends/vulkan/serialization/targets.bzl | 24 + backends/vulkan/targets.bzl | 60 +++ backends/vulkan/test/test_vulkan_delegate.py | 2 +- .../source_transformation/vulkan_rope.py | 5 +- 17 files changed, 547 insertions(+), 393 deletions(-) delete mode 100644 backends/vulkan/_passes/test_custom_ops.py rename backends/vulkan/{_passes/custom_ops_defs.py => custom_ops_lib.py} (100%) create mode 100644 backends/vulkan/op_registry.py delete mode 100644 backends/vulkan/partitioner/supported_ops.py create mode 100644 backends/vulkan/serialization/TARGETS create mode 100644 backends/vulkan/serialization/targets.bzl diff --git a/backends/transforms/fuse_conv_with_clamp.py b/backends/transforms/fuse_conv_with_clamp.py index 3903fe1bdf4..3f45296b26c 100644 --- a/backends/transforms/fuse_conv_with_clamp.py +++ b/backends/transforms/fuse_conv_with_clamp.py @@ -6,10 +6,9 @@ import sys +import executorch.backends.vulkan.custom_ops_lib # noqa + import torch -from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa - conv_with_clamp_op, -) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 47c518a8637..14725636f3d 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -70,7 +70,7 @@ def define_common_targets(): deps = [ ":utils", "//caffe2:torch", - "//executorch/backends/vulkan/_passes:custom_ops_defs", + "//executorch/backends/vulkan:custom_ops_lib", "//executorch/exir:pass_base", "//executorch/exir:sym_util", "//executorch/exir/dialects:lib", diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 4e0e83f2763..41893d29274 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -1,37 +1,4 @@ -load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load(":targets.bzl", "define_common_targets") - oncall("executorch") define_common_targets(is_fbcode = True) - -runtime.python_library( - name = "vulkan_preprocess", - srcs = [ - "serialization/vulkan_graph_builder.py", - "serialization/vulkan_graph_schema.py", - "serialization/vulkan_graph_serialize.py", - "vulkan_preprocess.py", - ], - resources = [ - "serialization/schema.fbs", - ], - visibility = [ - "//executorch/...", - "//executorch/vulkan/...", - "@EXECUTORCH_CLIENTS", - ], - deps = [ - "//executorch/backends/transforms:addmm_mm_to_linear", - "//executorch/backends/transforms:fuse_batch_norm_with_conv", - "//executorch/backends/transforms:fuse_conv_with_clamp", - "//executorch/backends/transforms:fuse_dequant_linear", - "//executorch/backends/transforms:fuse_view_copy", - "//executorch/backends/transforms:remove_clone_ops", - "//executorch/backends/vulkan/_passes:vulkan_passes", - "//executorch/exir:graph_module", - "//executorch/exir/_serialize:_bindings", - "//executorch/exir/_serialize:lib", - "//executorch/exir/backend:backend_details", - ], -) diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index fa828640bf4..3f328deb485 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -3,31 +3,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") oncall("executorch") -runtime.python_library( - name = "custom_ops_defs", - srcs = [ - "custom_ops_defs.py", - ], - visibility = [ - "//executorch/...", - "@EXECUTORCH_CLIENTS", - ], - deps = [ - "//caffe2:torch", - ], -) - -python_unittest( - name = "test_custom_ops", - srcs = [ - "test_custom_ops.py", - ], - deps = [ - ":custom_ops_defs", - "//caffe2:torch", - ], -) - runtime.python_library( name = "insert_prepack_nodes", srcs = ["insert_prepack_nodes.py"], @@ -62,7 +37,7 @@ runtime.python_library( "//executorch/backends/...", ], deps = [ - ":custom_ops_defs", + "//executorch/backends/vulkan:custom_ops_lib", "//pytorch/ao:torchao", ] ) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index afedf7af694..4850850a409 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -8,10 +8,12 @@ from typing import List -import executorch.backends.vulkan._passes.custom_ops_defs # noqa +import executorch.backends.vulkan.custom_ops_lib # noqa import torch +from executorch.backends.vulkan.op_registry import handles_own_prepacking + from executorch.exir.dialects._ops import ops as exir_ops from torch._export.utils import is_buffer, is_param @@ -63,10 +65,9 @@ def is_non_weight_param_tensor(node: torch.fx.Node) -> bool: return False for user in node.users: - if user.op == "call_function" and ( - # pyre-ignore [16] - user.target in USES_WEIGHTS - or user.target.name() in USES_WEIGHTS + if user.op == "call_function" and handles_own_prepacking( + # noqa + user.target ): return False diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index a0d208bb63f..71d96533d2c 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -1,13 +1,11 @@ import logging from typing import Any, Callable, Dict, Optional, Type +import executorch.backends.vulkan.custom_ops_lib # noqa + import torch import torch.nn.functional as F -from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa - linear_weight_int4_op, -) - from torchao.quantization.GPTQ import _check_linear_int4_k from torchao.quantization.unified import Quantizer from torchao.quantization.utils import groupwise_affine_quantize_tensor diff --git a/backends/vulkan/_passes/test_custom_ops.py b/backends/vulkan/_passes/test_custom_ops.py deleted file mode 100644 index c68dd6d6796..00000000000 --- a/backends/vulkan/_passes/test_custom_ops.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch - -from .custom_ops_defs import conv_with_clamp_op # noqa - - -class TestCustomOps(unittest.TestCase): - def test_conv_with_clamp(self): - class ConvWithClamp(torch.nn.Module): - def __init__( - self, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - output_min, - output_max, - ): - super().__init__() - self.weight = weight - self.bias = bias - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = transposed - self.output_padding = output_padding - self.groups = groups - self.output_min = output_min - self.output_max = output_max - - def forward(self, x): - return torch.ops.et_vk.conv_with_clamp( - x, - self.weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.transposed, - self.output_padding, - self.groups, - self.output_min, - self.output_max, - ) - - model = ConvWithClamp( - weight=torch.randn(64, 64, 3, 3), - bias=torch.randn(64), - stride=[1], - padding=[0], - dilation=[1], - transposed=False, - output_padding=[0], - groups=1, - output_min=0, - output_max=float("inf"), - ) - x = torch.randn(2, 64, 10, 10) - custom_out = model(x) - - expected_out = torch.clamp( - torch.convolution( - x, - model.weight, - model.bias, - model.stride, - model.padding, - model.dilation, - model.transposed, - model.output_padding, - model.groups, - ), - min=model.output_min, - max=model.output_max, - ) - - self.assertEqual( - custom_out.shape, - expected_out.shape, - "custom op `conv_with_clamp` output shape matches expected", - ) - self.assertTrue(torch.allclose(custom_out, expected_out)) - - def test_grid_priors(self): - class GridPriors(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, stride, offset): - return torch.ops.et_vk.grid_priors(x, stride, offset) - - model = GridPriors() - sample_input = (torch.rand(2, 5, 2, 3), 4, 0.5) - custom_out = model(*sample_input) - - def calculate_expected_output(x, stride, offset): - height, width = x.shape[-2:] - shift_x = (torch.arange(0, width) + offset) * stride - shift_y = (torch.arange(0, height) + offset) * stride - shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x) - shift_xx = shift_xx.reshape(-1) - shift_yy = shift_yy.reshape(-1) - shifts = torch.stack((shift_yy, shift_xx), dim=-1) - return shifts - - expected_out = calculate_expected_output(*sample_input) - - self.assertEqual( - custom_out.shape, - expected_out.shape, - "custom op `grid_priors` output shape matches expected", - ) - self.assertTrue(torch.allclose(custom_out, expected_out)) diff --git a/backends/vulkan/_passes/custom_ops_defs.py b/backends/vulkan/custom_ops_lib.py similarity index 100% rename from backends/vulkan/_passes/custom_ops_defs.py rename to backends/vulkan/custom_ops_lib.py diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py new file mode 100644 index 00000000000..30cd9f376dd --- /dev/null +++ b/backends/vulkan/op_registry.py @@ -0,0 +1,437 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import operator + +from typing import Callable, Dict, List, Optional, Union + +import executorch.backends.vulkan.custom_ops_lib # noqa + +import torch + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from torch._subclasses.fake_tensor import FakeTensor + + +class TextureImplFeatures: + __slots__ = [ + # If the shader accounts for the packed dimension of the tensor, then + # all memory layouts are supported. + "uses_packed_dim", + # If the shader accounts for the axis map, then non standard memory + # layouts are supported. + "uses_axis_map", + # Specifies a specific set of memory layouts that the shader supports. + "supported_layouts", + ] + + def __init__( + self, + uses_packed_dim: bool = False, + uses_axis_map: bool = False, + supported_layouts: Optional[List[VkMemoryLayout]] = None, + ): + self.uses_packed_dim: bool = uses_packed_dim + self.uses_axis_map: bool = uses_axis_map + self.supported_layouts: Optional[List[VkMemoryLayout]] = supported_layouts + + +class OpFeatures: + __slots__ = [ + # None or TextureImplFeatures to specify implementation details of the texture + # based operator implementation. + "texture_impl", + # bool indicating if the operator has a buffer based implementation. + "buffer_impl", + # bool indicating if the operator has a resize function, which allows it to + # support dynamic shape tensors. + "resize_fn", + # bool indicating if the operator handles its own prepacking. If this is True, + # then the insert_prepack_nodes pass will not insert prepack nodes for the args + # of the op. + "handles_own_prepacking", + # Optional check function used during partitioning to determine if a node's + # inputs are supported by the operator implementation. + "check_node_fn", + ] + + def __init__( + self, + texture_impl: Optional[TextureImplFeatures] = None, + buffer_impl: bool = False, + resize_fn: bool = False, + handles_own_prepacking: bool = False, + check_node_fn: Optional[Callable] = None, + ): + self.texture_impl: Optional[TextureImplFeatures] = texture_impl + self.buffer_impl: bool = buffer_impl + self.resize_fn: bool = resize_fn + self.handles_own_prepacking: bool = handles_own_prepacking + self.check_node_fn: Optional[Callable] = check_node_fn + + +OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload] + +vulkan_supported_ops: Dict[OpKey, OpFeatures] = {} + + +def update_features(aten_op): + def features_decorator(fn: Callable): + def update_features_impl(op: OpKey): + if op not in vulkan_supported_ops: + vulkan_supported_ops[op] = OpFeatures() + vulkan_supported_ops[op] = fn(vulkan_supported_ops[op]) + + if isinstance(aten_op, list): + for op in aten_op: + update_features_impl(op) + else: + update_features_impl(aten_op) + + return fn + + return features_decorator + + +@update_features( + [ + operator.getitem, + # Quantization related ops will be fused via graph passes + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + ] +) +def register_ephemeral_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=True, + uses_axis_map=True, + ) + features.buffer_impl = True + features.resize_fn = True + return features + + +@update_features( + [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.pow.Tensor_Tensor, + ] +) +def register_binary_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=True, + uses_axis_map=True, + ) + features.resize_fn = True + return features + + +@update_features( + [ + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.cos.default, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.hardshrink.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.sin.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten._to_copy.default, + ] +) +def register_unary_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=True, + uses_axis_map=True, + ) + features.buffer_impl = True + features.resize_fn = True + return features + + +@update_features(exir_ops.edge.aten._to_copy.default) +def register_to_copy_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=True, + uses_axis_map=True, + ) + features.resize_fn = True + + def check_to_copy_node(node: torch.fx.Node) -> bool: + float_dtypes = [torch.float16, torch.float32] + + if len(node.args) != 1: + return False + + in_arg = node.args[0] + if not isinstance(in_arg, torch.fx.Node): + return False + + in_tensor = in_arg.meta.get("val", None) + out_tensor = node.meta.get("val", None) + + if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor): + if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes: + return True + + return False + + features.check_node_fn = check_to_copy_node + + return features + + +@update_features( + [ + exir_ops.edge.aten.bmm.default, + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.linear.default, + exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.aten._weight_int8pack_mm.default, + ] +) +def register_mm_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=False, + uses_axis_map=True, + supported_layouts=[ + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_CHANNELS_PACKED, + ], + ) + features.buffer_impl = True + features.resize_fn = True + features.handles_own_prepacking = True + return features + + +@update_features(exir_ops.edge.aten._weight_int8pack_mm.default) +def register_int8_mm_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=False, + uses_axis_map=False, + supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + ) + features.buffer_impl = True + features.resize_fn = True + features.handles_own_prepacking = True + return features + + +@update_features(exir_ops.edge.et_vk.linear_weight_int4.default) +def register_int4_mm_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=False, + uses_axis_map=False, + supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + ) + features.resize_fn = True + features.handles_own_prepacking = True + return features + + +@update_features( + [ + exir_ops.edge.aten._log_softmax.default, + exir_ops.edge.aten._softmax.default, + ] +) +def register_softmax_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=True, + ) + features.resize_fn = True + return features + + +@update_features( + [ + exir_ops.edge.aten._log_softmax.default, + exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.amax.default, + exir_ops.edge.aten.amin.default, + ] +) +def register_reduce_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=True, + ) + features.resize_fn = True + + def check_reduce_node(node: torch.fx.Node) -> bool: + dim_list = node.args[1] + assert isinstance(dim_list, list) + if len(dim_list) != 1: + return False + + keepdim = node.args[2] + assert isinstance(keepdim, bool) + if not keepdim: + return False + + return True + + features.check_node_fn = check_reduce_node + return features + + +@update_features( + [ + exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, + ] +) +def register_2d_pool_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + features.resize_fn = True + return features + + +@update_features( + [ + exir_ops.edge.aten.convolution.default, + exir_ops.edge.et_vk.conv_with_clamp.default, + ] +) +def register_convolution_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + features.resize_fn = True + features.handles_own_prepacking = True + return features + + +@update_features("llama::sdpa_with_kv_cache") +def register_sdpa_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + ) + features.resize_fn = True + features.handles_own_prepacking = True + return features + + +@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) +def register_rotary_emb_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + ) + features.resize_fn = True + return features + + +@update_features(exir_ops.edge.aten.view_copy.default) +def register_view_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_packed_dim=True, + ) + features.resize_fn = True + return features + + +# Ops ported from PyTorch Vulkan backend. These ops commonly support channels +# packed tensors only and do not have a resize function. +@update_features( + [ + # Normalization + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + exir_ops.edge.aten.native_layer_norm.default, + # Shape Manipulation + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.t_copy.default, + # Indexing and lookup + exir_ops.edge.aten.flip.default, + exir_ops.edge.aten.index_select.default, + exir_ops.edge.aten.select_copy.int, + exir_ops.edge.aten.slice_copy.Tensor, + # Tensor combination + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split.Tensor, + exir_ops.edge.aten.repeat.default, + # Tensor creation + exir_ops.edge.aten.arange.start_step, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.constant_pad_nd.default, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.full_like.default, + exir_ops.edge.aten.ones.default, + exir_ops.edge.aten.ones_like.default, + exir_ops.edge.aten.upsample_nearest2d.vec, + exir_ops.edge.aten.zeros.default, + exir_ops.edge.aten.zeros_like.default, + exir_ops.edge.et_vk.grid_priors.default, + ] +) +def register_ported_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + return features + + +# Ported ops that support their own prepacking. +@update_features( + [ + exir_ops.edge.aten.embedding.default, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + exir_ops.edge.aten.native_layer_norm.default, + ] +) +def register_ported_ops_with_prepacking(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + features.handles_own_prepacking = True + return features + + +## +## Utility Functions +## + + +def get_op_features(target: OpKey) -> OpFeatures: + if not isinstance(target, str): + if target not in vulkan_supported_ops: + # Check if the name of the op is in the dict + return vulkan_supported_ops.get(target.name(), OpFeatures()) + + return vulkan_supported_ops[target] + else: + return vulkan_supported_ops[target] + + +def handles_own_prepacking(target: OpKey) -> bool: + return get_op_features(target).handles_own_prepacking diff --git a/backends/vulkan/partitioner/TARGETS b/backends/vulkan/partitioner/TARGETS index b11104c5902..d68a82ade05 100644 --- a/backends/vulkan/partitioner/TARGETS +++ b/backends/vulkan/partitioner/TARGETS @@ -5,7 +5,6 @@ oncall("executorch") runtime.python_library( name = "vulkan_partitioner", srcs = [ - "supported_ops.py", "vulkan_partitioner.py", ], visibility = [ @@ -13,6 +12,7 @@ runtime.python_library( "@EXECUTORCH_CLIENTS", ], deps = [ + "//executorch/backends/vulkan:op_registry", "//executorch/backends/vulkan:vulkan_preprocess", "//executorch/exir:delegate", "//executorch/exir:lib", diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py deleted file mode 100644 index 5a85c5f0ec1..00000000000 --- a/backends/vulkan/partitioner/supported_ops.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import operator - -from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa - conv_with_clamp_op, - grid_priors_op, -) - -from executorch.exir.dialects._ops import ops as exir_ops - - -class OpFeatures: - __slots__ = ["supports_texture", "supports_buffer", "supports_dynamic_shape"] - - def __init__( - self, - supports_dynamic_shape: bool = False, - supports_buffer: bool = False, - supports_texture: bool = True, - ): - self.supports_dynamic_shape = supports_dynamic_shape - self.supports_texture = supports_texture - self.supports_buffer = supports_buffer - - -class OpList: - def __init__(self): - self._ops = {} - - def __getitem__(self, op): - if op not in self._ops: - self._ops[op] = OpFeatures() - return self._ops[op] - - def __contains__(self, op): - return op in self._ops - - -PRIM_OPS = [ - operator.getitem, - # Quantization related ops will be fused via graph passes - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, -] - -SUPPORTS_DYNAMIC_SHAPE = [ - # Binary broadcasting - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.minimum.default, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.div.Tensor, - exir_ops.edge.aten.div.Tensor_mode, - exir_ops.edge.aten.pow.Tensor_Tensor, - # Unary elementwise - exir_ops.edge.aten.abs.default, - exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.cos.default, - exir_ops.edge.aten.exp.default, - exir_ops.edge.aten.gelu.default, - exir_ops.edge.aten.hardshrink.default, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.neg.default, - exir_ops.edge.aten.relu.default, - exir_ops.edge.aten.sigmoid.default, - exir_ops.edge.aten.sin.default, - exir_ops.edge.aten.sqrt.default, - exir_ops.edge.aten.rsqrt.default, - exir_ops.edge.aten.tanh.default, - exir_ops.edge.aten._to_copy.default, - # Matrix Multiplication - exir_ops.edge.aten.bmm.default, - exir_ops.edge.aten.mm.default, - exir_ops.edge.aten.addmm.default, - exir_ops.edge.aten.linear.default, - exir_ops.edge.et_vk.linear_weight_int4.default, - exir_ops.edge.aten._weight_int8pack_mm.default, - # Reduction - exir_ops.edge.aten._log_softmax.default, - exir_ops.edge.aten._softmax.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten.amax.default, - exir_ops.edge.aten.amin.default, - # 2D Pooling - exir_ops.edge.aten.avg_pool2d.default, - exir_ops.edge.aten.max_pool2d_with_indices.default, - # Convolution - exir_ops.edge.aten.convolution.default, - exir_ops.edge.et_vk.conv_with_clamp.default, - # Llama ops - "llama::sdpa_with_kv_cache", - exir_ops.edge.et_vk.apply_rotary_emb.default, -] - -NO_DYNAMIC_SHAPE = [ - # Normalization - exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - exir_ops.edge.aten.native_layer_norm.default, - # Shape Manipulation - exir_ops.edge.aten.squeeze_copy.dims, - exir_ops.edge.aten.unsqueeze_copy.default, - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.t_copy.default, - # Indexing and lookup - exir_ops.edge.aten.embedding.default, - exir_ops.edge.aten.flip.default, - exir_ops.edge.aten.index_select.default, - exir_ops.edge.aten.select_copy.int, - exir_ops.edge.aten.slice_copy.Tensor, - # Tensor combination - exir_ops.edge.aten.cat.default, - exir_ops.edge.aten.split_with_sizes_copy.default, - exir_ops.edge.aten.split.Tensor, - exir_ops.edge.aten.repeat.default, - # Tensor creation - exir_ops.edge.aten.arange.start_step, - exir_ops.edge.aten.clone.default, - exir_ops.edge.aten.constant_pad_nd.default, - exir_ops.edge.aten.full.default, - exir_ops.edge.aten.full_like.default, - exir_ops.edge.aten.ones.default, - exir_ops.edge.aten.ones_like.default, - exir_ops.edge.aten.upsample_nearest2d.vec, - exir_ops.edge.aten.zeros.default, - exir_ops.edge.aten.zeros_like.default, - exir_ops.edge.et_vk.grid_priors.default, -] - - -def enumerate_supported_ops(): - ops = OpList() - - # Register in order of least to most capabilities - - for op in NO_DYNAMIC_SHAPE: - ops[op].supports_dynamic_shape = False - - for op in SUPPORTS_DYNAMIC_SHAPE: - ops[op].supports_dynamic_shape = True - - for op in PRIM_OPS: - ops[op].supports_texture = True - ops[op].supports_buffer = True - ops[op].supports_dynamic_shape = True - - return ops diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 109a61049d2..1ec7855c0d9 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -13,10 +13,7 @@ import torch -from executorch.backends.vulkan.partitioner.supported_ops import ( - enumerate_supported_ops, - OpList, -) +from executorch.backends.vulkan.op_registry import vulkan_supported_ops from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( @@ -43,8 +40,6 @@ class VulkanSupportedOperators(OperatorSupportBase): - _ops: OpList = enumerate_supported_ops() - def __init__(self, require_dynamic_shape: bool = False) -> None: super().__init__() self.require_dynamic_shapes = require_dynamic_shape @@ -144,25 +139,6 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool: return False - def is_valid_to_copy(self, node: torch.fx.Node) -> bool: - float_dtypes = [torch.float16, torch.float32] - - if len(node.args) != 1: - return False - - in_arg = node.args[0] - if not isinstance(in_arg, torch.fx.Node): - return False - - in_tensor = in_arg.meta.get("val", None) - out_tensor = node.meta.get("val", None) - - if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor): - if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes: - return True - - return False - def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: @@ -186,17 +162,16 @@ def _is_node_supported( if self.is_in_local_scalar_dense_chain(node): return True - if target not in VulkanSupportedOperators._ops: + if target not in vulkan_supported_ops: return False - if target == exir_ops.edge.aten._to_copy.default and not self.is_valid_to_copy( - node - ): - return False + features = vulkan_supported_ops[target] - features = VulkanSupportedOperators._ops[target] + if features.check_node_fn is not None: + if not features.check_node_fn(node): + return False - if self.require_dynamic_shapes and not features.supports_dynamic_shape: + if self.require_dynamic_shapes and not features.resize_fn: return False return self.all_args_compatible(node) diff --git a/backends/vulkan/serialization/TARGETS b/backends/vulkan/serialization/TARGETS new file mode 100644 index 00000000000..6cbd0fa8fac --- /dev/null +++ b/backends/vulkan/serialization/TARGETS @@ -0,0 +1,4 @@ +load(":targets.bzl", "define_common_targets") +oncall("executorch") + +define_common_targets() diff --git a/backends/vulkan/serialization/targets.bzl b/backends/vulkan/serialization/targets.bzl new file mode 100644 index 00000000000..8f04976a54b --- /dev/null +++ b/backends/vulkan/serialization/targets.bzl @@ -0,0 +1,24 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.python_library( + name = "lib", + srcs = [ + "vulkan_graph_builder.py", + "vulkan_graph_schema.py", + "vulkan_graph_serialize.py", + ], + resources = [ + "schema.fbs", + ], + visibility = [ + "//executorch/...", + "//executorch/vulkan/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/exir:graph_module", + "//executorch/exir/_serialize:_bindings", + "//executorch/exir/_serialize:lib", + ], + ) diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index f37534b089c..3d83ee0160e 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -236,3 +236,63 @@ def define_common_targets(is_fbcode = False): # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, ) + + ## + ## AOT targets + ## + + runtime.python_library( + name = "custom_ops_lib", + srcs = [ + "custom_ops_lib.py" + ], + visibility = [ + "//executorch/...", + "//executorch/vulkan/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + ] + ) + + runtime.python_library( + name = "op_registry", + srcs = [ + "op_registry.py", + ], + visibility = [ + "//executorch/...", + "//executorch/vulkan/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + ":custom_ops_lib", + "//caffe2:torch", + "//executorch/exir/dialects:lib", + "//executorch/backends/vulkan/serialization:lib", + ] + ) + + runtime.python_library( + name = "vulkan_preprocess", + srcs = [ + "vulkan_preprocess.py", + ], + visibility = [ + "//executorch/...", + "//executorch/vulkan/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/backends/transforms:addmm_mm_to_linear", + "//executorch/backends/transforms:fuse_batch_norm_with_conv", + "//executorch/backends/transforms:fuse_conv_with_clamp", + "//executorch/backends/transforms:fuse_dequant_linear", + "//executorch/backends/transforms:fuse_view_copy", + "//executorch/backends/transforms:remove_clone_ops", + "//executorch/backends/vulkan/_passes:vulkan_passes", + "//executorch/backends/vulkan/serialization:lib", + "//executorch/exir/backend:backend_details", + ], + ) diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 54db1a4b778..0512485c649 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1381,7 +1381,7 @@ def __init__(self): super().__init__() def forward(self, x): - # torch.t is actually exported as aten::permute. + # torch.t is actually exported as aten::permut. return torch.t(x) sample_inputs = (torch.randn(size=(3, 14), dtype=torch.float32),) diff --git a/examples/models/llama/source_transformation/vulkan_rope.py b/examples/models/llama/source_transformation/vulkan_rope.py index 0dce6aeb448..cdaf6f0baa7 100644 --- a/examples/models/llama/source_transformation/vulkan_rope.py +++ b/examples/models/llama/source_transformation/vulkan_rope.py @@ -4,12 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import executorch.backends.vulkan.custom_ops_lib # noqa import torch -from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa - apply_rotary_emb_op, -) - from executorch.examples.models.llama.rope import RotaryEmbedding