diff --git a/backends/transforms/fuse_clamp_with_binary_op.py b/backends/transforms/fuse_clamp_with_binary_op.py new file mode 100644 index 00000000000..4155b2b7458 --- /dev/null +++ b/backends/transforms/fuse_clamp_with_binary_op.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import executorch.backends.vulkan.custom_ops_lib # noqa + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class FuseClampBinaryOpPass(ExportPass): + + FUSEABLE_CLAMP_OPS = [ + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.clamp.default, + ] + FUSEABLE_BINARY_OPS = [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + ] + + def exists_before(self, graph_module, node_a, node_b): + seen_a = False + for n in graph_module.graph.nodes: + if n is node_a: + seen_a = True + if n is node_b: + return seen_a + return False + + def get_output_min_max_from_activation(self, activation_node): + if activation_node.target == exir_ops.edge.aten.relu.default: + output_min = 0.0 + output_max = sys.float_info.max + elif activation_node.target == exir_ops.edge.aten.hardtanh.default: + output_min = -1.0 + output_max = 1.0 + if len(activation_node.args) > 1: + output_min = activation_node.args[1] + output_max = activation_node.args[2] + elif activation_node.target == exir_ops.edge.aten.clamp.default: + output_min = None + output_max = None + if len(activation_node.args) >= 2: + output_min = activation_node.args[1] + if len(activation_node.args) >= 3: + output_max = activation_node.args[2] + + return output_min, output_max + + def fuse_binary_op_with_clamp(self, graph_module: torch.fx.GraphModule): + fuseAdded = False + for clamp_node in graph_module.graph.nodes: + if clamp_node.op == "call_function": + if clamp_node.target in self.FUSEABLE_CLAMP_OPS: + preceding_op = clamp_node.args[0] + + if ( + preceding_op.op == "call_function" + and preceding_op.target in self.FUSEABLE_BINARY_OPS + ): + # Delete activation + output_min_max = self.get_output_min_max_from_activation( + clamp_node + ) + new_args = list(preceding_op.args) + new_args.append(output_min_max[0]) + new_args.append(output_min_max[1]) + new_args = tuple(new_args) + clamp_node.replace_all_uses_with(preceding_op) + graph_module.graph.erase_node(clamp_node) + + new_op = None + match preceding_op.target: + case exir_ops.edge.aten.add.Tensor: + new_op = ( + exir_ops.edge.et_vk.binary_add_with_clamp.default + ) + case exir_ops.edge.aten.sub.Tensor: + new_op = ( + exir_ops.edge.et_vk.binary_sub_with_clamp.default + ) + case exir_ops.edge.aten.mul.Tensor: + new_op = ( + exir_ops.edge.et_vk.binary_mul_with_clamp.default + ) + case exir_ops.edge.aten.div.Tensor: + new_op = ( + exir_ops.edge.et_vk.binary_div_with_clamp.default + ) + + # Create and insert node of custom op `binary__with_clamp` + with graph_module.graph.inserting_before(preceding_op): + binary_op_clamp_node = graph_module.graph.create_node( + "call_function", + new_op, + new_args, + ) + + preceding_op.replace_all_uses_with(binary_op_clamp_node) + graph_module.graph.erase_node(preceding_op) + + fuseAdded = True + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return [fuseAdded, graph_module] + + def call(self, graph_module: torch.fx.GraphModule): + fuseAdded = True + while fuseAdded: + fuseAdded, graph_module = self.fuse_binary_op_with_clamp(graph_module) + + return PassResult(graph_module, True) diff --git a/backends/transforms/fuse_clamps.py b/backends/transforms/fuse_clamps.py new file mode 100644 index 00000000000..6e5be508d54 --- /dev/null +++ b/backends/transforms/fuse_clamps.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import executorch.backends.vulkan.custom_ops_lib # noqa + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class FuseClampsPass(ExportPass): + + FUSEABLE_CLAMPS = [ + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.clamp.default, + ] + + def get_output_min_max_from_activation(self, activation_node): + if activation_node.target == exir_ops.edge.aten.relu.default: + output_min = 0.0 + output_max = sys.float_info.max + elif activation_node.target == exir_ops.edge.aten.hardtanh.default: + output_min = -1.0 + output_max = 1.0 + if len(activation_node.args) > 1: + output_min = activation_node.args[1] + output_max = activation_node.args[2] + elif activation_node.target == exir_ops.edge.aten.clamp.default: + output_min = None + output_max = None + if len(activation_node.args) >= 2: + output_min = activation_node.args[1] + if len(activation_node.args) >= 3: + output_max = activation_node.args[2] + + return output_min, output_max + + def call(self, graph_module: torch.fx.GraphModule): + fuseAdded = True + while fuseAdded: + fuseAdded = False + for clamp_2_node in graph_module.graph.nodes: + if clamp_2_node.op == "call_function": + if clamp_2_node.target in self.FUSEABLE_CLAMPS: + preceding_op = clamp_2_node.args[0] + if ( + preceding_op.op == "call_function" + and preceding_op.target in self.FUSEABLE_CLAMPS + ): + # Ensure the shapes match + if ( + "val" not in clamp_2_node.args[0].meta + or "val" not in preceding_op.args[0].meta + ): + continue + if len(clamp_2_node.args[0].meta["val"].shape) != len( + preceding_op.args[0].meta["val"].shape + ): + continue + + min_max1 = self.get_output_min_max_from_activation( + preceding_op + ) + min_max2 = self.get_output_min_max_from_activation( + clamp_2_node + ) + + min_max = [None, None] + + if min_max1[0] is None and min_max2[0] is not None: + min_max[0] = min_max2[0] + elif min_max1[0] is not None and min_max2[0] is None: + min_max[0] = min_max1[0] + else: + min_max[0] = min(min_max1[0], min_max2[0]) + + if min_max1[1] is None and min_max2[1] is not None: + min_max[1] = min_max2[1] + elif min_max1[1] is not None and min_max2[1] is None: + min_max[1] = min_max1[1] + else: + min_max[1] = max(min_max1[1], min_max2[1]) + + new_args = list(preceding_op.args) + + # Insert the new min/max at indices 1 and 2 + new_args.insert(1, min_max[0]) + new_args.insert(2, min_max[1]) + new_args = new_args[0:3] + preceding_op.args = tuple(new_args) + clamp_2_node.replace_all_uses_with(preceding_op) + graph_module.graph.erase_node(clamp_2_node) + fuseAdded = True + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/transforms/fuse_conv_with_clamp.py b/backends/transforms/fuse_conv_with_clamp.py index 3f45296b26c..52fc1f4a413 100644 --- a/backends/transforms/fuse_conv_with_clamp.py +++ b/backends/transforms/fuse_conv_with_clamp.py @@ -14,7 +14,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class FuseClampPass(ExportPass): +class FuseConvClampPass(ExportPass): """ Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it. """ @@ -25,6 +25,7 @@ class FuseClampPass(ExportPass): FUSEABLE_ACTIVATIONS = [ exir_ops.edge.aten.relu.default, exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.clamp.default, ] def get_output_min_max_from_activation(self, activation_node): @@ -37,6 +38,13 @@ def get_output_min_max_from_activation(self, activation_node): if len(activation_node.args) > 1: output_min = activation_node.args[1] output_max = activation_node.args[2] + elif activation_node.target == exir_ops.edge.aten.clamp.default: + output_min = None + output_max = None + if len(activation_node.args) >= 2: + output_min = activation_node.args[1] + if len(activation_node.args) >= 3: + output_max = activation_node.args[2] return output_min, output_max diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index ca09d34c2fe..f354f2234bd 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -77,6 +77,38 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "fuse_clamps", + srcs = ["fuse_clamps.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/backends/vulkan:custom_ops_lib", + "//executorch/exir:pass_base", + "//executorch/exir:sym_util", + "//executorch/exir/dialects:lib", + ], + ) + + runtime.python_library( + name = "fuse_clamp_with_binary_op", + srcs = ["fuse_clamp_with_binary_op.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/backends/vulkan:custom_ops_lib", + "//executorch/exir:pass_base", + "//executorch/exir:sym_util", + "//executorch/exir/dialects:lib", + ], + ) + runtime.python_library( name = "view_copy_to_squeeze_unsqueeze", srcs = ["view_copy_to_squeeze_unsqueeze.py"], diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 56e803b9127..e99883cadd9 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -109,6 +109,763 @@ def conv_with_clamp_out_impl( ) lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd") +########################## +## conv_with_binary_add ## +########################## + + +def conv_with_binary_add_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, +): + return torch.add( + torch.convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ), + other, + ) + + +name = "conv_with_binary_add" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" +) +lib.impl(name, conv_with_binary_add_impl, "CompositeExplicitAutograd") +conv_with_binary_add_op = getattr(getattr(torch.ops, namespace), name) + +############################# +## conv_with_binary_add.out ## +############################# + + +def conv_with_binary_add_out_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, + out=None, +): + out = conv_with_binary_add_impl( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + other, + ) + return out + + +name = "conv_with_binary_add.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, conv_with_binary_add_out_impl, "CompositeExplicitAutograd") + +########################## +## conv_with_binary_sub ## +########################## + + +def conv_with_binary_sub_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, +): + return torch.sub( + torch.convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ), + other, + ) + + +name = "conv_with_binary_sub" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" +) +lib.impl(name, conv_with_binary_sub_impl, "CompositeExplicitAutograd") +conv_with_binary_sub_op = getattr(getattr(torch.ops, namespace), name) + +############################## +## conv_with_binary_sub.out ## +############################## + + +def conv_with_binary_sub_out_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, + out=None, +): + out = conv_with_binary_sub_impl( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + other, + ) + return out + + +name = "conv_with_binary_sub.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, conv_with_binary_sub_out_impl, "CompositeExplicitAutograd") + +########################## +## conv_with_binary_mul ## +########################## + + +def conv_with_binary_mul_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, +): + return torch.mul( + torch.convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ), + other, + ) + + +name = "conv_with_binary_mul" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" +) +lib.impl(name, conv_with_binary_mul_impl, "CompositeExplicitAutograd") +conv_with_binary_mul_op = getattr(getattr(torch.ops, namespace), name) + +############################## +## conv_with_binary_mul.out ## +############################## + + +def conv_with_binary_mul_out_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, + out=None, +): + out = conv_with_binary_mul_impl( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + other, + ) + return out + + +name = "conv_with_binary_mul.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, conv_with_binary_mul_out_impl, "CompositeExplicitAutograd") + +########################## +## conv_with_binary_div ## +########################## + + +def conv_with_binary_div_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, +): + return torch.div( + torch.convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ), + other, + ) + + +name = "conv_with_binary_div" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" +) +lib.impl(name, conv_with_binary_div_impl, "CompositeExplicitAutograd") +conv_with_binary_div_op = getattr(getattr(torch.ops, namespace), name) + +############################## +## conv_with_binary_div.out ## +############################## + + +def conv_with_binary_div_out_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, + out=None, +): + out = conv_with_binary_div_impl( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + other, + ) + return out + + +name = "conv_with_binary_div.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, conv_with_binary_div_out_impl, "CompositeExplicitAutograd") + +########################### +## clamp_with_binary_add ## +########################### + + +def clamp_with_binary_add_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, +): + return torch.add( + torch.clamp( + input, + output_min, + output_max, + ), + other, + ) + + +name = "clamp_with_binary_add" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" +) +lib.impl(name, clamp_with_binary_add_impl, "CompositeExplicitAutograd") +clamp_with_binary_add_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## clamp_with_binary_add.out ## +############################### + + +def clamp_with_binary_add_out_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, + out=None, +): + out = clamp_with_binary_add_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "clamp_with_binary_add.out" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, clamp_with_binary_add_out_impl, "CompositeExplicitAutograd") + +########################### +## clamp_with_binary_sub ## +########################### + + +def clamp_with_binary_sub_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, +): + return torch.sub( + torch.clamp( + input, + output_min, + output_max, + ), + other, + ) + + +name = "clamp_with_binary_sub" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" +) +lib.impl(name, clamp_with_binary_sub_impl, "CompositeExplicitAutograd") +clamp_with_binary_sub_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## clamp_with_binary_sub.out ## +############################### + + +def clamp_with_binary_sub_out_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, + out=None, +): + out = clamp_with_binary_sub_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "clamp_with_binary_sub.out" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, clamp_with_binary_sub_out_impl, "CompositeExplicitAutograd") + +########################### +## clamp_with_binary_mul ## +########################### + + +def clamp_with_binary_mul_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, +): + return torch.mul( + torch.clamp( + input, + output_min, + output_max, + ), + other, + ) + + +name = "clamp_with_binary_mul" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" +) +lib.impl(name, clamp_with_binary_mul_impl, "CompositeExplicitAutograd") +clamp_with_binary_mul_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## clamp_with_binary_mul.out ## +############################### + + +def clamp_with_binary_mul_out_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, + out=None, +): + out = clamp_with_binary_mul_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "clamp_with_binary_mul.out" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, clamp_with_binary_mul_out_impl, "CompositeExplicitAutograd") + +########################### +## clamp_with_binary_div ## +########################### + + +def clamp_with_binary_div_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, +): + return torch.div( + torch.clamp( + input, + output_min, + output_max, + ), + other, + ) + + +name = "clamp_with_binary_div" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" +) +lib.impl(name, clamp_with_binary_div_impl, "CompositeExplicitAutograd") +clamp_with_binary_div_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## clamp_with_binary_div.out ## +############################### + + +def clamp_with_binary_div_out_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, + out=None, +): + out = clamp_with_binary_div_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "clamp_with_binary_div.out" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, clamp_with_binary_div_out_impl, "CompositeExplicitAutograd") + +########################### +## binary_add_with_clamp ## +########################### + + +def binary_add_with_clamp_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), +): + return torch.clamp( + torch.add( + input, + other, + ), + output_min, + output_max, + ) + + +name = "binary_add_with_clamp" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" +) +lib.impl(name, binary_add_with_clamp_impl, "CompositeExplicitAutograd") +binary_add_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## binary_add_with_clamp.out ## +############################### + + +def binary_add_with_clamp_out_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), + out=None, +): + out = binary_add_with_clamp_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "binary_add_with_clamp.out" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, binary_add_with_clamp_impl, "CompositeExplicitAutograd") + +########################### +## binary_sub_with_clamp ## +########################### + + +def binary_sub_with_clamp_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), +): + return torch.clamp( + torch.sub( + input, + other, + ), + output_min, + output_max, + ) + + +name = "binary_sub_with_clamp" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" +) +lib.impl(name, binary_sub_with_clamp_impl, "CompositeExplicitAutograd") +binary_sub_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## binary_sub_with_clamp.out ## +############################### + + +def binary_sub_with_clamp_out_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), + out=None, +): + out = binary_sub_with_clamp_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "binary_sub_with_clamp.out" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, binary_sub_with_clamp_impl, "CompositeExplicitAutograd") + +########################### +## binary_mul_with_clamp ## +########################### + + +def binary_mul_with_clamp_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), +): + return torch.clamp( + torch.mul( + input, + other, + ), + output_min, + output_max, + ) + + +name = "binary_mul_with_clamp" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" +) +lib.impl(name, binary_mul_with_clamp_impl, "CompositeExplicitAutograd") +binary_mul_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## binary_mul_with_clamp.out ## +############################### + + +def binary_mul_with_clamp_out_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), + out=None, +): + out = binary_mul_with_clamp_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "binary_mul_with_clamp.out" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, binary_mul_with_clamp_impl, "CompositeExplicitAutograd") + +########################### +## binary_div_with_clamp ## +########################### + + +def binary_div_with_clamp_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), +): + return torch.clamp( + torch.div( + input, + other, + ), + output_min, + output_max, + ) + + +name = "binary_div_with_clamp" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" +) +lib.impl(name, binary_div_with_clamp_impl, "CompositeExplicitAutograd") +binary_div_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## binary_div_with_clamp.out ## +############################### + + +def binary_div_with_clamp_out_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), + out=None, +): + out = binary_div_with_clamp_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "binary_div_with_clamp.out" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, binary_div_with_clamp_impl, "CompositeExplicitAutograd") + + ################# ## grid_priors ## ################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 8eb47ff467e..3cec6917e57 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -219,6 +219,10 @@ def register_torchao_choose_qparams_affine(): exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.et_vk.binary_add_with_clamp.default, + exir_ops.edge.et_vk.binary_sub_with_clamp.default, + exir_ops.edge.et_vk.binary_mul_with_clamp.default, + exir_ops.edge.et_vk.binary_div_with_clamp.default, ] ) def register_binary_op(): @@ -246,6 +250,10 @@ def register_binary_op(): exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.round.default, exir_ops.edge.aten.leaky_relu.default, + exir_ops.edge.et_vk.clamp_with_binary_add.default, + exir_ops.edge.et_vk.clamp_with_binary_sub.default, + exir_ops.edge.et_vk.clamp_with_binary_mul.default, + exir_ops.edge.et_vk.clamp_with_binary_div.default, ] ) def register_unary_op(): diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl index 6f2a93667ea..ed420fcc72f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -69,6 +69,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} ${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} ${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "clamp_type", "0")} +${layout_declare_spec_const(C, "float", "min_val", "0")} +${layout_declare_spec_const(C, "float", "max_val", "0")} $if STORAGE == "buffer": const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); @@ -90,7 +93,20 @@ void main() { // Simple case; no broadcasting if (are_equal(inp, other)) { - t_out[out_bufi] = T(op(t_in[out_bufi], t_other[out_bufi], T(alpha))); + T in_val = T(t_in[out_bufi]); + T other_val = T(t_other[out_bufi]); + if (clamp_type == 1) { + in_val = T(clamp(in_val, T(min_val), T(max_val))); + } + else if (clamp_type == 2) { + other_val = T(clamp(other_val, T(min_val), T(max_val))); + } + T out_val = T(op(in_val, other_val, T(alpha))); + if (clamp_type == 3) { + out_val = T(clamp(out_val, T(min_val), T(max_val))); + } + t_out[out_bufi] = out_val; + return; } @@ -106,7 +122,19 @@ void main() { uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); uint other_bufi = tensor_idx_to_linear_idx(other, other_tidx); - t_out[out_bufi] = T(op(t_in[inp_bufi], t_other[other_bufi], T(alpha))); + T in_val = T(t_in[inp_bufi]); + T other_val = T(t_other[other_bufi]); + if (clamp_type == 1) { + in_val = T(clamp(in_val, T(min_val), T(max_val))); + } + else if (clamp_type == 2) { + other_val = T(clamp(other_val, T(min_val), T(max_val))); + } + T out_val = T(op(in_val, other_val, T(alpha))); + if (clamp_type == 3) { + out_val = T(clamp(out_val, T(min_val), T(max_val))); + } + t_out[out_bufi] = out_val; } #else // USING_TEXTURE @@ -126,6 +154,10 @@ void main() { // read axis mapped texel tidx_to_pos(in_idx, in_sizes, in_axis_map, packed_dim))); + if (clamp_type == 1) { + in_texel = clamp(in_texel, VEC4_T(min_val), VEC4_T(max_val)); + } + // broadcast on logical sizes ivec4 other_idx = broadcast_indices(tidx, other_sizes); VEC4_T other_texel = VEC4_T(load_texel( @@ -133,6 +165,10 @@ void main() { // read axis mapped texel tidx_to_pos(other_idx, other_sizes, other_axis_map, packed_dim))); + if (clamp_type == 2) { + in_texel = clamp(other_texel, VEC4_T(min_val), VEC4_T(max_val)); + } + // Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment. if (broadcast_params.x > 0) { in_texel = in_texel.xxxx; @@ -141,11 +177,20 @@ void main() { other_texel = other_texel.xxxx; } - write_texel_lpos( - t_out, - lpos, - VEC4_OUT_T(op(in_texel, other_texel, alpha)), - out_axis_map); + if (clamp_type != 3) { + write_texel_lpos( + t_out, + lpos, + VEC4_OUT_T(op(in_texel, other_texel, alpha)), + out_axis_map); + } + else { + write_texel_lpos( + t_out, + lpos, + VEC4_OUT_T(clamp(VEC4_OUT_T(op(in_texel, other_texel, alpha)), min_val, max_val)), + out_axis_map); + } } #endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl index bb7ce482a7a..5bc01fa7f57 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl @@ -61,6 +61,7 @@ void main() { } VEC4_T in_texel = texelFetch(t_in, pos, 0); + imageStore(t_out, pos, VEC4_T(op(in_texel, minimum, maximum))); } diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 025b483eab7..9575ca0dcdd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -54,13 +54,39 @@ void resize_binary_op_node( graph->virtual_resize(out, new_out_sizes); } +int remove_clamp_from_name(std::string& op) { + if (op.find("clamp_0_with_") != std::string::npos) { + op.erase(op.find("clamp_0_with_"), 13); + + // Clamp input 0 + return 1; + } + if (op.find("clamp_1_with_") != std::string::npos) { + op.erase(op.find("clamp_1_with_"), 13); + + // Clamp input 1 + return 2; + } + if (op.find("_with_clamp") != std::string::npos) { + op.erase(op.find("_with_clamp"), 11); + + // Clamp output + return 3; + } + + // No clamp + return 0; +} + void add_binary_op_texture_node( ComputeGraph& graph, const ValueRef in1, const ValueRef in2, const ValueRef alpha, const ValueRef out, - const std::string& op_name) { + const std::string& op_name, + const float min, + const float max) { ValueRef arg1 = prepack_standard_like(graph, in1, out, true); ValueRef arg2 = prepack_standard_like(graph, in2, out, true); @@ -80,7 +106,10 @@ void add_binary_op_texture_node( std::string kernel_name("binary_"); kernel_name.reserve(kShaderNameReserve); - kernel_name += op_name; + + std::string op = op_name; + int clamp_type = remove_clamp_from_name(op); + kernel_name += op; add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(in1)); @@ -101,7 +130,10 @@ void add_binary_op_texture_node( // Specialization Constants {graph.hashed_layout_of(out), graph.hashed_layout_of(arg1), - graph.hashed_layout_of(arg2)}, + graph.hashed_layout_of(arg2), + clamp_type, + min, + max}, // Resize Args {}, // Resizing Logic @@ -114,7 +146,9 @@ void add_binary_op_buffer_node( const ValueRef in2, const ValueRef alpha, const ValueRef out, - const std::string& op_name) { + const std::string& op_name, + const float min, + const float max) { // check_binary_op_args(*t_in1, *t_in2, *t_out); float alpha_val = 1.0f; @@ -126,7 +160,9 @@ void add_binary_op_buffer_node( std::string kernel_name("binary_"); kernel_name.reserve(kShaderNameReserve); - kernel_name += op_name; + std::string op = op_name; + int clamp_type = remove_clamp_from_name(op); + kernel_name += op; add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(in1)); @@ -149,7 +185,9 @@ void add_binary_op_buffer_node( // Specialization Constants {graph.hashed_layout_of(out), graph.hashed_layout_of(in1), - graph.hashed_layout_of(in2)}, + graph.hashed_layout_of(in2), + min, + max}, // Resize Args {}, // Resizing Logic @@ -162,11 +200,13 @@ void add_binary_op_node( const ValueRef in2, const ValueRef alpha, const ValueRef out, - const std::string& op_name) { + const std::string& op_name, + const float min = std::numeric_limits::infinity(), + const float max = -std::numeric_limits::infinity()) { if (graph.is_buffer_storage(out)) { - add_binary_op_buffer_node(graph, in1, in2, alpha, out, op_name); + add_binary_op_buffer_node(graph, in1, in2, alpha, out, op_name, min, max); } else { - add_binary_op_texture_node(graph, in1, in2, alpha, out, op_name); + add_binary_op_texture_node(graph, in1, in2, alpha, out, op_name, min, max); } } @@ -182,6 +222,40 @@ void add_binary_op_node( graph, args[0], args[1], kDummyValueRef, args[2], #op_name); \ } +float get_val_or_inf_(ComputeGraph& graph, const ValueRef& val, bool max) { + if (!graph.val_is_none(val)) { + return graph.extract_scalar(val); + } + return max ? std::numeric_limits::infinity() + : -std::numeric_limits::infinity(); +} + +#define DEFINE_BINARY_OP_WITH_ALPHA_FN_CLAMPED(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_binary_op_node( \ + graph, \ + args[0], \ + args[1], \ + args[2], \ + args[5], \ + #op_name, \ + get_val_or_inf_(graph, args[3], false), \ + get_val_or_inf_(graph, args[4], true)); \ + } + +#define DEFINE_BINARY_OP_FN_CLAMPED(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_binary_op_node( \ + graph, \ + args[0], \ + args[1], \ + kDummyValueRef, \ + args[4], \ + #op_name, \ + get_val_or_inf_(graph, args[2], false), \ + get_val_or_inf_(graph, args[3], true)); \ + } + DEFINE_BINARY_OP_WITH_ALPHA_FN(add); DEFINE_BINARY_OP_WITH_ALPHA_FN(sub); @@ -199,6 +273,11 @@ DEFINE_BINARY_OP_FN(le); DEFINE_BINARY_OP_FN(gt); DEFINE_BINARY_OP_FN(ge); +DEFINE_BINARY_OP_FN_CLAMPED(add_with_clamp); +DEFINE_BINARY_OP_FN_CLAMPED(sub_with_clamp); +DEFINE_BINARY_OP_FN_CLAMPED(mul_with_clamp); +DEFINE_BINARY_OP_FN_CLAMPED(div_with_clamp); + REGISTER_OPERATORS { VK_REGISTER_OP(aten.add.Tensor, add); VK_REGISTER_OP(aten.sub.Tensor, sub); @@ -212,6 +291,11 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.le.Tensor, le); VK_REGISTER_OP(aten.gt.Tensor, gt); VK_REGISTER_OP(aten.ge.Tensor, ge); + + VK_REGISTER_OP(et_vk.binary_add_with_clamp.default, add_with_clamp); + VK_REGISTER_OP(et_vk.binary_sub_with_clamp.default, sub_with_clamp); + VK_REGISTER_OP(et_vk.binary_mul_with_clamp.default, mul_with_clamp); + VK_REGISTER_OP(et_vk.binary_div_with_clamp.default, div_with_clamp); } } // namespace vkcompute diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index a9ba62b6f9f..170afe4dc44 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -381,6 +381,8 @@ def define_common_targets(is_fbcode = False): deps = [ "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:fuse_batch_norm_with_conv", + "//executorch/backends/transforms:fuse_clamp_with_binary_op", + "//executorch/backends/transforms:fuse_clamps", "//executorch/backends/transforms:fuse_conv_with_clamp", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:remove_clone_ops", diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index d59bd9eff7d..8c902d6ba0d 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -13,7 +13,11 @@ import executorch.backends.vulkan.utils as utils from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform -from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass +from executorch.backends.transforms.fuse_clamp_with_binary_op import ( + FuseClampBinaryOpPass, +) +from executorch.backends.transforms.fuse_clamps import FuseClampsPass +from executorch.backends.transforms.fuse_conv_with_clamp import FuseConvClampPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import ( ViewCopyToSqueezeUnsqueezePass, @@ -168,7 +172,9 @@ def preprocess( # noqa: C901 [ FuseBatchNormPass(program), FusePatternsPass(), - FuseClampPass(), + FuseClampsPass(), + FuseConvClampPass(), + FuseClampBinaryOpPass(), AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(),