-
Notifications
You must be signed in to change notification settings - Fork 25k
Setup initial testing harness and cache key generation for AOTAutograd Cache #124642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124642
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d9e2a33 with merge base 935a946 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -485,25 +485,56 @@ class FxGraphCachePickler(pickle.Pickler): | |||
dispatch_table[torch.Tensor] = _reduce_tensor | |||
dispatch_table[torch.SymInt] = _reduce_symint | |||
|
|||
@staticmethod | |||
def dumps(obj) -> bytes: | |||
@classmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This just lets me reuse the pickling implementation
elif isinstance(obj, bytes): | ||
return "<bytes>" | ||
else: | ||
return str(obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is for debugging would it make sense to repr
instead of str
?
[ghstack-poisoned]
@pytorchbot rebase |
Rebased |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
fx_graph = gm | ||
return gm | ||
|
||
g = torch.compile(fn, backend=compiler) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might want fullgraph=True
? if there area graph breaks, compiler()
will get called multiple times (and the nonlocal will get overwritten with whatever the last graph was)
|
||
def autograd_cache_hash( | ||
gm: torch.fx.GraphModule, | ||
config: AOTConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I think caching on the AOTConfig object is mostly reasonable. But just calling out: AOTConfig isn't an input to AOTAutograd - it's mostly a convenience object to bundle up and pass around state about what AOTAutograd should do during tracing, and it contains a mix of inputs to AOTAutograd, and other stuff that is inferred implicitly (is_export=False
, aot_id=...
).
Just looking at some of the stuff in it:
aot_id
: Including AOTConfig.aot_id
will probably give us some false negatives (this is just a counter used to uniquely identify each compiled region in a python process).
is_export
: AOTAutograd is also used in an export setting where we don't need the warm cache, so this will just always be set to false. num_params_buffers
/dynamic_shapes
: these are both values that are inferred from looking at the dynamo graph. Then again these are all cheap to cache and won't cause false negatives, so caching on them probably won't matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm I guess this comment isn't too relevant, since you're already defining a custom reduce for the AOTConfig (so you can explicitly e.g. not hash on aot_id)
|
||
def _reduce_aot_config(config: AOTConfig): | ||
""" | ||
Reduce the config to a stable key for caching. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm - if we care about non-inductor backends working with the cache, then we probably want to include fw_compiler
/bw_compiler
/decomposition_table
in some form here, right?
That is kind of annoyingly redundant for the inductor case - we are already caching on the inductor source code, so caching on the exact e.g. decomposition_table arguments is redundant. But two use cases are:
(1) if someone who is not inductor is using AOTAutograd as their backend, we probably want to recompile if they change the decompositions they pass in
(2) There are two nice backends that we have for debugging: torch.compile(backend="aot_eager")
vs. torch.compile(backend="aot_eager_decomp_partition")
that are mostly the same exact that they use different decomposition tables and partitioning functions. We probably want the cache key to encode their differences
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, I think we do need to add some combination of these, though computing a hash key from the callables directly seems hard. Will add in a followup PR (there's also probably other global settings like vmap
we'll need to add too)
[ghstack-poisoned]
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…d Cache (#124642) This doesn't introduce any new behavior, but sets up a basic cache key generation mechanism that I can test. From here I will: - Add checks on the ops in an input FXGraph to make sure they are safe to cache. We'll be conservative in the first version here. - Add serialization for FX graphs - Save these FX graphs to disk in the cache - Support graphs with more complicated ops like higher order ops and specialized nn modules Pull Request resolved: #124642 Approved by: https://github.com/aorenste
Stack from ghstack (oldest at bottom):
This doesn't introduce any new behavior, but sets up a basic cache key generation mechanism that I can test. From here I will:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang