Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions backends/vulkan/_passes/remove_redundant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass):
exir_ops.edge.aten.lift_fresh_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.aten.expand_copy.default,
}

def __init__(self) -> None:
super(RemoveRedundantOpsTransform, self).__init__()

def _should_remove(self, node: torch.fx.Node) -> bool:
if node.target in self.redundant_ops:
return True

# Only remove to_copy if dtype does not change. Otherwise, memory format changes
# will be handled internally by the backend.
if (
node.target == exir_ops.edge.aten._to_copy.default
or node.target == torch.ops.aten._to_copy.default
):
src_dtype = node.meta["val"].dtype
# pyre-ignore
dst_dtype = node.args[0].meta["val"].dtype
return src_dtype == dst_dtype

return False
if node.target not in self.redundant_ops:
return False

orig_node = node.args[0]
assert isinstance(orig_node, torch.fx.Node)

src_dtype = orig_node.meta["val"].dtype
dst_dtype = node.meta["val"].dtype

# Do not remove if the op is converting the dtype.
if src_dtype != dst_dtype:
return False

src_shape = orig_node.meta["val"].shape
dst_shape = node.meta["val"].shape

return src_shape == dst_shape

def _remove(self, graph_module: torch.fx.GraphModule) -> None:
for node in graph_module.graph.nodes:
if not self._should_remove(node):
continue

with graph_module.graph.inserting_after(node):
node.replace_all_uses_with(node.args[0])
node.replace_all_uses_with(node.args[0])

graph_module.graph.eliminate_dead_code()

Expand Down
16 changes: 9 additions & 7 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,16 @@

import logging
import operator

from typing import Any

import executorch.backends.vulkan.utils as utils

import torch

from executorch.backends.vulkan.op_registry import get_op_features, has_impl, OpFeatures

from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
)

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.tensor import TensorSpec

Expand Down Expand Up @@ -130,15 +124,17 @@ def __init__(
texture_limits: utils.ImageExtents,
default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D,
default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED,
force_fp16: bool = False,
):
super().__init__()
self.default_storage: VkStorageType = default_storage_type
self.default_layout: VkMemoryLayout = default_memory_layout
self.texture_limits = texture_limits
self.force_fp16 = force_fp16

# Magic number to limit "lookahead" when tracing through users of an operator
# to constrain the representation of its arguments/outputs.
self.max_trace_search_depth = 20
self.max_trace_search_depth = None

def is_valid_op_node(self, node: Any) -> bool:
"""
Expand Down Expand Up @@ -361,6 +357,12 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No
2. Then, try to trace through the users of the argument to find a representation
that can be used for as long as possible without needing a transition.
"""
# If forcing fp16, then try to use texture storage whenever possible. This is
# a temporary stopgap measure until all buffer implementations properly account
# for potential overflow of fp16 representation range when doing math in fp16.
if self.force_fp16:
op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE)

arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)

Expand Down
28 changes: 3 additions & 25 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@
# pyre-unsafe

import operator

from typing import Any, Callable, Dict, List, Optional, Union

import executorch.backends.vulkan.custom_ops_lib # noqa

import executorch.backends.vulkan.utils as utils

import torch

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._subclasses.fake_tensor import FakeTensor

Expand Down Expand Up @@ -129,6 +124,7 @@ def update_features_impl(op: OpKey):
# Symbolic integer ops
torch.ops.aten.sym_size.int,
operator.add,
operator.sub,
operator.lt,
operator.gt,
operator.ge,
Expand Down Expand Up @@ -297,27 +293,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:

@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
def register_to_copy_dim_order_op():
# Currently there is no "real" implementation for to_dim_order_copy, but it can be
# removed as long as the operator is not changing the dtype, i.e. the operator call
# is modifying the dim order only. Therefore, check that the input and output dtypes
# are the same, if so the operator is safe to remove.
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
in_arg = node.args[0]
if not isinstance(in_arg, torch.fx.Node):
return False

in_tensor = in_arg.meta.get("val", None)
out_tensor = node.meta.get("val", None)

if in_tensor.dtype != out_tensor.dtype:
return False

return True

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_storage=utils.ANY_BUFFER,
supports_resize=True,
are_node_inputs_supported_fn=check_dim_order_copy_node,
)


Expand Down Expand Up @@ -709,7 +687,7 @@ def register_sdpa_ops():
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
def register_rotary_emb_op():
return OpFeatures(
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
inputs_storage=utils.CONTIGUOUS_ANY,
supports_resize=True,
)

Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,10 @@ class ComputeGraph final {

bool device_name_contains(const char* substr);

int64_t max_buffer_numel() {
return static_cast<int64_t>(context_->adapter_ptr()->max_buffer_numel());
}

//
// Graph Building
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ int load_embedding_idx(const TensorIndex4D out_tidx) {
indices_tidx.data.xyz = out_tidx.data.yzw;
indices_tidx.data.w = 0;

TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple(
indices_tidx, indices);
TextureElementIndex elem_pos = tensor4d_idx_to_texture_element_idx_simple(
indices, indices_tidx);

const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0);
return in_texel[elem_pos.comp];
Expand All @@ -61,7 +61,7 @@ void main() {
return;
}

TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp);
TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);
const int embedding_idx = load_embedding_idx(out_tidx);

const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x);
Expand Down
50 changes: 40 additions & 10 deletions backends/vulkan/runtime/graph/ops/glsl/indexing.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,20 @@ struct TensorIndex4D {
ivec4 data;
};

TensorIndex4D zero_tensor4d_idx() {
TensorIndex4D tidx;
tidx.data = ivec4(0);
return tidx;
}

bool out_of_bounds(const TensorIndex4D tidx, const BufferMetadata meta) {
return any(greaterThanEqual(tidx.data, meta.sizes[0]));
}

bool out_of_bounds(const TensorIndex4D tidx, const TextureMetadata meta) {
return any(greaterThanEqual(tidx.data, meta.sizes));
}

//
// TextureElementIndex
//
Expand Down Expand Up @@ -245,15 +259,9 @@ void clamp_tensor_idx(const BufferMetadata meta, inout TensorIndex tidx) {
tidx.data[1] = min(tidx.data[1], meta.sizes[1] - 1);
}

TensorIndex4D zero_tensor4d_idx() {
TensorIndex4D tidx;
tidx.data = ivec4(0);
return tidx;
}

// Does not account for axis mapping or batches
TensorIndex4D texture_pos_to_tensor_idx_simple(
const ivec3 pos, const TextureMetadata meta) {
TensorIndex4D texture_pos_to_tensor4d_idx_simple(
const TextureMetadata meta, const ivec3 pos) {
TensorIndex4D tidx;
tidx.data.xyz = pos;
tidx.data.w = 0;
Expand All @@ -262,8 +270,20 @@ TensorIndex4D texture_pos_to_tensor_idx_simple(
}

// Does not account for axis mapping or batches
TextureElementIndex tensor_idx_to_texture_element_idx_simple(
const TensorIndex4D tidx, const TextureMetadata meta) {
ivec3 tensor4d_idx_to_texel_pos_simple(
const TextureMetadata meta, const TensorIndex4D tidx) {
ivec3 texel_pos;

const int packed_dim_idx = tidx.data[meta.packed_dim];

texel_pos = tidx.data.xyz;
texel_pos[meta.packed_dim] = div_4(packed_dim_idx);
return texel_pos;
}

// Does not account for axis mapping or batches
TextureElementIndex tensor4d_idx_to_texture_element_idx_simple(
const TextureMetadata meta, const TensorIndex4D tidx) {
const int packed_dim_idx = tidx.data[meta.packed_dim];
TextureElementIndex tex_idx;
tex_idx.pos = tidx.data.xyz;
Expand All @@ -272,6 +292,16 @@ TextureElementIndex tensor_idx_to_texture_element_idx_simple(
return tex_idx;
}

uint tensor4d_idx_to_linear_idx(
const BufferMetadata meta,
const TensorIndex4D tidx) {
uint lin_idx = 0;
for (int d = 0; d < 4; ++d) {
lin_idx += meta.strides[0][d] * tidx.data[d];
}
return lin_idx;
}

//
// Debug utilities
//
Expand Down
Loading
Loading