Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,34 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
lib.impl(name, select_as_symint_impl, "Meta")
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)

##########
## sdpa ##
##########


def sdpa_impl(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
):
if scale is None:
scale = 1.0 / (q.size(-1) ** 0.5)
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
if attn_mask is not None:
attn = attn + attn_mask
attn = torch.softmax(attn, dim=-1)
return torch.matmul(attn, v)


name = "sdpa"
lib.define(
f"{name}(Tensor q, Tensor k, Tensor v, Tensor? attn_mask = None, float? scale = None) -> Tensor"
)
lib.impl(name, sdpa_impl, "CompositeExplicitAutograd")
sdpa_op = getattr(getattr(torch.ops, namespace), name)

################
## rms_norm ##
################
Expand Down
14 changes: 14 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,20 @@ def register_sdpa_cpp_ops():
)


# =============================================================================
# SDPA.cpp (fused SDPA entry point)
# =============================================================================


@update_features("et_vk::sdpa")
def register_general_sdpa():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
inputs_dtypes=utils.FP_T,
supports_resize=True,
)


# =============================================================================
# RotaryEmbedding.cpp
# =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
* Macro Settings:
* - TILE_M
* - TILE_K4
*
* Optional:
* - LINEAR_FP_INPUT_TILE_VEC4_T — input tile vec4 type (default: VEC4_T).
*/

#extension GL_EXT_control_flow_attributes : require

#ifndef LINEAR_FP_INPUT_TILE_VEC4_T
#define LINEAR_FP_INPUT_TILE_VEC4_T VEC4_T
#endif

struct FPInputTile {
VEC4_T data[TILE_M][TILE_K4];
LINEAR_FP_INPUT_TILE_VEC4_T data[TILE_M][TILE_K4];
};

#ifdef DEBUG_MODE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

#include "linear_fp_input_tile.glslh"

VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) {
LINEAR_FP_INPUT_TILE_VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) {
#ifdef INPUT_BUFFER
return t_input[(m * ntexels_k) + k4];
return LINEAR_FP_INPUT_TILE_VEC4_T(t_input[(m * ntexels_k) + k4]);
#else
return texelFetch(t_input, ivec3(k4, m, 0), 0);
return LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_input, ivec3(k4, m, 0), 0));
#endif
}

Expand Down Expand Up @@ -53,7 +53,7 @@ void load_input_tile_with_checks(
if (m_start + m < M && k4_start + k4 < K4) {
in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4);
} else {
in_tile.data[m][k4] = VEC4_T(0.0);
in_tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,31 @@
* Macro Settings:
* - TILE_M
* - TILE_N4
*
* Optional:
* - LINEAR_FP_OUTPUT_TILE_VEC4_T — accumulator vec4 type (default: VEC4_T).
* Set this to `vec4` to force fp32 accumulation regardless of DTYPE; used
* by fused SDPA QK to avoid fp16 overflow in Q@K^T.
*/

#ifndef LINEAR_FP_OUTPUT_TILE_GLSLH
#define LINEAR_FP_OUTPUT_TILE_GLSLH

#extension GL_EXT_control_flow_attributes : require

#ifndef LINEAR_FP_OUTPUT_TILE_VEC4_T
#define LINEAR_FP_OUTPUT_TILE_VEC4_T VEC4_T
#define LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT
#endif

struct FPOutTile {
VEC4_T data[TILE_M][TILE_N4];
LINEAR_FP_OUTPUT_TILE_VEC4_T data[TILE_M][TILE_N4];
};

void initialize(out FPOutTile out_tile) {
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
out_tile.data[m][n4] = VEC4_T(0);
out_tile.data[m][n4] = LINEAR_FP_OUTPUT_TILE_VEC4_T(0);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
#include "linear_fp_per_out_channel_params.glslh"
#include "linear_fp_weight_tile.glslh"

#if defined(LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT) == defined(LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT)
#define MAYBE_CAST_WVEC4(x) (x)
#else
#define MAYBE_CAST_WVEC4(x) LINEAR_FP_OUTPUT_TILE_VEC4_T(x)
#endif

void fp_accumulate_with_fp_weight(
inout FPOutTile accum,
FPInputTile in_tile,
Expand All @@ -29,23 +35,23 @@ void fp_accumulate_with_fp_weight(
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
accum.data[m][n4] =
fma(VEC4_T(in_tile.data[m][k4][0]),
w_tile.data[mul_4(k4)][n4],
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][0]),
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4)][n4]),
accum.data[m][n4]);

accum.data[m][n4] =
fma(VEC4_T(in_tile.data[m][k4][1]),
w_tile.data[mul_4(k4) + 1][n4],
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][1]),
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 1][n4]),
accum.data[m][n4]);

accum.data[m][n4] =
fma(VEC4_T(in_tile.data[m][k4][2]),
w_tile.data[mul_4(k4) + 2][n4],
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][2]),
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 2][n4]),
accum.data[m][n4]);

accum.data[m][n4] =
fma(VEC4_T(in_tile.data[m][k4][3]),
w_tile.data[mul_4(k4) + 3][n4],
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][3]),
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 3][n4]),
accum.data[m][n4]);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
#include "linear_fp_output_tile.glslh"

void write_output_x4(
const VEC4_T out_texel,
const LINEAR_FP_OUTPUT_TILE_VEC4_T out_texel,
const int n4,
const int m,
const int N4) {
#ifdef OUTPUT_BUFFER
t_output[m * N4 + n4] = out_texel;
t_output[m * N4 + n4] = VEC4_T(out_texel);
#else
imageStore(t_output, ivec3(n4, m, 0), out_texel);
imageStore(t_output, ivec3(n4, m, 0), VEC4_T(out_texel));
#endif
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@

#include "linear_fp_weight_tile.glslh"

VEC4_T load_packed_weight_x4(
LINEAR_FP_WEIGHT_TILE_VEC4_T load_packed_weight_x4(
const int n4, const int dk, const int k4, const int b, const int K4, const int N4) {
#ifdef WEIGHT_BUFFER
return t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk];
return LINEAR_FP_WEIGHT_TILE_VEC4_T(t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk]);
#else
return VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0));
return LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0));
#endif
}

Expand Down Expand Up @@ -65,7 +65,7 @@ void load_packed_weight_tile_with_checks(
if (k4 < K4 && n4_start + n4 < N4) {
tile.data[k][n4] = load_packed_weight_x4(n4_start + n4, dk, k4, b, K4, N4);
} else {
tile.data[k][n4] = VEC4_T(0);
tile.data[k][n4] = LINEAR_FP_WEIGHT_TILE_VEC4_T(0);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
* Macro Settings:
* - TILE_K
* - TILE_N4
*
* Optional:
* - LINEAR_FP_WEIGHT_TILE_VEC4_T — weight tile vec4 type (default: VEC4_T).
*/

#ifndef LINEAR_FP_WEIGHT_TILE_GLSLH
Expand All @@ -19,8 +22,13 @@

#include "common.glslh"

#ifndef LINEAR_FP_WEIGHT_TILE_VEC4_T
#define LINEAR_FP_WEIGHT_TILE_VEC4_T VEC4_T
#define LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT
#endif

struct FPWeightTile {
VEC4_T data[TILE_K][TILE_N4];
LINEAR_FP_WEIGHT_TILE_VEC4_T data[TILE_K][TILE_N4];
};

#ifdef DEBUG_MODE
Expand Down
Loading
Loading