Skip to content

Conversation

jamesjwu
Copy link
Contributor

@jamesjwu jamesjwu commented May 6, 2024

Stack from ghstack (oldest at bottom):

This is the first PR in a series where I try to organize our runtime wrappers a bit: specifically, I'd like to separate wrappers into objects that have (up to) 2 methods:
A pre-compile function, which takes in flat_fn and flat_args (inputs to the compiler) and wraps/modifies them
A post-compile function, which takes in a compiled_fn and runtime args and wraps the compiled_function.

Extra metadata necessary to run the compile functions can be stored on the attributes of the class. This way, when we think about caching, the set of attributes on the class should be the exact set of metadata that we need to serialize and save in the cache (along with common data, like fw_metadata)

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125595

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 2d11a83 with merge base failed to retrieve merge base, please contact dev infra:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jamesjwu added 2 commits May 6, 2024 11:12
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
jamesjwu added 2 commits May 6, 2024 12:28
[ghstack-poisoned]
[ghstack-poisoned]
@jamesjwu jamesjwu requested a review from bdhirsh May 7, 2024 15:16
@jamesjwu jamesjwu marked this pull request as draft May 7, 2024 15:18
[ghstack-poisoned]
@jamesjwu jamesjwu marked this pull request as ready for review May 7, 2024 15:46
trace_joint=False,
keep_input_mutations=aot_config.keep_inference_input_mutations,
disable_amp=disable_amp,
).post_compile(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly just a callout: for the most part we have a 1:1 relationship between "pre_compile and post_compile transformations. But for the particular case of this runtime_wrapper bit, the transformation is something like:

in the inference path: aot_dispatch_base() makes the compile-time change, and uses create_runtime_wrapper to make the runtime change

in the training path: aot_dispatch_autograd() makes the compile-time change, and uses a combination of CompiledFunction + create_runtime_wrapper to make the runtime change.

So the actual "post_compile" runtime wrapper change needed between the inference and training paths have some commonalities, that are shared in the the create_runtime_wrapper helper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are refactoring all of the runtime wrappers to use CompilerWrapper, then I'm mostly just calling out that I would have originally expected every instance to have both a pre_compile and post_compile, and I guess we have just a post_compile here since it lines up with how the runtime wrappers are used today / makes the refactor easier.

Do you imagine a better end-state being that we eventually refactor the wrapper logic so there is more of a 1:1 between pre-compile and post-compile for each layer? Or do you think this is ok as an end state (I think this refactor looks good to me though, happy to stamp)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so from what I can tell, the runtime change is only made in create_runtime_wrapper, is it not? Do you mind giving an example of somewhere where CompiledFunction is making a runtime change related to what create_runtime_wrapper is doing? I had thought that the only runtime change create_runtime_wrapper is responsible for was in the two places it's called in aot_dispatch_autograd and aot_dispatch_base, both of which are refactored here.

Just based on my own understanding, I had thought that create_runtime_wrapper was a runtime only wrapper (in that it does not seem to affect the flat_fn state), but if you could point me to what you would consider the "pre_compile" step of create_runtime_wrapper (i.e., where it's modifying the flat_args or flat_fn), I can follow up and refactor that bit too!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this more, I think the point you're makign is that aot_dispatch_base and aot_dispatch_autograd, are, in a sense, themselves just big wrappers, and create_runtime_wrapper is actually just a post_compile for the precompile step that exists somewhere in the code of aot_dispatch_*.

I do think it would be valuable to disentangle the logic of aot_dispatch_* that has to do with whatever create_runtime_wrappers is doing so that it can be expressed as a pre_compile for the new wrapper. Happy to work on that, though I might prioritize it a bit lower just because it's not necessary for caching.

I think the same can be said about the other "post compile only" wrappers that I define in the next PR: there's probably some pre-compile logic in aot_dispatch_* that could be put into a pre compile step, but it would require a more involved refactor. I even tried to do it for rng functionalization, but found it to be pretty challenging to get right just due to the sheer number of tangled dependencies within aot_dispatch_autograd.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but found it to be pretty challenging to get right just due to the sheer number of tangled dependencies within aot_dispatch_autograd.

yeah... agreed 😛

flat_args: List[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just another callout - right now, we have:

(1) multiple transformations (each of which will eventually get its own CompilerWrapper according to this PR)
(2) each transformation requires some metadata

And some of that metadata is shared across transformations (like aot_config and fw_metadata, which you're uniformly passing into every instance of CompilerWrapper, while other metadata is specific to a certain transformation (like indices_of_inps_to_detach).

I definitely wouldn't block your current refactor on this extra refactor, but I do wonder if we should aim to go more in either of those directions: either each CompilerWrapper gets exactly the metadata it cares about and nothing is shared, or we glob all metadata into a single shared object that is plumbed around everywhere (we definitely lean more in this direction today)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is a bit tricky to get right I think: I don't really think it's possible to make each CompilerWrapper completely independent, simply because the specific fields that are shared are also modified by each CompilerWrapper's pre_compile functions, so they're not actually independent.

Having a single metadata structure represent all of the shared info would definitely be nice, and also is pretty much what a cache entry would look like, so I might end up doing that refactor when I actually create cache entries(basically, passing along a single data structure with all the fields shared between CompilerWrappers).

Though it's also kind of nice, code readability wise, for CompilerWrappers to have statically defined parameters for fields like indices_of_inps_to_detach instead of a single "metadata" field: it means if you change metadata it's easier to find out which wrappers consume the data, instead of having to hunt through each usage in the code.

I think the compromise here might be that there is a single object we plumb through aot autograd, but the compilerWrappers themselves still just take individual arguments: note that none of these CompilerWrapper objects should need to be directly plumbed outside the function they're defined in: in a perfect world or end state, we should just have a series of pre compile steps, followed by actual compilation (aot_dispatch_base/autograd), followed by a series of post compile steps.

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@jamesjwu jamesjwu added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category and removed release notes: AO frontend labels May 8, 2024
@jamesjwu
Copy link
Contributor Author

jamesjwu commented May 8, 2024

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: inductor / cuda12.1-py3.10-gcc9-sm86 / test (dynamic_inductor_timm, 2, 2, linux.g5.4xlarge.nvidia.gpu), inductor / rocm6.1-py3.8-inductor / test (inductor, 1, 1, linux.rocm.gpu.2)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/jamesjwu/21/head branch June 8, 2024 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants