Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 81 additions & 49 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import logging
import operator

from collections import deque
from typing import Any

import executorch.backends.vulkan.utils as utils
Expand Down Expand Up @@ -332,81 +334,111 @@ def trace_node_users_to_constrain_repset( # noqa: C901
search_depth: list[int] | None = None,
) -> utils.TensorRepSet:
"""
For an ambiguous repset, try to constrain the repset by tracing the required
repsets of the users of `origin_node`. The idea is to try to find a representation
that can be used the longest without needing user nodes to insert a transition
for its arguments.
BFS over downstream users to constrain an ambiguous repset. Explores all
immediate users at each level before going deeper, so that nearby constrained
ops (e.g. linear requiring width_packed) are discovered before the search
budget is spent on a single deep branch.
"""
# Optionally limit the total number of nodes explored to improve export
# time. search_depth is a mutable list so that all branches of a fan-out
# share a single counter, preventing exponential blowup.
if self.max_trace_search_depth is not None:
if search_depth is None:
search_depth = [self.max_trace_search_depth]
search_depth[0] -= 1
if search_depth[0] <= 0:

queue: deque[torch.fx.Node] = deque()
queue.append(origin_node)

while queue:
if repset.is_constrained():
return repset

users_to_trace = origin_node.users
if self.max_trace_search_depth is not None:
search_depth[0] -= 1
if search_depth[0] <= 0:
return repset

node = queue.popleft()

users_to_trace = node.users

sync_outs_repr = True
if self.is_valid_op_node(node):
sync_outs_repr = self.get_node_cached_repsets(node).sync_outs_repr

sync_outs_repr = True
if self.is_valid_op_node(origin_node):
sync_outs_repr = self.get_node_cached_repsets(origin_node).sync_outs_repr
if utils.num_tensors_in_node(node) > 1 and not sync_outs_repr:
users_to_trace = []
for usage_node in node.users:
if (
usage_node.target == operator.getitem
and usage_node.args[1] == 1
):
users_to_trace.append(usage_node)

if utils.num_tensors_in_node(origin_node) > 1 and not sync_outs_repr:
users_to_trace = []
for usage_node in origin_node.users:
if usage_node.target == operator.getitem and usage_node.args[1] == 1:
users_to_trace.append(usage_node)
for usage_node in users_to_trace:
if repset.is_constrained():
return repset

for usage_node in users_to_trace:
arg_i_in_user = None
for i in range(len(usage_node.args)):
if origin_node == usage_node.args[i]:
arg_i_in_user = i
break
arg_i_in_user = None
for i in range(len(usage_node.args)):
if node == usage_node.args[i]:
arg_i_in_user = i
break

if arg_i_in_user is not None:
repset = self.constrain_repset_with_user(
usage_node, arg_i_in_user, repset, search_depth
if arg_i_in_user is None:
continue

if not self.is_valid_op_node(usage_node):
continue

cur_node_repsets = self.get_node_cached_repsets(usage_node)
req_arg_repset = cur_node_repsets.get_arg_repset(arg_i_in_user)

if not req_arg_repset.any_in_common(repset):
continue

repset = repset.make_intersect(req_arg_repset)

repset_propagates_to_output = (
cur_node_repsets.sync_primary_io_repr
and (
cur_node_repsets.sync_args_repr
or arg_i_in_user == cur_node_repsets.primary_arg_idx
)
)

if repset.is_constrained():
return repset
if repset_propagates_to_output:
queue.append(usage_node)

return repset

def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> None:
"""
Attempts to constrain the repset of the argument at index `arg_i` of the op
associated with `op_repsets`. Does this with two stages:

1. First, account for any existing representation that has already been determined
for the argument. If no existing representation has been determined, then use
the output repset of the operator that produces the argument.
2. Then, try to trace through the users of the argument to find a representation
that can be used for as long as possible without needing a transition.
associated with `op_repsets`. Prefers downstream consumers' layout requirements
over the upstream source's existing layout, falling back to the source only when
downstream tracing does not fully constrain the repset.
"""
# If forcing fp16, then try to use texture storage whenever possible. This is
# a temporary stopgap measure until all buffer implementations properly account
# for potential overflow of fp16 representation range when doing math in fp16.
if self.force_fp16:
op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE)

arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)

arg_repset = op_repsets.get_arg_repset(arg_i)
if arg_repset.is_constrained():
return

# First, trace downstream users to discover what layout they prefer.
arg_node = op_repsets.op_node.args[arg_i]

if isinstance(arg_node, list):
arg_node = arg_node[0]

arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset)
arg_repset = op_repsets.get_arg_repset(arg_i)
if not arg_repset.is_constrained():
downstream_repset = self.trace_node_users_to_constrain_repset(
arg_node, arg_repset
)
op_repsets.try_constrain_with_arg_repset(arg_i, downstream_repset)

# Fall back to the upstream source's existing layout only if downstream
# tracing did not fully constrain the repset.
arg_repset = op_repsets.get_arg_repset(arg_i)
if not arg_repset.is_constrained():
arg_source_repset = self.get_arg_tensor_source_repset(
op_repsets.op_node, arg_i
)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)

def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None:
"""
Expand Down
77 changes: 77 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,55 @@ def apply_rotary_emb_hf_impl(
lib.impl(name, apply_rotary_emb_hf_impl, "CompositeExplicitAutograd")
apply_rotary_emb_hf_op = getattr(getattr(torch.ops, namespace), name)

##################################
## apply_rotary_emb_interleaved ##
##################################


def apply_rotary_emb_interleaved_impl(
x: torch.Tensor, freqs_cis: torch.Tensor
) -> torch.Tensor:
# EdgeTAM's pair-interleaved complex-number RoPE.
# x: [B, N, C] with (real, imag) pairs interleaved along C
# freqs_cis: any rank whose flattened layout is [N, C]. Commonly 2D
# [N, C] or 4D [1, N, C/2, 2] from
# `torch.view_as_real(...).unsqueeze(0)`. The (cos, sin)
# pairs are interleaved along the innermost axis in the
# flattened view.
# Semantically equivalent to:
# freqs_cis.reshape(N, C // 2, 2) -> (cos, sin)
# out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k]
# out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]
B, N, C = x.shape
a_real, a_imag = x.view(B, N, C // 2, 2).unbind(-1)
# Use reshape so callers may pass freqs_cis at any rank.
cs = freqs_cis.reshape(N, C // 2, 2)
b_real, b_imag = cs[..., 0], cs[..., 1]
out = torch.stack(
(a_real * b_real - a_imag * b_imag, a_real * b_imag + a_imag * b_real),
dim=-1,
)
return out.view(B, N, C)


def apply_rotary_emb_interleaved_meta(
x: torch.Tensor, freqs_cis: torch.Tensor
) -> torch.Tensor:
# Meta kernel: shape-only. Keeps the op opaque during torch.export (no
# inlining of view/reshape calls into the exported graph) and does not
# constrain the rank of freqs_cis — any shape with N * C elements is
# accepted by the Vulkan dispatcher.
return torch.empty_like(x)


name = "apply_rotary_emb_interleaved"
lib.define(f"{name}(Tensor x, Tensor freqs_cis) -> Tensor")
# CPU kernel preserves eager-mode reference semantics.
lib.impl(name, apply_rotary_emb_interleaved_impl, "CPU")
# Meta kernel keeps the op opaque in the exported graph.
lib.impl(name, apply_rotary_emb_interleaved_meta, "Meta")
apply_rotary_emb_interleaved_op = getattr(getattr(torch.ops, namespace), name)

########################
## q8ta_add ##
########################
Expand Down Expand Up @@ -960,6 +1009,34 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
lib.impl(name, select_as_symint_impl, "Meta")
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)

##########
## sdpa ##
##########


def sdpa_impl(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
):
if scale is None:
scale = 1.0 / (q.size(-1) ** 0.5)
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
if attn_mask is not None:
attn = attn + attn_mask
attn = torch.softmax(attn, dim=-1)
return torch.matmul(attn, v)


name = "sdpa"
lib.define(
f"{name}(Tensor q, Tensor k, Tensor v, Tensor? attn_mask = None, float? scale = None) -> Tensor"
)
lib.impl(name, sdpa_impl, "CompositeExplicitAutograd")
sdpa_op = getattr(getattr(torch.ops, namespace), name)

################
## rms_norm ##
################
Expand Down
30 changes: 30 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,20 @@ def register_sdpa_cpp_ops():
)


# =============================================================================
# SDPA.cpp (fused SDPA entry point)
# =============================================================================


@update_features("et_vk::sdpa")
def register_general_sdpa():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
inputs_dtypes=utils.FP_T,
supports_resize=True,
)


# =============================================================================
# RotaryEmbedding.cpp
# =============================================================================
Expand All @@ -1096,6 +1110,22 @@ def register_apply_rotary_emb_hf():
)


@update_features(exir_ops.edge.et_vk.apply_rotary_emb_interleaved.default)
def register_apply_rotary_emb_interleaved():
return OpFeatures(
# freqs_cis is pinned to buffer storage so the shader can compute a
# flat [N, C] linear address regardless of the tensor's declared rank
# (callers commonly pass 4D [1, N, C/2, 2] without a preceding view).
inputs_storage=[
utils.CONTIGUOUS_ANY, # x
utils.CONTIGUOUS_BUFFER, # freqs_cis
],
inputs_dtypes=utils.FP_T,
supports_resize=True,
supports_highdim=True,
)


# =============================================================================
# Permute.cpp
# =============================================================================
Expand Down
Loading
Loading