Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ runtime.python_library(
],
)

runtime.python_library(
name = "insert_dtype_promotion",
srcs = ["insert_dtype_promotion.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "fuse_patterns",
srcs = ["fuse_patterns.py"],
Expand Down Expand Up @@ -133,6 +146,7 @@ runtime.python_library(
":fold_qdq",
":fuse_patterns",
":fuse_quantized_ops",
":insert_dtype_promotion",
":insert_prepack_nodes",
":remove_asserts",
":remove_redundant_ops",
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
FuseQuantizedOpsTransform,
)
from executorch.backends.vulkan._passes.insert_dtype_promotion import (
InsertDtypePromotionPass,
)
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
from executorch.backends.vulkan._passes.remove_asserts import (
remove_asserts,
Expand All @@ -28,6 +31,7 @@
"FoldQDQPass",
"FusePatternsPass",
"FuseQuantizedOpsTransform",
"InsertDtypePromotionPass",
"insert_prepack_nodes",
"remove_asserts",
"RemoveAssertsTransform",
Expand Down
102 changes: 102 additions & 0 deletions backends/vulkan/_passes/insert_dtype_promotion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Set, Union

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass

OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]

# Binary ops whose first two args are tensor inputs that may need promotion
BINARY_OPS: Set[OpType] = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.ge.Tensor,
}


def _promote_dtype(a: torch.dtype, b: torch.dtype) -> torch.dtype:
"""Promote to common dtype following PyTorch type promotion rules."""
if a == b:
return a
# Any mix of different dtypes promotes to float32
return torch.float32


class InsertDtypePromotionPass(ExportPass):
"""
Insert _to_copy nodes before binary ops when the two tensor inputs have
different dtypes, promoting both to a common dtype.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
dirty = False
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in BINARY_OPS:
continue

lhs = node.args[0]
rhs = node.args[1]

if not isinstance(lhs, torch.fx.Node) or not isinstance(rhs, torch.fx.Node):
continue

if "val" not in lhs.meta or "val" not in rhs.meta:
continue

lhs_dtype = lhs.meta["val"].dtype
rhs_dtype = rhs.meta["val"].dtype

if lhs_dtype == rhs_dtype:
continue

promoted = _promote_dtype(lhs_dtype, rhs_dtype)

if lhs_dtype != promoted:
with graph_module.graph.inserting_before(node):
cast_lhs = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(lhs,),
{"dtype": promoted},
)
cast_lhs.meta["val"] = lhs.meta["val"].to(promoted)
node.replace_input_with(lhs, cast_lhs)
dirty = True

if rhs_dtype != promoted:
with graph_module.graph.inserting_before(node):
cast_rhs = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(rhs,),
{"dtype": promoted},
)
cast_rhs.meta["val"] = rhs.meta["val"].to(promoted)
node.replace_input_with(rhs, cast_rhs)
dirty = True

if dirty:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
dead_code_elimination_pass(graph_module)

return PassResult(graph_module, dirty)
22 changes: 20 additions & 2 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,28 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
# Vulkan compute graph. This annotation is used in later graph passes.
node.meta["etvk_tensorref"] = True

# Get the list of node users that do not handle their own prepacking
# Get the list of node users that need a prepack node inserted. This
# includes ops that don't handle their own prepacking, as well as ops
# that do handle their own prepacking but use this constant tensor as
# the primary input (since the op expects the primary input to be a GPU
# tensor, not a TensorRef).
nodes_to_replace_input = []
for user in node.users:
if user.op == "call_function" and not handles_own_prepacking(user.target):
if user.op != "call_function":
continue

if not handles_own_prepacking(user.target):
nodes_to_replace_input.append(user)
continue

# Most prepacking ops have the primary input at arg 0, but
# embedding is embedding(weight, indices, ...) where the
# primary input (indices) is at arg 1.
primary_arg_idx = 0
if user.target == exir_ops.edge.aten.embedding.default:
primary_arg_idx = 1

if node in user.args and user.args.index(node) == primary_arg_idx:
nodes_to_replace_input.append(user)

if len(nodes_to_replace_input) == 0:
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_T,
outputs_dtypes=utils.FP_T,
inputs_dtypes=utils.FP_INT_T,
outputs_dtypes=utils.FP_INT_T,
supports_resize=True,
are_node_inputs_supported_fn=check_to_copy_node,
)
Expand Down
10 changes: 5 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ binary_op:
- VALUE: half
- VALUE: float
- NAME: binary_eq_texture3d
OPERATOR: all(lessThanEqual(abs(X - Y), VEC4_T(1e-5)))
OPERATOR: lessThanEqual(abs(X - Y), VEC4_T(1e-5))
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -61,7 +61,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_lt_texture3d
OPERATOR: all(lessThan(X, Y))
OPERATOR: lessThan(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -77,7 +77,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_le_texture3d
OPERATOR: all(lessThanEqual(X, Y))
OPERATOR: lessThanEqual(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -93,7 +93,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_gt_texture3d
OPERATOR: all(greaterThan(X, Y))
OPERATOR: greaterThan(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand All @@ -109,7 +109,7 @@ binary_op:
- VALUE: float
- VALUE: int32
- NAME: binary_ge_texture3d
OPERATOR: all(greaterThanEqual(X, Y))
OPERATOR: greaterThanEqual(X, Y)
STORAGE: texture3d
generate_variant_forall:
DTYPE:
Expand Down
44 changes: 44 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define IN_VEC4_T ${texel_type(IN_DTYPE)}
#define OUT_VEC4_T ${texel_type(OUT_DTYPE)}

${define_required_extensions("texture3d", IN_DTYPE)}
${define_required_extensions("texture3d", OUT_DTYPE)}

#include "indexing_utils.h"

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = C_DIM;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim);

if (any(greaterThanEqual(idx, out_sizes))) {
return;
}

IN_VEC4_T in_texel = IN_VEC4_T(texelFetch(t_in, pos, 0));
imageStore(t_out, pos, OUT_VEC4_T(in_texel));
}
19 changes: 19 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

view_convert_texture:
parameter_names_with_default_values:
IN_DTYPE: float
OUT_DTYPE: float
STORAGE: texture3d
generate_variant_forall:
combination:
parameter_names: [IN_DTYPE, OUT_DTYPE]
combos:
- parameter_values: [int32, float]
- parameter_values: [float, int32]
shader_variants:
- NAME: view_convert_texture
36 changes: 23 additions & 13 deletions backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include <executorch/backends/vulkan/runtime/graph/ops/BlitNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
#include <set>

namespace vkcompute {

Expand All @@ -25,19 +25,29 @@ void resize_to_copy_op_node(
graph->virtual_resize(out, graph->sizes_of(self));
}

bool is_float_type(vkapi::ScalarType dtype) {
return dtype == vkapi::ScalarType::Float || dtype == vkapi::ScalarType::Half;
}

void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
static std::set<vkapi::ScalarType> supported_types = {
vkapi::ScalarType::Float, vkapi::ScalarType::Half};

VK_CHECK_COND(
supported_types.find(graph.dtype_of(in)) != supported_types.end() &&
supported_types.find(graph.dtype_of(out)) != supported_types.end(),
"Unsupported dtype for to_copy, only Float and Half are currently supported, recieved ",
vkapi::to_string(graph.dtype_of(in)),
" <-> ",
vkapi::to_string(graph.dtype_of(out)));

graph.execute_nodes().emplace_back(new BlitNode(graph, in, out));
vkapi::ScalarType in_dtype = graph.dtype_of(in);
vkapi::ScalarType out_dtype = graph.dtype_of(out);

// Same-dtype or float<->half conversions can use BlitNode
if (in_dtype == out_dtype ||
(is_float_type(in_dtype) && is_float_type(out_dtype))) {
graph.execute_nodes().emplace_back(new BlitNode(graph, in, out));
return;
}

// Other conversions (e.g. int<->float) use view_convert shaders
if (graph.is_buffer_storage(in)) {
add_view_copy_convert_buffer_node(
graph, in, out, {}, resize_to_copy_op_node);
} else {
add_view_copy_convert_texture_node(
graph, in, out, {}, resize_to_copy_op_node);
}
}

void to_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
Expand Down
29 changes: 29 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/View.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,35 @@ void add_view_copy_buffer_node(
resize_fn));
}

void add_view_copy_convert_texture_node(
ComputeGraph& graph,
ValueRef in,
ValueRef out,
const std::vector<ValueRef>& resize_args,
const ExecuteNode::ResizeFunction& resize_fn) {
std::string kernel_name = "view_convert_texture";
add_dtype_suffix(kernel_name, graph.dtype_of(in));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Parameter Buffers
{},
// Push Constants
{{graph.sizes_pc_of(out)}},
// Specialization Constants
{graph.packed_dim_of(out)},
// Resize Args
resize_args,
// Resizing Logic
resize_fn));
}

void add_view_copy_convert_buffer_node(
ComputeGraph& graph,
ValueRef in,
Expand Down
Loading
Loading