-
Notifications
You must be signed in to change notification settings - Fork 565
Open
Labels
pallaspytorch divergenceXLA behavior doesn't match Pytorch eager frontendXLA behavior doesn't match Pytorch eager frontendxla:tpuTPU specific issues and PRsTPU specific issues and PRs
Description
🐛 Bug
Hey, I have found consistent mismatch between the output of the flash_attention impl wrt torch F.scaled_dot_product_attention for the default non-causal case.
Is this still within the accuracy margin you were targeting..?
Thanks for your help!
To Reproduce
scale = 1.0
for _ in range(10):
q = torch.randn(1, 16, 577, 64, dtype=torch.bfloat16).to(xm.xla_device())
k = torch.randn(1, 16, 577, 64, dtype=torch.bfloat16).to(xm.xla_device())
v = torch.randn(1, 16, 577, 64, dtype=torch.bfloat16).to(xm.xla_device())
output_torch = F.scaled_dot_product_attention(q, k, v, scale=scale)
xm.mark_step()
output = torch_xla.experimental.custom_kernel.flash_attention(q, k, v, sm_scale=scale)
xm.mark_step()
try:
torch.testing.assert_close(output, output_torch)
except AssertionError as e:
print(e, '\n\n')Outputs:
Mismatched elements: 579871 / 590848 (98.1%)
Greatest absolute difference: 0.0081862211227417 at index (0, 8, 453, 21) (up to 1e-05 allowed)
Greatest relative difference: 144.22958374023438 at index (0, 8, 379, 43) (up to 1.3e-06 allowed)
Tensor-likes are not close!
Mismatched elements: 579790 / 590848 (98.1%)
Greatest absolute difference: 0.008110404014587402 at index (0, 4, 502, 13) (up to 1e-05 allowed)
Greatest relative difference: 850.0736083984375 at index (0, 3, 489, 60) (up to 1.3e-06 allowed)
Tensor-likes are not close!
Mismatched elements: 579248 / 590848 (98.0%)
Greatest absolute difference: 0.008902788162231445 at index (0, 9, 145, 7) (up to 1e-05 allowed)
Greatest relative difference: 589.856201171875 at index (0, 15, 274, 8) (up to 1.3e-06 allowed)
Tensor-likes are not close!
Mismatched elements: 579571 / 590848 (98.1%)
Greatest absolute difference: 0.010603189468383789 at index (0, 8, 3, 49) (up to 1e-05 allowed)
Greatest relative difference: 332.93280029296875 at index (0, 0, 367, 32) (up to 1.3e-06 allowed)
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v6
- torch_xla version:
Name: torch
Version: 2.8.0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: compressed-tensors, outlines, xgrammar
---
Name: torch-xla
Version: 2.8.0+git4190fc0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: absl-py, numpy, pyyaml, requests
Required-by:
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
pallaspytorch divergenceXLA behavior doesn't match Pytorch eager frontendXLA behavior doesn't match Pytorch eager frontendxla:tpuTPU specific issues and PRsTPU specific issues and PRs