Skip to content

Conversation

zdevito
Copy link
Contributor

@zdevito zdevito commented Dec 19, 2018

This PR adds enough of the infra for supporting closures (inner script functions) in order to allow us to expression symbolic gradients using them. We do not actually ever run graphs that contain these closures. The symbolic_script infrastructure just extracts them out of the original forward graph and turns them into discrete forward/backward pairs. This cuts down on the type annotations necessary to write forward/backward pairs and aligns closely with the "differentiator" function approach to expression reverse-mode AD.

Example:

This code:

import torch

r = torch.jit.CompilationUnit(
'''
def mul_forward(self, other):
    def backward(grad_output):
        grad_self = (grad_output * other).sum_to_size(self.size())
        grad_other = (grad_output * self).sum_to_size(other.size())
        return grad_self, grad_other
    return self * other, backward
''')

print(r.module.code)

Will produce this graph (pretty printed for clarity):

def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  backward = (self.__lambda, (other, self))
  return (torch.mul(self, other), backward)


def __lambda(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)

symbolic_script will then do some modifications to remove the unsuppored prim::Function node, yielding:

# same as before but will self.__lambda removed
def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  return (torch.mul(self, other), (other, self))


# just the captured closure, unchanged
def backward(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 19, 2018
Copy link
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!! This is definitely much cleaner than the old separate forward, backward way!
Nit: The compiled graph in the description has an additional self printed in the first line, is this intended?

def mul_forward(self,
    self: Tensor,

@zdevito
Copy link
Contributor Author

zdevito commented Dec 19, 2018

extra self is a bug, David has a fix landing for it already. It shouldn't affect the correctness for our purposes.

});
const std::vector<std::string> functions = {
R"(
def mul(self, other):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm how does the CompilationUnit differentiate when other is a Tensor or Scalar?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zdevito has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zdevito has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito added a commit to zdevito/ATen that referenced this pull request Dec 20, 2018
Summary:
This PR adds enough of the infra for supporting closures (inner script functions) in order to allow us to expression symbolic gradients using them. We do not actually ever run graphs that contain these closures. The symbolic_script infrastructure just extracts them out of the original forward graph and turns them into discrete forward/backward pairs. This cuts down on the type annotations necessary to write forward/backward pairs and aligns closely with the "differentiator" function approach to expression reverse-mode AD.

Example:

This code:
```
import torch

r = torch.jit.CompilationUnit(
'''
def mul_forward(self, other):
    def backward(grad_output):
        grad_self = (grad_output * other).sum_to_size(self.size())
        grad_other = (grad_output * self).sum_to_size(other.size())
        return grad_self, grad_other
    return self * other, backward
''')

print(r.module.code)
```

Will produce this graph (pretty printed for clarity):

```
def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  backward = (self.__lambda, (other, self))
  return (torch.mul(self, other), backward)

def __lambda(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)
```

symbolic_script will then do some modifications to remove the unsuppored prim::Function node, yielding:

```
def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  return (torch.mul(self, other), (other, self))

def backward(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)
```
Pull Request resolved: pytorch/pytorch#15411

Differential Revision: D13523340

Pulled By: zdevito

fbshipit-source-id: 4d4a269460e595b16802c00ec55ae00e3e682d49
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants