Skip to content
41 changes: 22 additions & 19 deletions examples/models/llama2/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,28 @@ def forward(
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
# Note that this path will still inplace mutate the k_cache, v_cache.
# WHen we are not using quantized kv cache, this will just mutate
# the original kv cache.
# When we aer using quantized kv cache, this will mutate
# k_cache, v_cache that is returned from cache update operation.
# This operation just dequantized thee cache and returns that.
# Future diffs will optimize this
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
k_cache,
v_cache,
input_pos[-1].item(),
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
output = torch.ops.llama.custom_sdpa(
q,
k_cache,
v_cache,
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
else:
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
k_cache,
v_cache,
input_pos[0].item(),
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def test_simple(self, is_dynamic_shape=False):
torch.testing.assert_close(
float_out,
quantized_out,
# had to adjust rtol because switching to using custom_sdpa means we
# will use dequantized k and v instead of original k and v
# this leads to larger differences in the output.
# subsequent diff in the stack will address this issue.
rtol=1e-01,
atol=1e-03,
)

input_pos = torch.tensor([3], dtype=torch.int64)
Expand Down
141 changes: 73 additions & 68 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,74 @@ void update_cache(
}
}

} // anonymous namespace

Tensor& flash_attention_kernel_out(
RuntimeContext& ctx,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
(void)ctx;
ET_KERNEL_CHECK(
ctx,
validate_flash_attention_args(query, key, value, attn_mask),
InvalidArgument,
output);

ET_KERNEL_CHECK(
ctx,
resize_tensor(output, query.sizes()) == Error::Ok,
InvalidArgument,
output);

auto q_seq_len = query.size(2);

ET_SWITCH_FLOAT_TYPES(
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (q_seq_len >= 768) {
cpu_flash_attention<CTYPE, 256, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention<CTYPE, 64, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
} else {
cpu_flash_attention<CTYPE, 32, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
}
});
return output;
}

/*
Input params
@param[in] q_projected Projected query with query weights.
Expand Down Expand Up @@ -900,74 +968,6 @@ Tensor& custom_sdpa_out(
});
return output;
}
} // anonymous namespace

Tensor& flash_attention_kernel_out(
KernelRuntimeContext& ctx,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
(void)ctx;
ET_KERNEL_CHECK(
ctx,
validate_flash_attention_args(query, key, value, attn_mask),
InvalidArgument,
output);

ET_KERNEL_CHECK(
ctx,
resize_tensor(output, query.sizes()) == Error::Ok,
InvalidArgument,
output);

auto q_seq_len = query.size(2);

ET_SWITCH_FLOAT_TYPES(
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (q_seq_len >= 768) {
cpu_flash_attention<CTYPE, 256, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention<CTYPE, 64, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
} else {
cpu_flash_attention<CTYPE, 32, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale);
}
});
return output;
}

/*
Input params
@param[in] q_projected Projected query with query weights.
Expand Down Expand Up @@ -1033,3 +1033,8 @@ EXECUTORCH_LIBRARY(
llama,
"sdpa_with_kv_cache.out",
torch::executor::native::sdpa_with_kv_cache_out);

EXECUTORCH_LIBRARY(
llama,
"custom_sdpa.out",
torch::executor::native::custom_sdpa_out);
13 changes: 13 additions & 0 deletions extension/llm/custom_ops/op_sdpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ Tensor& sdpa_with_kv_cache_out(
const optional<double> scale,
Tensor& output);

Tensor& custom_sdpa_out(
RuntimeContext& ctx,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output);

Tensor& flash_attention_kernel_out(
KernelRuntimeContext& ctx,
const Tensor& query,
Expand Down
62 changes: 58 additions & 4 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,51 @@ at::Tensor sdpa_with_kv_cache_aten(
return output;
}

Tensor& custom_sdpa_out_no_context(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
exec_aten::RuntimeContext context{};
return torch::executor::native::custom_sdpa_out(
context,
q,
k,
v,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
output);
}

at::Tensor custom_sdpa_aten(
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
const int64_t start_pos,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const c10::optional<at::Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const c10::optional<double> scale) {
auto output = at::empty_like(q);
WRAP_TO_ATEN(custom_sdpa_out_no_context, 8)
(q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
return output;
}

Tensor& update_quantized_cache_out_no_context(
const Tensor& value,
Tensor& cache,
Expand Down Expand Up @@ -115,6 +160,14 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
m.def(
"custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None) -> Tensor");
m.def(
"custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"update_quantized_cache(Tensor value, Tensor(a!) cache, "
"SymInt start_pos) -> Tensor");
Expand All @@ -123,17 +176,18 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
}

// TODO: Rename this file to op_custom_ops_aot.cpp
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl(
"sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten);
m.impl(
"sdpa_with_kv_cache.out",
WRAP_TO_ATEN(
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
}

// TODO: Rename this file to op_custom_ops_aot.cpp
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten);
m.impl(
"custom_sdpa.out",
WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8));
m.impl(
"update_quantized_cache",
torch::executor::native::update_quantized_cache_aten);
Expand Down
29 changes: 29 additions & 0 deletions extension/llm/custom_ops/sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,35 @@ def fast_hadamard_transform_meta(mat):
return torch.empty_like(mat)


@impl(custom_ops_lib, "custom_sdpa", "Meta")
def custom_sdpa(
query,
key_cache,
value_cache,
start_pos,
attn_mask=None,
drpout_p=0.0,
is_causal=False,
scale=None,
):
seq_len = query.size(1)
_validate_params(
query,
key_cache,
value_cache,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
drpout_p,
is_causal,
scale,
)

return torch.empty_like(query)


def _validate_update_cache_params(
value,
cache,
Expand Down
Loading