-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Add a HOP to bypass tracing of a wrapper function while tracing the wrapped function #153487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153487
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 194cdce with merge base 7c9d94e ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Usage: ```python from torch._higher_order_ops.wrap import wrap_generic # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return wrap_generic( fn, *args, wrapper_fn=functools.partial(my_hop_fn_impl, k=k), **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `wrap_generic` should have signature `Callable -> Callable` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
raise RuntimeError( | ||
f"WrapGenericHigherOrderVariable: Unsupported function {type(func_var)}" | ||
) | ||
gmod_kwargs = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the wrapper_fn as kwargs seems pretty fragile like it might collide with inner kwargs. Maybe we could just pass the wrapper fn as an explicit input? Something like:
DynamoBypassingWrapper(wrapper_fn, inner_fn, *inner_args, **inner_kwargs)
The proxy node we created in dynamo could make wrapper_fn a string key that can be used to access gmod.meta.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, updated
Usage: ```python from torch._higher_order_ops.wrap import wrap_generic # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return wrap_generic( fn, *args, wrapper_fn=functools.partial(my_hop_fn_impl, k=k), **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `wrap_generic` should have signature `Callable -> Callable` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Usage: ```python from torch._higher_order_ops.wrap import wrap_generic # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return wrap_generic( fn, *args, wrapper_fn=functools.partial(my_hop_fn_impl, k=k), **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `wrap_generic` should have signature `Callable -> Callable` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
torch/_higher_order_ops/wrap.py
Outdated
import torch._dynamo # noqa: F401 | ||
from torch._dynamo import disable | ||
|
||
is_compiling = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this accurate? An alternative is:
How about we do something like this:
```python
def __call__(sefl, wrapper_fn_or_key, inner_fn, *inner_args, **inner_kwargs):
is_compiling = isinstance(wrapper_fn_or_key, str)
if is_compiling:
assert isinstance(inner_fn, torch.fx.GraphModule)
wrapper_fn = inner_fn.meta[wrapper_fn_or_key]
eles:
wrapper_fn = wrapper_fn_or_key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks this looks cleaner
@@ -87,6 +87,41 @@ def wrapper(): | |||
wrap_with_autocast = WrapWithAutocast() | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe some doc?
…acing the wrapped function" Usage: ```python from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return dynamo_bypassing_wrapper( functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `dynamo_bypassing_wrapper ` should have signature `Callable -> Callable` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
torch/_higher_order_ops/wrap.py
Outdated
@@ -87,6 +87,43 @@ def wrapper(): | |||
wrap_with_autocast = WrapWithAutocast() | |||
|
|||
|
|||
# This HOP allows you to bypass dynamo tracing of the wrapper function while | |||
# still tracing the inner function. | |||
# The wrapper function should receive a single callable argument, and return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: something like "two callables, the first callable being the wrapper_fn which takes a single argument inner_fn. The second callable is inner_fn and the rest of args and kwags are for the inner fn."
…acing the wrapped function" Usage: ```python from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return dynamo_bypassing_wrapper( functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `dynamo_bypassing_wrapper ` should have signature `Callable -> Callable` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
@pytorchbot merge |
…acing the wrapped function" Usage: ```python from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return dynamo_bypassing_wrapper( functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `dynamo_bypassing_wrapper ` should have signature `Callable -> Callable` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
…acing the wrapped function" Usage: ```python from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return dynamo_bypassing_wrapper( functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `dynamo_bypassing_wrapper ` should have signature `Callable -> Callable` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 3, 3, macos-m1-stable) Details for Dev Infra teamRaised by workflow job |
…acing the wrapped function" Usage: ```python from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return dynamo_bypassing_wrapper( functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `dynamo_bypassing_wrapper ` should have signature `Callable -> Callable` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
…acing the wrapped function" Usage: ```python from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return dynamo_bypassing_wrapper( functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `dynamo_bypassing_wrapper ` should have signature `Callable -> Callable` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rapped function (#153487) Usage: ```python from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper # Your ordinary function wrapper def my_hop_fn_impl(fn, *args, k=1, **kwargs): def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, tuple): return (out[0] + k,) return out + k return wrapper # Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph def my_hop_fn(fn, *args, k=1, **kwargs): return dynamo_bypassing_wrapper( functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs ) ``` Notes: - The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn. - The `wrapper_fn` passed to `dynamo_bypassing_wrapper ` should have signature `Callable -> Callable` Pull Request resolved: #153487 Approved by: https://github.com/ydwu4
Stack from ghstack (oldest at bottom):
Usage:
Notes:
wrapper_fn
passed todynamo_bypassing_wrapper
should have signatureCallable -> Callable
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames