-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Dynamo] Fix function overrides #120885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dynamo] Fix function overrides #120885
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/120885
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 183c5bb with merge base 5929d4e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@StellarrZ remember to sign the CLA and comment "/easycla" to get the cla job to retrigger |
|
/easycla |
torch/_dynamo/variables/torch.py
Outdated
| assert not (kwargs or len(args) != 1) | ||
| return ConstantVariable.create( | ||
| any(has_torch_function(a) for a in args[0].unpack_var_sequence(tx)), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you merge it with the other place?
pytorch/torch/_dynamo/variables/torch.py
Lines 398 to 405 in da559c9
| elif self.value in ( | |
| torch.overrides.has_torch_function, | |
| torch.overrides.has_torch_function_variadic, | |
| torch.overrides.has_torch_function_unary, | |
| ): | |
| assert not kwargs | |
| return ConstantVariable.create( | |
| any(has_torch_function(a) for a in args), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably merge it likes this. But IMO the original solution is stricter and cleaner. Please let me know your preference or if there's a better way to go
elif self.value in (
torch.overrides.has_torch_function,
torch.overrides.has_torch_function_variadic,
torch.overrides.has_torch_function_unary,
):
assert not kwargs
elems = (
args[0].unpack_var_sequence(tx)
if len(args) == 1 and isinstance(args[0], TupleVariable)
else args
)
return ConstantVariable.create(
any(has_torch_function(x) for x in elems),
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to update has_torch_function as this, it would cover more cases and more concise:
def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker", tx: "torch._dynamo.symbolic_convert.InstructionTranslatorBase") -> bool:
from torch._dynamo.variables import UserDefinedObjectVariable
from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
if vt.has_unpack_var_sequence(tx):
return any(has_torch_function(v, tx) for v in vt.unpack_var_sequence(tx))
else:
return isinstance(vt, TensorWithTFOverrideVariable) or (
isinstance(vt, UserDefinedObjectVariable)
and hasattr(vt.value, "__torch_function__")
)
|
Thanks @yanboliang for guiding! Unpacking inside Given this, strictly unpack in |
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot label "topic: not user facing" |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
To check existence of `__torch_function__`, the code intended to iterate each element but got `TupleVariable` when the ordinary `has_torch_function()` was being used. Needs further unpack in this case Fixes #120653 Pull Request resolved: #120885 Approved by: https://github.com/yanboliang
To check existence of
__torch_function__, the code intended to iterate each element but gotTupleVariablewhen the ordinaryhas_torch_function()was being used. Needs further unpack in this caseFixes #120653
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @aakhundov