Skip to content

Setup initial testing harness and cache key generation for AOTAutograd Cache #124642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from

Conversation

jamesjwu
Copy link
Contributor

@jamesjwu jamesjwu commented Apr 22, 2024

Stack from ghstack (oldest at bottom):

This doesn't introduce any new behavior, but sets up a basic cache key generation mechanism that I can test. From here I will:

  • Add checks on the ops in an input FXGraph to make sure they are safe to cache. We'll be conservative in the first version here.
  • Add serialization for FX graphs
  • Save these FX graphs to disk in the cache
  • Support graphs with more complicated ops like higher order ops and specialized nn modules

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124642

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d9e2a33 with merge base 935a946 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Apr 22, 2024
ghstack-source-id: 3b232da
Pull Request resolved: #124642
@jamesjwu jamesjwu changed the title Setup initial testing harness for AOTAutograd Cache Setup initial testing harness and cache key generation for AOTAutograd Cache Apr 23, 2024
@jamesjwu jamesjwu requested review from oulgen and bdhirsh April 23, 2024 16:03
@jamesjwu jamesjwu marked this pull request as ready for review April 23, 2024 16:05
@@ -485,25 +485,56 @@ class FxGraphCachePickler(pickle.Pickler):
dispatch_table[torch.Tensor] = _reduce_tensor
dispatch_table[torch.SymInt] = _reduce_symint

@staticmethod
def dumps(obj) -> bytes:
@classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just lets me reuse the pickling implementation

[ghstack-poisoned]
elif isinstance(obj, bytes):
return "<bytes>"
else:
return str(obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is for debugging would it make sense to repr instead of str?

[ghstack-poisoned]
@jamesjwu
Copy link
Contributor Author

@pytorchbot rebase

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Rebased gh/jamesjwu/17/orig onto refs/remotes/origin/viable/strict because #124745 was rebased, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/124642)

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Tried to rebase and push PR #124642, but it was already up to date. Try rebasing against main by issuing:
@pytorchbot rebase -b main

[ghstack-poisoned]
fx_graph = gm
return gm

g = torch.compile(fn, backend=compiler)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want fullgraph=True? if there area graph breaks, compiler() will get called multiple times (and the nonlocal will get overwritten with whatever the last graph was)


def autograd_cache_hash(
gm: torch.fx.GraphModule,
config: AOTConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think caching on the AOTConfig object is mostly reasonable. But just calling out: AOTConfig isn't an input to AOTAutograd - it's mostly a convenience object to bundle up and pass around state about what AOTAutograd should do during tracing, and it contains a mix of inputs to AOTAutograd, and other stuff that is inferred implicitly (is_export=False, aot_id=...).

Just looking at some of the stuff in it:

aot_id: Including AOTConfig.aot_id will probably give us some false negatives (this is just a counter used to uniquely identify each compiled region in a python process).

is_export: AOTAutograd is also used in an export setting where we don't need the warm cache, so this will just always be set to false. num_params_buffers/dynamic_shapes: these are both values that are inferred from looking at the dynamo graph. Then again these are all cheap to cache and won't cause false negatives, so caching on them probably won't matter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm I guess this comment isn't too relevant, since you're already defining a custom reduce for the AOTConfig (so you can explicitly e.g. not hash on aot_id)


def _reduce_aot_config(config: AOTConfig):
"""
Reduce the config to a stable key for caching.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm - if we care about non-inductor backends working with the cache, then we probably want to include fw_compiler/bw_compiler/decomposition_table in some form here, right?

That is kind of annoyingly redundant for the inductor case - we are already caching on the inductor source code, so caching on the exact e.g. decomposition_table arguments is redundant. But two use cases are:

(1) if someone who is not inductor is using AOTAutograd as their backend, we probably want to recompile if they change the decompositions they pass in

(2) There are two nice backends that we have for debugging: torch.compile(backend="aot_eager") vs. torch.compile(backend="aot_eager_decomp_partition") that are mostly the same exact that they use different decomposition tables and partitioning functions. We probably want the cache key to encode their differences

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I think we do need to add some combination of these, though computing a hash key from the callables directly seems hard. Will add in a followup PR (there's also probably other global settings like vmap we'll need to add too)

[ghstack-poisoned]
@jamesjwu
Copy link
Contributor Author

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/jamesjwu/16/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/124642)

pytorchmergebot pushed a commit that referenced this pull request Apr 29, 2024
ghstack-source-id: b89440e
Pull Request resolved: #124642
[ghstack-poisoned]
[ghstack-poisoned]
@jamesjwu
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 30, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@jamesjwu
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
…d Cache (#124642)

This doesn't introduce any new behavior, but sets up a basic cache key generation mechanism that I can test. From here I will:

- Add checks on the ops in an input FXGraph to make sure they are safe to cache. We'll be conservative in the first version here.
- Add serialization for FX graphs
- Save these FX graphs to disk in the cache
- Support graphs with more complicated ops like higher order ops and specialized nn modules

Pull Request resolved: #124642
Approved by: https://github.com/aorenste
@github-actions github-actions bot deleted the gh/jamesjwu/16/head branch June 4, 2024 01:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants