-
Notifications
You must be signed in to change notification settings - Fork 25.6k
SDP Backend function fix #161169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SDP Backend function fix #161169
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161169
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fa24351 with merge base 3d40642 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@ahkush, can you add a test case? |
@guilhermeleobas, I've added a test case as requested. Please let me know if any other updates are needed. Thanks! |
test/dynamo/test_sdpa.py
Outdated
|
||
def test_sdpa_c_functions_no_graph_break(self): | ||
|
||
counter = CompileCounter() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test using the repro from the issue?
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
SDPA_BACKEND_PRIORITY = [
SDPBackend.MATH,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.FLASH_ATTENTION,
]
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
@torch.compile(fullgraph=True)
def f(x):
return scaled_dot_product_attention(x, x, x)
x = torch.rand(128, 64, 64, 256, dtype=torch.float16, device='cuda')
f(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do that. Shall I replace the one I added with the repro in the issue or add it as another test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its up to you. If they test the same thing, then you can remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added it as a separate test since they test different scenarios - one tests the direct function call, the other tests the original decorator usage pattern.
I think this needs rebasing with main. Also, can you run lintrunner? First run |
69d1c26
to
fa24351
Compare
@guilhermeleobas, Thanks for the feedback! I've rebased with main and ran lintrunner init followed by lintrunner -a to format the files. |
Hi @ahkush, thanks for your contribution. I just approved, but you need a second approval before it can be merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, but please unlink this PR from #160691. The intent of that issue is to highlight bad error message during context manager tracing, and the example was mainly to illustrate that issue:).
@pytorchbot merge |
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
@pytorchbot merge |
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
@StrongerXi, can you merge this one? |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The issue cannot be reproduced using the original repro code provided in the issue description. However, the underlying issue mentioned by the maintainer (missing functions in `builder.py` and `trace_rules.py`) was never addressed and can still be reproduced with this test case: ```python import torch from torch.nn.attention import _cur_sdpa_kernel_backends @torch.compile(fullgraph=True) def test_function_that_triggers_error(): return _cur_sdpa_kernel_backends() print("Calling torch.compile function...") try: result = test_function_that_triggers_error() print(f"Success: {result}") except Exception as e: print(f"ERROR: {e}") print(f"Error type: {type(e)}") ``` The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to `_cur_sdpa_kernel_backends()` exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing. I have implemented the changes by adding the missing functions to both `builder.py` and `trace_rules.py` to properly handle these cases during compilation. @guilhermeleobas Pull Request resolved: pytorch#161169 Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi
The issue cannot be reproduced using the original repro code provided in the issue description. However, the underlying issue mentioned by the maintainer (missing functions in `builder.py` and `trace_rules.py`) was never addressed and can still be reproduced with this test case: ```python import torch from torch.nn.attention import _cur_sdpa_kernel_backends @torch.compile(fullgraph=True) def test_function_that_triggers_error(): return _cur_sdpa_kernel_backends() print("Calling torch.compile function...") try: result = test_function_that_triggers_error() print(f"Success: {result}") except Exception as e: print(f"ERROR: {e}") print(f"Error type: {type(e)}") ``` The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to `_cur_sdpa_kernel_backends()` exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing. I have implemented the changes by adding the missing functions to both `builder.py` and `trace_rules.py` to properly handle these cases during compilation. @guilhermeleobas Pull Request resolved: pytorch#161169 Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi
The issue cannot be reproduced using the original repro code provided in the issue description. However, the underlying issue mentioned by the maintainer (missing functions in `builder.py` and `trace_rules.py`) was never addressed and can still be reproduced with this test case: ```python import torch from torch.nn.attention import _cur_sdpa_kernel_backends @torch.compile(fullgraph=True) def test_function_that_triggers_error(): return _cur_sdpa_kernel_backends() print("Calling torch.compile function...") try: result = test_function_that_triggers_error() print(f"Success: {result}") except Exception as e: print(f"ERROR: {e}") print(f"Error type: {type(e)}") ``` The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to `_cur_sdpa_kernel_backends()` exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing. I have implemented the changes by adding the missing functions to both `builder.py` and `trace_rules.py` to properly handle these cases during compilation. @guilhermeleobas Pull Request resolved: pytorch#161169 Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi
The issue cannot be reproduced using the original repro code provided in the issue description.
However, the underlying issue mentioned by the maintainer (missing functions in
builder.py
andtrace_rules.py
) was never addressed and can still be reproduced with this test case:The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to
_cur_sdpa_kernel_backends()
exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing.I have implemented the changes by adding the missing functions to both
builder.py
andtrace_rules.py
to properly handle these cases during compilation.@guilhermeleobas
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos