Skip to content

torch_xla.experimental.custom_kernel.flash_attention output does not match F.scaled_dot_product_attention on TPU #8869

@NickLucche

Description

@NickLucche

🐛 Bug

Hey, I have found consistent mismatch between the output of the flash_attention impl wrt torch F.scaled_dot_product_attention for the default non-causal case.
Is this still within the accuracy margin you were targeting..?

Thanks for your help!

To Reproduce

scale = 1.0
for _ in range(10):
    q = torch.randn(1, 16, 577, 64, dtype=torch.bfloat16).to(xm.xla_device())
    k = torch.randn(1, 16, 577, 64, dtype=torch.bfloat16).to(xm.xla_device())
    v = torch.randn(1, 16, 577, 64, dtype=torch.bfloat16).to(xm.xla_device())

    output_torch = F.scaled_dot_product_attention(q, k, v, scale=scale)
    xm.mark_step()
    output = torch_xla.experimental.custom_kernel.flash_attention(q, k, v, sm_scale=scale)
    xm.mark_step()
    try:
        torch.testing.assert_close(output, output_torch)
    except AssertionError as e:
        print(e, '\n\n')

Outputs:

Mismatched elements: 579871 / 590848 (98.1%)
Greatest absolute difference: 0.0081862211227417 at index (0, 8, 453, 21) (up to 1e-05 allowed)
Greatest relative difference: 144.22958374023438 at index (0, 8, 379, 43) (up to 1.3e-06 allowed) 


Tensor-likes are not close!

Mismatched elements: 579790 / 590848 (98.1%)
Greatest absolute difference: 0.008110404014587402 at index (0, 4, 502, 13) (up to 1e-05 allowed)
Greatest relative difference: 850.0736083984375 at index (0, 3, 489, 60) (up to 1.3e-06 allowed) 


Tensor-likes are not close!

Mismatched elements: 579248 / 590848 (98.0%)
Greatest absolute difference: 0.008902788162231445 at index (0, 9, 145, 7) (up to 1e-05 allowed)
Greatest relative difference: 589.856201171875 at index (0, 15, 274, 8) (up to 1.3e-06 allowed) 


Tensor-likes are not close!

Mismatched elements: 579571 / 590848 (98.1%)
Greatest absolute difference: 0.010603189468383789 at index (0, 8, 3, 49) (up to 1e-05 allowed)
Greatest relative difference: 332.93280029296875 at index (0, 0, 367, 32) (up to 1.3e-06 allowed) 

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v6
  • torch_xla version:
Name: torch
Version: 2.8.0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: compressed-tensors, outlines, xgrammar
---
Name: torch-xla
Version: 2.8.0+git4190fc0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: absl-py, numpy, pyyaml, requests
Required-by:

Metadata

Metadata

Assignees

No one assigned

    Labels

    pallaspytorch divergenceXLA behavior doesn't match Pytorch eager frontendxla:tpuTPU specific issues and PRs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions