Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/llm/custom_ops/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ runtime.python_test(
],
preload_deps = [
":custom_ops_aot_lib",
":custom_ops_aot_py",
],
deps = [
"//caffe2:torch",
Expand Down
10 changes: 9 additions & 1 deletion extension/llm/custom_ops/test_quantized_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import torch
import torch.nn.functional as F

from .custom_ops import custom_ops_lib # noqa
from executorch.extension.llm.custom_ops import custom_ops # noqa


def is_fbcode():
return not hasattr(torch.version, "git_version")


class SDPATestForCustomQuantizedSDPA(unittest.TestCase):
Expand Down Expand Up @@ -343,6 +347,7 @@ def _test_sdpa_common(
v_scale_fp32,
is_seq_at_dim_2,
)
print((ref_output - op_output).abs().max())
self.assertTrue(torch.allclose(ref_output, op_output, atol=atol))
# Following line crashes due to some weird issues in mkldnn with crash in mkl_sgemm with `wild jump`
# self.assertTrue(torch.allclose(ref_output, quantized_sdpa_ref_output, atol=1e-3))
Expand Down Expand Up @@ -386,6 +391,9 @@ def _test_sdpa_common(
)
self.assertTrue(torch.allclose(ref_output, op_output, atol=atol))

@unittest.skipIf(
not is_fbcode(), "in OSS error is too large 0.0002 for some reason"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does ""in OSS error is too large 0.0002 for some reason" mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I ran this test in OSS error i got was 0.0002. Which is too large for this test to pass. It does pass internally so not sure if this is blas lib issue or mac vs linux. Bug I wanted to enable the tests in oss since we have to coverage otherwise

)
def test_sdpa_with_custom_quantized(self):
n_heads_kv = 8
n_heads_q = 8
Expand Down
9 changes: 8 additions & 1 deletion extension/llm/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import torch
import torch.nn.functional as F

from .custom_ops import custom_ops_lib # noqa
from executorch.extension.llm.custom_ops import custom_ops # noqa


def is_fbcode():
return not hasattr(torch.version, "git_version")


def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len):
Expand Down Expand Up @@ -604,6 +608,9 @@ def test_sdpa_with_cache_seq_len_llava_example(self):
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
)

@unittest.skipIf(
not is_fbcode(), "in OSS error is too large 0.0004 for some reason"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's enable in OSS. Otherwise we will never enable in OSS.

We can finetune atol and rtol during comparison.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am little reluctant on this because on one hand yes we can make it pass by tuning this knob, but on the other hand if there is an issue we wont catch it. My preference would be to have this test at least running internally without compromise. I dont know the reason as to why it requires low atol

)
def test_sdpa_with_cache_seq_len_130_gqa(self):
n_heads_kv = 8
n_heads_q = 32
Expand Down
2 changes: 2 additions & 0 deletions extension/llm/custom_ops/test_update_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

from executorch.extension.llm.custom_ops import custom_ops # noqa


def run_in_subprocess(target):
"""
Expand Down
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ addopts =
# extension/
extension/llm/modules/test
extension/llm/export
extension/llm/custom_ops/test_sdpa_with_kv_cache.py
extension/llm/custom_ops/test_update_cache.py
extension/llm/custom_ops/test_quantized_sdpa.py
extension/pybindings/test
extension/training/pybindings/test
# Runtime
Expand Down
Loading