-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[JIT] Support enough of closures to write autograd functions #15411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
3c552a6
to
982fd65
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!! This is definitely much cleaner than the old separate forward, backward
way!
Nit: The compiled graph in the description has an additional self
printed in the first line, is this intended?
def mul_forward(self,
self: Tensor,
extra |
}); | ||
const std::vector<std::string> functions = { | ||
R"( | ||
def mul(self, other): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm how does the CompilationUnit
differentiate when other
is a Tensor or Scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zdevito has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zdevito has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This PR adds enough of the infra for supporting closures (inner script functions) in order to allow us to expression symbolic gradients using them. We do not actually ever run graphs that contain these closures. The symbolic_script infrastructure just extracts them out of the original forward graph and turns them into discrete forward/backward pairs. This cuts down on the type annotations necessary to write forward/backward pairs and aligns closely with the "differentiator" function approach to expression reverse-mode AD. Example: This code: ``` import torch r = torch.jit.CompilationUnit( ''' def mul_forward(self, other): def backward(grad_output): grad_self = (grad_output * other).sum_to_size(self.size()) grad_other = (grad_output * self).sum_to_size(other.size()) return grad_self, grad_other return self * other, backward ''') print(r.module.code) ``` Will produce this graph (pretty printed for clarity): ``` def mul_forward(self, self: Tensor, other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]: backward = (self.__lambda, (other, self)) return (torch.mul(self, other), backward) def __lambda(self, context: Tuple[Tensor, Tensor], grad_output: Tensor) -> Tuple[Tensor, Tensor]: other, self, = context grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self)) grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other)) return (grad_self, grad_other) ``` symbolic_script will then do some modifications to remove the unsuppored prim::Function node, yielding: ``` def mul_forward(self, self: Tensor, other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]: return (torch.mul(self, other), (other, self)) def backward(self, context: Tuple[Tensor, Tensor], grad_output: Tensor) -> Tuple[Tensor, Tensor]: other, self, = context grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self)) grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other)) return (grad_self, grad_other) ``` Pull Request resolved: pytorch/pytorch#15411 Differential Revision: D13523340 Pulled By: zdevito fbshipit-source-id: 4d4a269460e595b16802c00ec55ae00e3e682d49
This PR adds enough of the infra for supporting closures (inner script functions) in order to allow us to expression symbolic gradients using them. We do not actually ever run graphs that contain these closures. The symbolic_script infrastructure just extracts them out of the original forward graph and turns them into discrete forward/backward pairs. This cuts down on the type annotations necessary to write forward/backward pairs and aligns closely with the "differentiator" function approach to expression reverse-mode AD.
Example:
This code:
Will produce this graph (pretty printed for clarity):
symbolic_script will then do some modifications to remove the unsuppored prim::Function node, yielding: