Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(fx): make all make_fx invocations isolated (opaque to higher make_fx invocations) by default #93290

Conversation

jon-chuang
Copy link
Collaborator

@jon-chuang jon-chuang commented Jan 30, 2023

Fixes #88996 (comment)

Example code:

import torch
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx

@torch.fx.wrap
def func(a, b):
    return b.expand([1, a.shape[0], b.shape[-1]])

a = torch.randn(3, 4)
b = torch.randn(4)

class TestMode(torch.overrides.TorchFunctionMode):
    def __torch_function__(self, func, types, args=(), kwargs={}):
        if torch.overrides.resolve_name(func) in ["torch.Tensor.expand"]:
            print(f"TestMode: {func} {args} {kwargs}")
            wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs)
            gm = make_fx(wrapped, tracing_mode="real")(all_args)

        return func(*args, **kwargs)

with TestMode():
    gm = make_fx(func, tracing_mode="symbolic")(a, b)

gm.graph.print_tabular()

Before:

opcode         name        target               args                              kwargs
-------------  ----------  -------------------  --------------------------------  --------
placeholder    a_1         a_1                  ()                                {}
placeholder    b_1         b_1                  ()                                {}
call_function  detach      aten.detach.default  (b_1,)                            {}
call_function  detach_1    aten.detach.default  (detach,)                         {}
call_function  sym_size    aten.sym_size        (a_1, 0)                          {}
call_function  sym_size_1  aten.sym_size        (b_1, 0)                          {}
call_function  expand      aten.expand.default  (b_1, [1, sym_size, sym_size_1])  {}
call_function  detach_2    aten.detach.default  (expand,)                         {}
call_function  expand_1    aten.expand.default  (b_1, [1, sym_size, sym_size_1])  {}
output         output      output               (expand_1,)                       {}

After:

opcode         name        target               args                              kwargs
-------------  ----------  -------------------  --------------------------------  --------
placeholder    a_1         a_1                  ()                                {}
placeholder    b_1         b_1                  ()                                {}
call_function  sym_size    aten.sym_size        (a_1, 0)                          {}
call_function  sym_size_1  aten.sym_size        (b_1, 0)                          {}
call_function  expand      aten.expand.default  (b_1, [1, sym_size, sym_size_1])  {}
output         output      output               (expand_1,)                       {}

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 30, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93290

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9b7b95d:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ezyang
Copy link
Contributor

ezyang commented Jan 31, 2023

I... guess we can do this? It kind of feels better to not by default (and let get_isolated_subgraph be used for this case) because if you do it this way, there is no way to have a single trace get recorded by multiple proxy tensor modes at once. This may seem like a weird thing to want but in fact functorch nested grad does something like this (where nested grad levels get recorded onto multiple tapes one per grad level).

The countervailing argument is make_fx doesn't return its outputs, so therefore it is not true compute and shouldn't get traced. I could be convinced by this.

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Jan 31, 2023

Actually, when we call make_fx nested in another make_fx, the inner make_fx will behave exactly like get_isolated_submodule modulo kwargs wrapping.

So this PR will isolate graphmodules by default.

I believe this is a good change as making a tape of making a tape doesn't seem sensible, and produces the weird artifacts observed.

(Unless: an inner tape is used in the definition of the exterior function?)

@jon-chuang jon-chuang changed the title fix(fx): disable other make_fx traces for nested make_fx fix(fx): disable all make_fx traces except current for nested make_fx Jan 31, 2023
@jon-chuang jon-chuang changed the title fix(fx): disable all make_fx traces except current for nested make_fx fix(fx): make all make_fx invocations isolated (opaque to higher make_fx invocations) by default Jan 31, 2023
@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Jan 31, 2023

(Unless: an inner tape is used in the definition of the exterior function?)

This is actually impossible. make_fx arrives at its output through mutation, so its output has no functional dependence on its inputs. Furthermore, it does not return any tensors as you point out.

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Jan 31, 2023

functorch nested grad does something like this (where nested grad levels get recorded onto multiple tapes one per grad level).

Could you point to an example? I tried constructing several examples where things might fail, including nesting torch.func.grad and torch.autograd.grad but:

  1. for the former, make_fx is always just called once, the invocation produces the compiled function as a side effect.
  2. for the latter, autograd.grad is actually unnestable and produces a None result...

There is basically no compilation pathway that has nested tapes that rely directly on an interior tape. Because it seems that torch.autograd.autograd itself does not rely on tapes from make_fx, it uses a C++ torch._C._ImperativeEngine.

@ezyang
Copy link
Contributor

ezyang commented Jan 31, 2023

Could you point to an example? I tried constructing several examples where things might fail, including nesting torch.func.grad and torch.autograd.grad but:

It's not make_fx per se, but tape recording (which is like fx but a bit different). https://github.com/albanD/subclass_zoo/blob/main/simple_functorch.ipynb search for "to compute higher order gradients".

But yeah, consider me convinced, we can land this

@ezyang
Copy link
Contributor

ezyang commented Jan 31, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 31, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed (Rule superuser). The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor

ezyang commented Jan 31, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed (Rule superuser). The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor

ezyang commented Feb 1, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: fx release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Nested FX tracing doesn't work when outer tracing mode is symbolic
4 participants