Skip to content

Conversation

ahkush
Copy link
Contributor

@ahkush ahkush commented Aug 21, 2025

The issue cannot be reproduced using the original repro code provided in the issue description.

However, the underlying issue mentioned by the maintainer (missing functions in builder.py and trace_rules.py) was never addressed and can still be reproduced with this test case:

import torch
from torch.nn.attention import _cur_sdpa_kernel_backends

@torch.compile(fullgraph=True)
def test_function_that_triggers_error():
    return _cur_sdpa_kernel_backends()

print("Calling torch.compile function...")
try:
    result = test_function_that_triggers_error()
    print(f"Success: {result}")
except Exception as e:
    print(f"ERROR: {e}")
    print(f"Error type: {type(e)}")

The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to _cur_sdpa_kernel_backends() exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing.

I have implemented the changes by adding the missing functions to both builder.py and trace_rules.py to properly handle these cases during compilation.

@guilhermeleobas

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos

Copy link

pytorch-bot bot commented Aug 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161169

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fa24351 with merge base 3d40642 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@guilhermeleobas
Copy link
Collaborator

@ahkush, can you add a test case?

@ahkush
Copy link
Contributor Author

ahkush commented Aug 21, 2025

@guilhermeleobas, I've added a test case as requested. Please let me know if any other updates are needed. Thanks!

Comment on lines 102 to 104

def test_sdpa_c_functions_no_graph_break(self):

counter = CompileCounter()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test using the repro from the issue?

import torch
from torch.nn.attention import SDPBackend, sdpa_kernel

SDPA_BACKEND_PRIORITY = [
    SDPBackend.MATH,
    SDPBackend.EFFICIENT_ATTENTION,
    SDPBackend.FLASH_ATTENTION,
]

@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
    return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)


@torch.compile(fullgraph=True)
def f(x):
    return scaled_dot_product_attention(x, x, x)

x = torch.rand(128, 64, 64, 256, dtype=torch.float16, device='cuda')
f(x)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do that. Shall I replace the one I added with the repro in the issue or add it as another test?

Copy link
Collaborator

@guilhermeleobas guilhermeleobas Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its up to you. If they test the same thing, then you can remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added it as a separate test since they test different scenarios - one tests the direct function call, the other tests the original decorator usage pattern.

@guilhermeleobas
Copy link
Collaborator

I think this needs rebasing with main. Also, can you run lintrunner? First run lintrunner init and then lintrunner -a to format the changed files.

@ahkush ahkush force-pushed the fix-dynamo-sdpa-backend-functions branch from 69d1c26 to fa24351 Compare August 27, 2025 18:48
@ahkush
Copy link
Contributor Author

ahkush commented Aug 28, 2025

@guilhermeleobas, Thanks for the feedback! I've rebased with main and ran lintrunner init followed by lintrunner -a to format the files.

@guilhermeleobas
Copy link
Collaborator

Hi @ahkush, thanks for your contribution. I just approved, but you need a second approval before it can be merged.

Copy link
Contributor

@StrongerXi StrongerXi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, but please unlink this PR from #160691. The intent of that issue is to highlight bad error message during context manager tracing, and the example was mainly to illustrate that issue:).

@ahkush ahkush changed the title 160691: SDP Backend function fix SDP Backend function fix Sep 8, 2025
@ahkush
Copy link
Contributor Author

ahkush commented Sep 15, 2025

@pytorchbot merge

Copy link

pytorch-bot bot commented Sep 15, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@guilhermeleobas
Copy link
Collaborator

@pytorchbot merge

Copy link

pytorch-bot bot commented Sep 15, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@guilhermeleobas
Copy link
Collaborator

@StrongerXi, can you merge this one?

@StrongerXi
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 19, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
The issue cannot be reproduced using the original repro code provided in the issue description.

However, the underlying issue mentioned by the maintainer (missing functions in `builder.py` and `trace_rules.py`) was never addressed and can still be reproduced with this test case:

```python
import torch
from torch.nn.attention import _cur_sdpa_kernel_backends

@torch.compile(fullgraph=True)
def test_function_that_triggers_error():
    return _cur_sdpa_kernel_backends()

print("Calling torch.compile function...")
try:
    result = test_function_that_triggers_error()
    print(f"Success: {result}")
except Exception as e:
    print(f"ERROR: {e}")
    print(f"Error type: {type(e)}")
```

The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to `_cur_sdpa_kernel_backends()` exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing.

I have implemented the changes by adding the missing functions to both `builder.py` and `trace_rules.py` to properly handle these cases during compilation.

@guilhermeleobas

Pull Request resolved: pytorch#161169
Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
The issue cannot be reproduced using the original repro code provided in the issue description.

However, the underlying issue mentioned by the maintainer (missing functions in `builder.py` and `trace_rules.py`) was never addressed and can still be reproduced with this test case:

```python
import torch
from torch.nn.attention import _cur_sdpa_kernel_backends

@torch.compile(fullgraph=True)
def test_function_that_triggers_error():
    return _cur_sdpa_kernel_backends()

print("Calling torch.compile function...")
try:
    result = test_function_that_triggers_error()
    print(f"Success: {result}")
except Exception as e:
    print(f"ERROR: {e}")
    print(f"Error type: {type(e)}")
```

The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to `_cur_sdpa_kernel_backends()` exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing.

I have implemented the changes by adding the missing functions to both `builder.py` and `trace_rules.py` to properly handle these cases during compilation.

@guilhermeleobas

Pull Request resolved: pytorch#161169
Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
The issue cannot be reproduced using the original repro code provided in the issue description.

However, the underlying issue mentioned by the maintainer (missing functions in `builder.py` and `trace_rules.py`) was never addressed and can still be reproduced with this test case:

```python
import torch
from torch.nn.attention import _cur_sdpa_kernel_backends

@torch.compile(fullgraph=True)
def test_function_that_triggers_error():
    return _cur_sdpa_kernel_backends()

print("Calling torch.compile function...")
try:
    result = test_function_that_triggers_error()
    print(f"Success: {result}")
except Exception as e:
    print(f"ERROR: {e}")
    print(f"Error type: {type(e)}")
```

The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to `_cur_sdpa_kernel_backends()` exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing.

I have implemented the changes by adding the missing functions to both `builder.py` and `trace_rules.py` to properly handle these cases during compilation.

@guilhermeleobas

Pull Request resolved: pytorch#161169
Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants