From b196b1e3dfff2c7cfb55cb4d2356308bf078fa0e Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 13 Oct 2025 13:18:02 -0700 Subject: [PATCH] Revert "[ET-VK] Add Fusing for Conv/Binary Ops, Clamp/Binary Ops, and Clamp/Clamp (#14415)" This reverts commit a5d7e5c2d9f619f3d1d11745e9fb4852fa74ca2c. --- .../transforms/fuse_clamp_with_binary_op.py | 123 --- backends/transforms/fuse_clamps.py | 105 --- backends/transforms/fuse_conv_with_clamp.py | 10 +- backends/transforms/targets.bzl | 32 - backends/vulkan/custom_ops_lib.py | 757 ------------------ backends/vulkan/op_registry.py | 8 - .../runtime/graph/ops/glsl/binary_op.glsl | 59 +- .../runtime/graph/ops/glsl/unary_op.glsl | 1 - .../runtime/graph/ops/impl/BinaryOp.cpp | 102 +-- backends/vulkan/targets.bzl | 2 - backends/vulkan/vulkan_preprocess.py | 10 +- 11 files changed, 19 insertions(+), 1190 deletions(-) delete mode 100644 backends/transforms/fuse_clamp_with_binary_op.py delete mode 100644 backends/transforms/fuse_clamps.py diff --git a/backends/transforms/fuse_clamp_with_binary_op.py b/backends/transforms/fuse_clamp_with_binary_op.py deleted file mode 100644 index 4155b2b7458..00000000000 --- a/backends/transforms/fuse_clamp_with_binary_op.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import sys - -import executorch.backends.vulkan.custom_ops_lib # noqa - -import torch - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - - -class FuseClampBinaryOpPass(ExportPass): - - FUSEABLE_CLAMP_OPS = [ - exir_ops.edge.aten.relu.default, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.clamp.default, - ] - FUSEABLE_BINARY_OPS = [ - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.div.Tensor, - ] - - def exists_before(self, graph_module, node_a, node_b): - seen_a = False - for n in graph_module.graph.nodes: - if n is node_a: - seen_a = True - if n is node_b: - return seen_a - return False - - def get_output_min_max_from_activation(self, activation_node): - if activation_node.target == exir_ops.edge.aten.relu.default: - output_min = 0.0 - output_max = sys.float_info.max - elif activation_node.target == exir_ops.edge.aten.hardtanh.default: - output_min = -1.0 - output_max = 1.0 - if len(activation_node.args) > 1: - output_min = activation_node.args[1] - output_max = activation_node.args[2] - elif activation_node.target == exir_ops.edge.aten.clamp.default: - output_min = None - output_max = None - if len(activation_node.args) >= 2: - output_min = activation_node.args[1] - if len(activation_node.args) >= 3: - output_max = activation_node.args[2] - - return output_min, output_max - - def fuse_binary_op_with_clamp(self, graph_module: torch.fx.GraphModule): - fuseAdded = False - for clamp_node in graph_module.graph.nodes: - if clamp_node.op == "call_function": - if clamp_node.target in self.FUSEABLE_CLAMP_OPS: - preceding_op = clamp_node.args[0] - - if ( - preceding_op.op == "call_function" - and preceding_op.target in self.FUSEABLE_BINARY_OPS - ): - # Delete activation - output_min_max = self.get_output_min_max_from_activation( - clamp_node - ) - new_args = list(preceding_op.args) - new_args.append(output_min_max[0]) - new_args.append(output_min_max[1]) - new_args = tuple(new_args) - clamp_node.replace_all_uses_with(preceding_op) - graph_module.graph.erase_node(clamp_node) - - new_op = None - match preceding_op.target: - case exir_ops.edge.aten.add.Tensor: - new_op = ( - exir_ops.edge.et_vk.binary_add_with_clamp.default - ) - case exir_ops.edge.aten.sub.Tensor: - new_op = ( - exir_ops.edge.et_vk.binary_sub_with_clamp.default - ) - case exir_ops.edge.aten.mul.Tensor: - new_op = ( - exir_ops.edge.et_vk.binary_mul_with_clamp.default - ) - case exir_ops.edge.aten.div.Tensor: - new_op = ( - exir_ops.edge.et_vk.binary_div_with_clamp.default - ) - - # Create and insert node of custom op `binary__with_clamp` - with graph_module.graph.inserting_before(preceding_op): - binary_op_clamp_node = graph_module.graph.create_node( - "call_function", - new_op, - new_args, - ) - - preceding_op.replace_all_uses_with(binary_op_clamp_node) - graph_module.graph.erase_node(preceding_op) - - fuseAdded = True - - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - return [fuseAdded, graph_module] - - def call(self, graph_module: torch.fx.GraphModule): - fuseAdded = True - while fuseAdded: - fuseAdded, graph_module = self.fuse_binary_op_with_clamp(graph_module) - - return PassResult(graph_module, True) diff --git a/backends/transforms/fuse_clamps.py b/backends/transforms/fuse_clamps.py deleted file mode 100644 index 6e5be508d54..00000000000 --- a/backends/transforms/fuse_clamps.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import sys - -import executorch.backends.vulkan.custom_ops_lib # noqa - -import torch - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - - -class FuseClampsPass(ExportPass): - - FUSEABLE_CLAMPS = [ - exir_ops.edge.aten.relu.default, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.clamp.default, - ] - - def get_output_min_max_from_activation(self, activation_node): - if activation_node.target == exir_ops.edge.aten.relu.default: - output_min = 0.0 - output_max = sys.float_info.max - elif activation_node.target == exir_ops.edge.aten.hardtanh.default: - output_min = -1.0 - output_max = 1.0 - if len(activation_node.args) > 1: - output_min = activation_node.args[1] - output_max = activation_node.args[2] - elif activation_node.target == exir_ops.edge.aten.clamp.default: - output_min = None - output_max = None - if len(activation_node.args) >= 2: - output_min = activation_node.args[1] - if len(activation_node.args) >= 3: - output_max = activation_node.args[2] - - return output_min, output_max - - def call(self, graph_module: torch.fx.GraphModule): - fuseAdded = True - while fuseAdded: - fuseAdded = False - for clamp_2_node in graph_module.graph.nodes: - if clamp_2_node.op == "call_function": - if clamp_2_node.target in self.FUSEABLE_CLAMPS: - preceding_op = clamp_2_node.args[0] - if ( - preceding_op.op == "call_function" - and preceding_op.target in self.FUSEABLE_CLAMPS - ): - # Ensure the shapes match - if ( - "val" not in clamp_2_node.args[0].meta - or "val" not in preceding_op.args[0].meta - ): - continue - if len(clamp_2_node.args[0].meta["val"].shape) != len( - preceding_op.args[0].meta["val"].shape - ): - continue - - min_max1 = self.get_output_min_max_from_activation( - preceding_op - ) - min_max2 = self.get_output_min_max_from_activation( - clamp_2_node - ) - - min_max = [None, None] - - if min_max1[0] is None and min_max2[0] is not None: - min_max[0] = min_max2[0] - elif min_max1[0] is not None and min_max2[0] is None: - min_max[0] = min_max1[0] - else: - min_max[0] = min(min_max1[0], min_max2[0]) - - if min_max1[1] is None and min_max2[1] is not None: - min_max[1] = min_max2[1] - elif min_max1[1] is not None and min_max2[1] is None: - min_max[1] = min_max1[1] - else: - min_max[1] = max(min_max1[1], min_max2[1]) - - new_args = list(preceding_op.args) - - # Insert the new min/max at indices 1 and 2 - new_args.insert(1, min_max[0]) - new_args.insert(2, min_max[1]) - new_args = new_args[0:3] - preceding_op.args = tuple(new_args) - clamp_2_node.replace_all_uses_with(preceding_op) - graph_module.graph.erase_node(clamp_2_node) - fuseAdded = True - - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) diff --git a/backends/transforms/fuse_conv_with_clamp.py b/backends/transforms/fuse_conv_with_clamp.py index 52fc1f4a413..3f45296b26c 100644 --- a/backends/transforms/fuse_conv_with_clamp.py +++ b/backends/transforms/fuse_conv_with_clamp.py @@ -14,7 +14,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class FuseConvClampPass(ExportPass): +class FuseClampPass(ExportPass): """ Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it. """ @@ -25,7 +25,6 @@ class FuseConvClampPass(ExportPass): FUSEABLE_ACTIVATIONS = [ exir_ops.edge.aten.relu.default, exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.clamp.default, ] def get_output_min_max_from_activation(self, activation_node): @@ -38,13 +37,6 @@ def get_output_min_max_from_activation(self, activation_node): if len(activation_node.args) > 1: output_min = activation_node.args[1] output_max = activation_node.args[2] - elif activation_node.target == exir_ops.edge.aten.clamp.default: - output_min = None - output_max = None - if len(activation_node.args) >= 2: - output_min = activation_node.args[1] - if len(activation_node.args) >= 3: - output_max = activation_node.args[2] return output_min, output_max diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index f354f2234bd..ca09d34c2fe 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -77,38 +77,6 @@ def define_common_targets(): ], ) - runtime.python_library( - name = "fuse_clamps", - srcs = ["fuse_clamps.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - ":utils", - "//caffe2:torch", - "//executorch/backends/vulkan:custom_ops_lib", - "//executorch/exir:pass_base", - "//executorch/exir:sym_util", - "//executorch/exir/dialects:lib", - ], - ) - - runtime.python_library( - name = "fuse_clamp_with_binary_op", - srcs = ["fuse_clamp_with_binary_op.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - ":utils", - "//caffe2:torch", - "//executorch/backends/vulkan:custom_ops_lib", - "//executorch/exir:pass_base", - "//executorch/exir:sym_util", - "//executorch/exir/dialects:lib", - ], - ) - runtime.python_library( name = "view_copy_to_squeeze_unsqueeze", srcs = ["view_copy_to_squeeze_unsqueeze.py"], diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 56d882fa075..6e5aa926d37 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -109,763 +109,6 @@ def conv_with_clamp_out_impl( ) lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd") -########################## -## conv_with_binary_add ## -########################## - - -def conv_with_binary_add_impl( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - transposed=False, - output_padding=0, - groups=1, - other=None, -): - return torch.add( - torch.convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ), - other, - ) - - -name = "conv_with_binary_add" -lib.define( - f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" -) -lib.impl(name, conv_with_binary_add_impl, "CompositeExplicitAutograd") -conv_with_binary_add_op = getattr(getattr(torch.ops, namespace), name) - -############################# -## conv_with_binary_add.out ## -############################# - - -def conv_with_binary_add_out_impl( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - transposed=False, - output_padding=0, - groups=1, - other=None, - out=None, -): - out = conv_with_binary_add_impl( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - other, - ) - return out - - -name = "conv_with_binary_add.out" -lib.define( - f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, conv_with_binary_add_out_impl, "CompositeExplicitAutograd") - -########################## -## conv_with_binary_sub ## -########################## - - -def conv_with_binary_sub_impl( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - transposed=False, - output_padding=0, - groups=1, - other=None, -): - return torch.sub( - torch.convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ), - other, - ) - - -name = "conv_with_binary_sub" -lib.define( - f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" -) -lib.impl(name, conv_with_binary_sub_impl, "CompositeExplicitAutograd") -conv_with_binary_sub_op = getattr(getattr(torch.ops, namespace), name) - -############################## -## conv_with_binary_sub.out ## -############################## - - -def conv_with_binary_sub_out_impl( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - transposed=False, - output_padding=0, - groups=1, - other=None, - out=None, -): - out = conv_with_binary_sub_impl( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - other, - ) - return out - - -name = "conv_with_binary_sub.out" -lib.define( - f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, conv_with_binary_sub_out_impl, "CompositeExplicitAutograd") - -########################## -## conv_with_binary_mul ## -########################## - - -def conv_with_binary_mul_impl( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - transposed=False, - output_padding=0, - groups=1, - other=None, -): - return torch.mul( - torch.convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ), - other, - ) - - -name = "conv_with_binary_mul" -lib.define( - f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" -) -lib.impl(name, conv_with_binary_mul_impl, "CompositeExplicitAutograd") -conv_with_binary_mul_op = getattr(getattr(torch.ops, namespace), name) - -############################## -## conv_with_binary_mul.out ## -############################## - - -def conv_with_binary_mul_out_impl( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - transposed=False, - output_padding=0, - groups=1, - other=None, - out=None, -): - out = conv_with_binary_mul_impl( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - other, - ) - return out - - -name = "conv_with_binary_mul.out" -lib.define( - f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, conv_with_binary_mul_out_impl, "CompositeExplicitAutograd") - -########################## -## conv_with_binary_div ## -########################## - - -def conv_with_binary_div_impl( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - transposed=False, - output_padding=0, - groups=1, - other=None, -): - return torch.div( - torch.convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ), - other, - ) - - -name = "conv_with_binary_div" -lib.define( - f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" -) -lib.impl(name, conv_with_binary_div_impl, "CompositeExplicitAutograd") -conv_with_binary_div_op = getattr(getattr(torch.ops, namespace), name) - -############################## -## conv_with_binary_div.out ## -############################## - - -def conv_with_binary_div_out_impl( - input, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - transposed=False, - output_padding=0, - groups=1, - other=None, - out=None, -): - out = conv_with_binary_div_impl( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - other, - ) - return out - - -name = "conv_with_binary_div.out" -lib.define( - f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, conv_with_binary_div_out_impl, "CompositeExplicitAutograd") - -########################### -## clamp_with_binary_add ## -########################### - - -def clamp_with_binary_add_impl( - input, - output_min=-float("inf"), - output_max=float("inf"), - other=None, -): - return torch.add( - torch.clamp( - input, - output_min, - output_max, - ), - other, - ) - - -name = "clamp_with_binary_add" -lib.define( - f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" -) -lib.impl(name, clamp_with_binary_add_impl, "CompositeExplicitAutograd") -clamp_with_binary_add_op = getattr(getattr(torch.ops, namespace), name) - -############################### -## clamp_with_binary_add.out ## -############################### - - -def clamp_with_binary_add_out_impl( - input, - output_min=-float("inf"), - output_max=float("inf"), - other=None, - out=None, -): - out = clamp_with_binary_add_impl( - input, - output_min, - output_max, - other, - ) - return out - - -name = "clamp_with_binary_add.out" -lib.define( - f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, clamp_with_binary_add_out_impl, "CompositeExplicitAutograd") - -########################### -## clamp_with_binary_sub ## -########################### - - -def clamp_with_binary_sub_impl( - input, - output_min=-float("inf"), - output_max=float("inf"), - other=None, -): - return torch.sub( - torch.clamp( - input, - output_min, - output_max, - ), - other, - ) - - -name = "clamp_with_binary_sub" -lib.define( - f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" -) -lib.impl(name, clamp_with_binary_sub_impl, "CompositeExplicitAutograd") -clamp_with_binary_sub_op = getattr(getattr(torch.ops, namespace), name) - -############################### -## clamp_with_binary_sub.out ## -############################### - - -def clamp_with_binary_sub_out_impl( - input, - output_min=-float("inf"), - output_max=float("inf"), - other=None, - out=None, -): - out = clamp_with_binary_sub_impl( - input, - output_min, - output_max, - other, - ) - return out - - -name = "clamp_with_binary_sub.out" -lib.define( - f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, clamp_with_binary_sub_out_impl, "CompositeExplicitAutograd") - -########################### -## clamp_with_binary_mul ## -########################### - - -def clamp_with_binary_mul_impl( - input, - output_min=-float("inf"), - output_max=float("inf"), - other=None, -): - return torch.mul( - torch.clamp( - input, - output_min, - output_max, - ), - other, - ) - - -name = "clamp_with_binary_mul" -lib.define( - f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" -) -lib.impl(name, clamp_with_binary_mul_impl, "CompositeExplicitAutograd") -clamp_with_binary_mul_op = getattr(getattr(torch.ops, namespace), name) - -############################### -## clamp_with_binary_mul.out ## -############################### - - -def clamp_with_binary_mul_out_impl( - input, - output_min=-float("inf"), - output_max=float("inf"), - other=None, - out=None, -): - out = clamp_with_binary_mul_impl( - input, - output_min, - output_max, - other, - ) - return out - - -name = "clamp_with_binary_mul.out" -lib.define( - f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, clamp_with_binary_mul_out_impl, "CompositeExplicitAutograd") - -########################### -## clamp_with_binary_div ## -########################### - - -def clamp_with_binary_div_impl( - input, - output_min=-float("inf"), - output_max=float("inf"), - other=None, -): - return torch.div( - torch.clamp( - input, - output_min, - output_max, - ), - other, - ) - - -name = "clamp_with_binary_div" -lib.define( - f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" -) -lib.impl(name, clamp_with_binary_div_impl, "CompositeExplicitAutograd") -clamp_with_binary_div_op = getattr(getattr(torch.ops, namespace), name) - -############################### -## clamp_with_binary_div.out ## -############################### - - -def clamp_with_binary_div_out_impl( - input, - output_min=-float("inf"), - output_max=float("inf"), - other=None, - out=None, -): - out = clamp_with_binary_div_impl( - input, - output_min, - output_max, - other, - ) - return out - - -name = "clamp_with_binary_div.out" -lib.define( - f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, clamp_with_binary_div_out_impl, "CompositeExplicitAutograd") - -########################### -## binary_add_with_clamp ## -########################### - - -def binary_add_with_clamp_impl( - input, - other=None, - output_min=-float("inf"), - output_max=float("inf"), -): - return torch.clamp( - torch.add( - input, - other, - ), - output_min, - output_max, - ) - - -name = "binary_add_with_clamp" -lib.define( - f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" -) -lib.impl(name, binary_add_with_clamp_impl, "CompositeExplicitAutograd") -binary_add_with_clamp_op = getattr(getattr(torch.ops, namespace), name) - -############################### -## binary_add_with_clamp.out ## -############################### - - -def binary_add_with_clamp_out_impl( - input, - other=None, - output_min=-float("inf"), - output_max=float("inf"), - out=None, -): - out = binary_add_with_clamp_impl( - input, - output_min, - output_max, - other, - ) - return out - - -name = "binary_add_with_clamp.out" -lib.define( - f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, binary_add_with_clamp_impl, "CompositeExplicitAutograd") - -########################### -## binary_sub_with_clamp ## -########################### - - -def binary_sub_with_clamp_impl( - input, - other=None, - output_min=-float("inf"), - output_max=float("inf"), -): - return torch.clamp( - torch.sub( - input, - other, - ), - output_min, - output_max, - ) - - -name = "binary_sub_with_clamp" -lib.define( - f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" -) -lib.impl(name, binary_sub_with_clamp_impl, "CompositeExplicitAutograd") -binary_sub_with_clamp_op = getattr(getattr(torch.ops, namespace), name) - -############################### -## binary_sub_with_clamp.out ## -############################### - - -def binary_sub_with_clamp_out_impl( - input, - other=None, - output_min=-float("inf"), - output_max=float("inf"), - out=None, -): - out = binary_sub_with_clamp_impl( - input, - output_min, - output_max, - other, - ) - return out - - -name = "binary_sub_with_clamp.out" -lib.define( - f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, binary_sub_with_clamp_impl, "CompositeExplicitAutograd") - -########################### -## binary_mul_with_clamp ## -########################### - - -def binary_mul_with_clamp_impl( - input, - other=None, - output_min=-float("inf"), - output_max=float("inf"), -): - return torch.clamp( - torch.mul( - input, - other, - ), - output_min, - output_max, - ) - - -name = "binary_mul_with_clamp" -lib.define( - f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" -) -lib.impl(name, binary_mul_with_clamp_impl, "CompositeExplicitAutograd") -binary_mul_with_clamp_op = getattr(getattr(torch.ops, namespace), name) - -############################### -## binary_mul_with_clamp.out ## -############################### - - -def binary_mul_with_clamp_out_impl( - input, - other=None, - output_min=-float("inf"), - output_max=float("inf"), - out=None, -): - out = binary_mul_with_clamp_impl( - input, - output_min, - output_max, - other, - ) - return out - - -name = "binary_mul_with_clamp.out" -lib.define( - f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, binary_mul_with_clamp_impl, "CompositeExplicitAutograd") - -########################### -## binary_div_with_clamp ## -########################### - - -def binary_div_with_clamp_impl( - input, - other=None, - output_min=-float("inf"), - output_max=float("inf"), -): - return torch.clamp( - torch.div( - input, - other, - ), - output_min, - output_max, - ) - - -name = "binary_div_with_clamp" -lib.define( - f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" -) -lib.impl(name, binary_div_with_clamp_impl, "CompositeExplicitAutograd") -binary_div_with_clamp_op = getattr(getattr(torch.ops, namespace), name) - -############################### -## binary_div_with_clamp.out ## -############################### - - -def binary_div_with_clamp_out_impl( - input, - other=None, - output_min=-float("inf"), - output_max=float("inf"), - out=None, -): - out = binary_div_with_clamp_impl( - input, - output_min, - output_max, - other, - ) - return out - - -name = "binary_div_with_clamp.out" -lib.define( - f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.impl(name, binary_div_with_clamp_impl, "CompositeExplicitAutograd") - - ################# ## grid_priors ## ################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 85d14b30e88..63b57a0e79c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -219,10 +219,6 @@ def register_torchao_choose_qparams_affine(): exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.ge.Tensor, - exir_ops.edge.et_vk.binary_add_with_clamp.default, - exir_ops.edge.et_vk.binary_sub_with_clamp.default, - exir_ops.edge.et_vk.binary_mul_with_clamp.default, - exir_ops.edge.et_vk.binary_div_with_clamp.default, ] ) def register_binary_op(): @@ -250,10 +246,6 @@ def register_binary_op(): exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.round.default, exir_ops.edge.aten.leaky_relu.default, - exir_ops.edge.et_vk.clamp_with_binary_add.default, - exir_ops.edge.et_vk.clamp_with_binary_sub.default, - exir_ops.edge.et_vk.clamp_with_binary_mul.default, - exir_ops.edge.et_vk.clamp_with_binary_div.default, ] ) def register_unary_op(): diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl index ed420fcc72f..6f2a93667ea 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -69,9 +69,6 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} ${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} ${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "clamp_type", "0")} -${layout_declare_spec_const(C, "float", "min_val", "0")} -${layout_declare_spec_const(C, "float", "max_val", "0")} $if STORAGE == "buffer": const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); @@ -93,20 +90,7 @@ void main() { // Simple case; no broadcasting if (are_equal(inp, other)) { - T in_val = T(t_in[out_bufi]); - T other_val = T(t_other[out_bufi]); - if (clamp_type == 1) { - in_val = T(clamp(in_val, T(min_val), T(max_val))); - } - else if (clamp_type == 2) { - other_val = T(clamp(other_val, T(min_val), T(max_val))); - } - T out_val = T(op(in_val, other_val, T(alpha))); - if (clamp_type == 3) { - out_val = T(clamp(out_val, T(min_val), T(max_val))); - } - t_out[out_bufi] = out_val; - + t_out[out_bufi] = T(op(t_in[out_bufi], t_other[out_bufi], T(alpha))); return; } @@ -122,19 +106,7 @@ void main() { uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); uint other_bufi = tensor_idx_to_linear_idx(other, other_tidx); - T in_val = T(t_in[inp_bufi]); - T other_val = T(t_other[other_bufi]); - if (clamp_type == 1) { - in_val = T(clamp(in_val, T(min_val), T(max_val))); - } - else if (clamp_type == 2) { - other_val = T(clamp(other_val, T(min_val), T(max_val))); - } - T out_val = T(op(in_val, other_val, T(alpha))); - if (clamp_type == 3) { - out_val = T(clamp(out_val, T(min_val), T(max_val))); - } - t_out[out_bufi] = out_val; + t_out[out_bufi] = T(op(t_in[inp_bufi], t_other[other_bufi], T(alpha))); } #else // USING_TEXTURE @@ -154,10 +126,6 @@ void main() { // read axis mapped texel tidx_to_pos(in_idx, in_sizes, in_axis_map, packed_dim))); - if (clamp_type == 1) { - in_texel = clamp(in_texel, VEC4_T(min_val), VEC4_T(max_val)); - } - // broadcast on logical sizes ivec4 other_idx = broadcast_indices(tidx, other_sizes); VEC4_T other_texel = VEC4_T(load_texel( @@ -165,10 +133,6 @@ void main() { // read axis mapped texel tidx_to_pos(other_idx, other_sizes, other_axis_map, packed_dim))); - if (clamp_type == 2) { - in_texel = clamp(other_texel, VEC4_T(min_val), VEC4_T(max_val)); - } - // Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment. if (broadcast_params.x > 0) { in_texel = in_texel.xxxx; @@ -177,20 +141,11 @@ void main() { other_texel = other_texel.xxxx; } - if (clamp_type != 3) { - write_texel_lpos( - t_out, - lpos, - VEC4_OUT_T(op(in_texel, other_texel, alpha)), - out_axis_map); - } - else { - write_texel_lpos( - t_out, - lpos, - VEC4_OUT_T(clamp(VEC4_OUT_T(op(in_texel, other_texel, alpha)), min_val, max_val)), - out_axis_map); - } + write_texel_lpos( + t_out, + lpos, + VEC4_OUT_T(op(in_texel, other_texel, alpha)), + out_axis_map); } #endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl index 5bc01fa7f57..bb7ce482a7a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl @@ -61,7 +61,6 @@ void main() { } VEC4_T in_texel = texelFetch(t_in, pos, 0); - imageStore(t_out, pos, VEC4_T(op(in_texel, minimum, maximum))); } diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 9575ca0dcdd..025b483eab7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -54,39 +54,13 @@ void resize_binary_op_node( graph->virtual_resize(out, new_out_sizes); } -int remove_clamp_from_name(std::string& op) { - if (op.find("clamp_0_with_") != std::string::npos) { - op.erase(op.find("clamp_0_with_"), 13); - - // Clamp input 0 - return 1; - } - if (op.find("clamp_1_with_") != std::string::npos) { - op.erase(op.find("clamp_1_with_"), 13); - - // Clamp input 1 - return 2; - } - if (op.find("_with_clamp") != std::string::npos) { - op.erase(op.find("_with_clamp"), 11); - - // Clamp output - return 3; - } - - // No clamp - return 0; -} - void add_binary_op_texture_node( ComputeGraph& graph, const ValueRef in1, const ValueRef in2, const ValueRef alpha, const ValueRef out, - const std::string& op_name, - const float min, - const float max) { + const std::string& op_name) { ValueRef arg1 = prepack_standard_like(graph, in1, out, true); ValueRef arg2 = prepack_standard_like(graph, in2, out, true); @@ -106,10 +80,7 @@ void add_binary_op_texture_node( std::string kernel_name("binary_"); kernel_name.reserve(kShaderNameReserve); - - std::string op = op_name; - int clamp_type = remove_clamp_from_name(op); - kernel_name += op; + kernel_name += op_name; add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(in1)); @@ -130,10 +101,7 @@ void add_binary_op_texture_node( // Specialization Constants {graph.hashed_layout_of(out), graph.hashed_layout_of(arg1), - graph.hashed_layout_of(arg2), - clamp_type, - min, - max}, + graph.hashed_layout_of(arg2)}, // Resize Args {}, // Resizing Logic @@ -146,9 +114,7 @@ void add_binary_op_buffer_node( const ValueRef in2, const ValueRef alpha, const ValueRef out, - const std::string& op_name, - const float min, - const float max) { + const std::string& op_name) { // check_binary_op_args(*t_in1, *t_in2, *t_out); float alpha_val = 1.0f; @@ -160,9 +126,7 @@ void add_binary_op_buffer_node( std::string kernel_name("binary_"); kernel_name.reserve(kShaderNameReserve); - std::string op = op_name; - int clamp_type = remove_clamp_from_name(op); - kernel_name += op; + kernel_name += op_name; add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(in1)); @@ -185,9 +149,7 @@ void add_binary_op_buffer_node( // Specialization Constants {graph.hashed_layout_of(out), graph.hashed_layout_of(in1), - graph.hashed_layout_of(in2), - min, - max}, + graph.hashed_layout_of(in2)}, // Resize Args {}, // Resizing Logic @@ -200,13 +162,11 @@ void add_binary_op_node( const ValueRef in2, const ValueRef alpha, const ValueRef out, - const std::string& op_name, - const float min = std::numeric_limits::infinity(), - const float max = -std::numeric_limits::infinity()) { + const std::string& op_name) { if (graph.is_buffer_storage(out)) { - add_binary_op_buffer_node(graph, in1, in2, alpha, out, op_name, min, max); + add_binary_op_buffer_node(graph, in1, in2, alpha, out, op_name); } else { - add_binary_op_texture_node(graph, in1, in2, alpha, out, op_name, min, max); + add_binary_op_texture_node(graph, in1, in2, alpha, out, op_name); } } @@ -222,40 +182,6 @@ void add_binary_op_node( graph, args[0], args[1], kDummyValueRef, args[2], #op_name); \ } -float get_val_or_inf_(ComputeGraph& graph, const ValueRef& val, bool max) { - if (!graph.val_is_none(val)) { - return graph.extract_scalar(val); - } - return max ? std::numeric_limits::infinity() - : -std::numeric_limits::infinity(); -} - -#define DEFINE_BINARY_OP_WITH_ALPHA_FN_CLAMPED(op_name) \ - void op_name(ComputeGraph& graph, const std::vector& args) { \ - return add_binary_op_node( \ - graph, \ - args[0], \ - args[1], \ - args[2], \ - args[5], \ - #op_name, \ - get_val_or_inf_(graph, args[3], false), \ - get_val_or_inf_(graph, args[4], true)); \ - } - -#define DEFINE_BINARY_OP_FN_CLAMPED(op_name) \ - void op_name(ComputeGraph& graph, const std::vector& args) { \ - return add_binary_op_node( \ - graph, \ - args[0], \ - args[1], \ - kDummyValueRef, \ - args[4], \ - #op_name, \ - get_val_or_inf_(graph, args[2], false), \ - get_val_or_inf_(graph, args[3], true)); \ - } - DEFINE_BINARY_OP_WITH_ALPHA_FN(add); DEFINE_BINARY_OP_WITH_ALPHA_FN(sub); @@ -273,11 +199,6 @@ DEFINE_BINARY_OP_FN(le); DEFINE_BINARY_OP_FN(gt); DEFINE_BINARY_OP_FN(ge); -DEFINE_BINARY_OP_FN_CLAMPED(add_with_clamp); -DEFINE_BINARY_OP_FN_CLAMPED(sub_with_clamp); -DEFINE_BINARY_OP_FN_CLAMPED(mul_with_clamp); -DEFINE_BINARY_OP_FN_CLAMPED(div_with_clamp); - REGISTER_OPERATORS { VK_REGISTER_OP(aten.add.Tensor, add); VK_REGISTER_OP(aten.sub.Tensor, sub); @@ -291,11 +212,6 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.le.Tensor, le); VK_REGISTER_OP(aten.gt.Tensor, gt); VK_REGISTER_OP(aten.ge.Tensor, ge); - - VK_REGISTER_OP(et_vk.binary_add_with_clamp.default, add_with_clamp); - VK_REGISTER_OP(et_vk.binary_sub_with_clamp.default, sub_with_clamp); - VK_REGISTER_OP(et_vk.binary_mul_with_clamp.default, mul_with_clamp); - VK_REGISTER_OP(et_vk.binary_div_with_clamp.default, div_with_clamp); } } // namespace vkcompute diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 42173e587ac..c48ce0a452b 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -392,8 +392,6 @@ def define_common_targets(is_fbcode = False): deps = [ "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:fuse_batch_norm_with_conv", - "//executorch/backends/transforms:fuse_clamp_with_binary_op", - "//executorch/backends/transforms:fuse_clamps", "//executorch/backends/transforms:fuse_conv_with_clamp", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:remove_clone_ops", diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index d23f0a29126..876f7fa8900 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -13,11 +13,7 @@ import executorch.backends.vulkan.utils as utils from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform -from executorch.backends.transforms.fuse_clamp_with_binary_op import ( - FuseClampBinaryOpPass, -) -from executorch.backends.transforms.fuse_clamps import FuseClampsPass -from executorch.backends.transforms.fuse_conv_with_clamp import FuseConvClampPass +from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import ( ViewCopyToSqueezeUnsqueezePass, @@ -173,9 +169,7 @@ def preprocess( # noqa: C901 [ FuseBatchNormPass(program), FusePatternsPass(), - FuseClampsPass(), - FuseConvClampPass(), - FuseClampBinaryOpPass(), + FuseClampPass(), AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(),