-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[dynamo] compiled_autograd support for post_acc_grad hooks #112326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…_grad hooks [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112326
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f2d3f04 with merge base a7a0955 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
return variable_list(); | ||
// A little hack - this is only here for the purpose of hooks. It will get cleared. | ||
return variable_list({fake_variable_copy}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel surely there's a better way to smuggle stuff out? This does get a fake tensor into the hook (if you set a breakpoint in def post_acc_grad_hook
you can see it having the correct value, with grad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't smuggle it out. Just call into a virtual method on tensor_post_acc_grad_hooks()
.
Eager returns nothing here, compiled autograd should do the same.
if(typeid(*call.node) == typeid(torch::autograd::AccumulateGrad)) { | ||
// The return of AccumulateGrad should be [], but we hack it to make hooks work. | ||
// This restores it correctly. | ||
outputs = variable_list(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel see comment above. I dislike this.
torch/_dynamo/compiled_autograd.py
Outdated
assert len(inputs) == 1 | ||
hook = self.hooks_proxy[hook_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel this is an odd one - if we don't unpack the argument, and leave it to pass in inputs, the hook is invoked with a list, which is incorrect.
with disable_proxy_modes_tracing(): | ||
inputs = maybe_clone(inputs[0]) | ||
self.bind_tensors_to_proxies([inputs], proxies) | ||
return inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this type of hook actually return anything? The call to it in eager just ignores the result.
|
||
return variable_list(); | ||
// A little hack - this is only here for the purpose of hooks. It will get cleared. | ||
return variable_list({fake_variable_copy}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't smuggle it out. Just call into a virtual method on tensor_post_acc_grad_hooks()
.
Eager returns nothing here, compiled autograd should do the same.
|
||
SwapSavedVariables saved(compiler_call, state); | ||
variable_list outputs = call.node->apply_with_saved(inputs, saved); | ||
if (!call.post_acc_grad_hooks.empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't do this here, in eager it is called in accumulate_grad.cpp. We should do the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whats the right way to invoke the python fn here?
Im doing this outside because I couldn't figure out the plumbing. I also tried to do it as an inductor op, but I dont think that supports hook signatures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just follow the same pattern as the other hooks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that's what I did here. post and pre hook are called above and below. No other hook is called from a AccumulateGrad or from within a node from what I saw? I don't understand how to call a python hook from within the accumulate_grad implementation without plumbing in a bunch of python notions into it?
Even if within apply_with_saved
within AccumulateGrad we grab the hook and call it, you still need to pass a py_compiler fromthe_autograd_compiler
, right? I didn't want to start mucking around with the apply_with_saved
interface. Would you mind just giving me a little more detail on how you would do it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like, we could skip the py stuff and just do
auto& hook = tensor_post_acc_grad_hooks();
if (hook != nullptr) {
(*hook)(variable);
}
But then it will bypass all the proxy and fake tensor stuff, which I don't think is what we want? Unless you want me to invoke the above with the fake tensor between before/after...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I guess this works, but not sure if this is correct. I'll PR what I have and we can discuss it there.
…or post_acc_grad hooks" cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
…or post_acc_grad hooks" cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
auto& hook = tensor_post_acc_grad_hooks(); | ||
if (hook != nullptr) { | ||
(*hook)(variable_copy); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong. It will just trace through the hook (which might not be tracable).
Need something like:
hook->apply_with_saved(variable_copy, saved)
Then inside the handler for apply_with_saved the hooks is lifted to an input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Asked offline, I do not understand.
cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
def hook(p): | ||
p.add_(p.grad) | ||
def hook(input_t): | ||
input_t.mul_(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect this to work at this point, though see minor comment.
|
||
AutogradCompilerCall& compiler; | ||
TraceState& state; | ||
PyObject* py_compiler; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment about borrowed(?) ownership.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, borrowed.
yes, it works. |
cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…12326) Pull Request resolved: pytorch#112326 Approved by: https://github.com/jansel ghstack dependencies: pytorch#112325
…12326) Pull Request resolved: pytorch#112326 Approved by: https://github.com/jansel ghstack dependencies: pytorch#112325
Stack from ghstack (oldest at bottom):
cc @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng