Skip to content

Commit

Permalink
autograd.Function supports vmap staticmethod
Browse files Browse the repository at this point in the history
This PR adds a `vmap` staticmethod to autograd.Function and a
corresponding vmap kernel for custom_function_call. These two items mean
that autograd.Function with a vmap staticmethod can be used with vmap.

```py
class NumpyMul(torch.autograd.Function)
    staticmethod
    def forward(x, y):
        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)

    staticmethod
    def setup_context(ctx, outputs, x, y):
        ctx.save_for_backward(x, y)

    staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        gx = None
        if isinstance(x, torch.Tensor) and x.requires_grad:
            gx = NumpyMul.apply(grad_output, y)
        gy = None
        if isinstance(y, torch.Tensor) and y.requires_grad:
            gy = NumpyMul.apply(grad_output, x)
        return gx, gy

    staticmethod
    def vmap(info, in_dims, x, y):
        x_bdim, y_bdim = in_dims
        x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1)
        y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1)
        result = NumpyMul.apply(x, y)
        result = result.movedim(-1, 0)
        return result, 0
```

API Spec
- the staticmethod takes two arguments (info, in_dims) as well as the
unexpanded inputs (x, y).
- If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a
pytree with the same tree structure as args. It has None if the arg is
not being vmapped over and an integer vmapped dimension index if it is.
- `info` is an object with metadata about the vmap. It currently has one
field, `info.batch_size`. In the future we can extend this by adding
things like the randomness information.
- If there is a single vmap going on, (x, y) are NOT BatchedTensors,
they've already been unpacked.
- We expect the user to return a `(outputs, out_dims)` tuple. `out_dims`
must "broadcast" to the same pytree structure as `outputs`.

Semantics
- vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is
one and will never actually run NumpyMul.forward.
- In order for the autograd.Function to support nested vmap (e.g.,
`vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call
into operations that vmap understands (i.e. PyTorch operators or more
autograd.Function).

At a high level, this PR:
- adds a vmap rule for custom_function_call

Testing
- Added some tests for in_dims and info
- Added vmap staticmethod to most of the autograd.Function in
autograd_function_db and sent them through functorch's vmap-related
OpInfo tests

Future
- Better error messages if the user gets the return contract wrong. I
didn't include them in this PR because it might involve a refactor of
some of the existing code in functorch/_src/vmap.py that will add
~200LOC to the PR, but LMK if you'd prefer it here.

ghstack-source-id: b17a03e7563a663418f30f99f6e21366c9d62015
Pull Request resolved: #90037
  • Loading branch information
zou3519 committed Dec 2, 2022
1 parent ca3bd35 commit e52ef1d
Show file tree
Hide file tree
Showing 12 changed files with 297 additions and 19 deletions.
9 changes: 2 additions & 7 deletions aten/src/ATen/functorch/ADInterpreters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,12 @@ Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) {
return makeTensorWrapper(tensor, current_level, /*is_immutable=*/true);
}

static Tensor base_lift(const Tensor& tensor, int64_t level) {
auto tensor_ = unwrapIfDead(tensor);
return materializeGradWrappers(tensor_, level);
}

Tensor GradInterpreterPtr::lift(const Tensor& tensor) const {
return base_lift(tensor, level());
return materializeGradWrappers(tensor, level());
}

Tensor JvpInterpreterPtr::lift(const Tensor& tensor) const {
return base_lift(tensor, level());
return materializeGradWrappers(tensor, level());
}

static void autogradBasedTransformProcess(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/functorch/PlumbingHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, option
return res;
}

std::tuple<Tensor, optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level) {
std::tuple<Tensor, c10::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level) {
auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) {
return std::make_tuple(tensor, nullopt);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/functorch/PlumbingHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TORCH_API Tensor makeBatched(const Tensor& tensor, optional<int64_t> bdim, int64
// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level
// doesn't match, then this returns (tensor, nullopt).
// Otherwise, it returns (unwrap(tensor), bdim).
TORCH_API std::tuple<Tensor, optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);
TORCH_API std::tuple<Tensor, c10::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);

// Creates a vector of BatchedTensor
TORCH_API std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, optional<int64_t> bdim, int64_t level);
Expand Down
3 changes: 2 additions & 1 deletion test/functorch/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import unittest
from torch.testing._internal.common_device_type import toleranceOverride
from torch.testing._internal.autograd_function_db import autograd_function_db
from collections import namedtuple

IS_FBCODE = os.getenv('FUNCTORCH_TEST_FBCODE') == '1'
Expand Down Expand Up @@ -351,7 +352,7 @@ def skip(op_name, variant_name='', *, device_type=None, dtypes=None):


def skipOps(test_case_name, base_test_name, to_skip):
all_opinfos = op_db + additional_op_db
all_opinfos = op_db + additional_op_db + autograd_function_db
for decorate_meta in to_skip:
matching_opinfos = [o for o in all_opinfos
if o.name == decorate_meta.op_name and
Expand Down
125 changes: 125 additions & 0 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,126 @@ def backward(ctx, grad_output):
grad(grad(A.apply))(x, y)


class TestAutogradFunctionVmapAPI(TestCase):
def test_no_vmap_staticmethod(self, device):
class NumpyCube(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = to_numpy(input)
dinput = torch.tensor(3 * input_np ** 2, device=input.device)
return torch.tensor(input_np ** 3, device=input.device), dinput

@staticmethod
def setup_context(ctx, outputs, input):
ctx.save_for_backward(input, outputs[1])

@staticmethod
def backward(ctx, grad_output, grad_saved):
raise RuntimeError("foobar")

x = torch.randn(3, device=device)
with self.assertRaisesRegex(RuntimeError, 'does not have a vmap rule defined'):
vmap(NumpyCube.apply)(x)

def test_info_object(self, device):
batch_size = 10

class Id(torch.autograd.Function):
@staticmethod
def forward(input):
pass

@staticmethod
def setup_context(ctx, outputs, input):
pass

@staticmethod
def backward(ctx, grad_output, grad_saved):
pass

@staticmethod
def vmap(info, in_dims, input):
self.assertEqual(info.batch_size, batch_size)
return input, in_dims[0]

x = torch.randn(batch_size, 3, device=device)
vmap(Id.apply)(x)

def test_in_dims_single_input(self, device):
class Id(torch.autograd.Function):
@staticmethod
def forward(input):
pass

@staticmethod
def setup_context(ctx, outputs, input):
pass

@staticmethod
def backward(ctx, grad_output, grad_saved):
pass

@staticmethod
def vmap(info, in_dims, input):
self.assertEqual(in_dims, (1,))
return input, in_dims[0]

B = 10
x = torch.randn(3, B, device=device)
vmap(Id.apply, in_dims=1)(x)
vmap(Id.apply, in_dims=(1,))(x)

def test_in_dims_multiple_inputs(self, device):
class Id(torch.autograd.Function):
@staticmethod
def forward(input):
pass

@staticmethod
def setup_context(ctx, outputs, x, y):
pass

@staticmethod
def backward(ctx, grad_output, grad_saved):
pass

@staticmethod
def vmap(info, in_dims, x, y):
self.assertEqual(in_dims, (0, [0, 0]))
self.assertTrue(isinstance(in_dims, tuple))
self.assertTrue(isinstance(in_dims[1], list))
return (x, y), in_dims

x = torch.randn(2, device=device)
vmap(Id.apply)(x, [x, x])

def test_skips_empty_layer(self, device):
class Id(torch.autograd.Function):
@staticmethod
def forward(input):
return input

@staticmethod
def setup_context(ctx, outputs, input):
pass

@staticmethod
def backward(ctx, grad_output, grad_saved):
pass

@staticmethod
def vmap(info, in_dims, input):
raise RuntimeError("expected to not be called")

def f(x):
y = torch.tensor(1.)
y = Id.apply(y)
return x * 1

x = torch.randn(2, 3)
vmap(f)(x)


class TestVmapOfGrad(TestCase):
def test_per_sample_grads_inplace_view(self, device):
def compute_loss(weight, x, t):
Expand Down Expand Up @@ -3792,6 +3912,11 @@ def f(x):
globals(),
only_for=only_for,
)
instantiate_device_type_tests(
TestAutogradFunctionVmapAPI,
globals(),
only_for=only_for,
)
instantiate_parametrized_tests(
TestMakeFunctional,
)
Expand Down
13 changes: 8 additions & 5 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def fn(inp, *args, **kwargs):
xfail("double"), # rank 4 tensor for channels_last
xfail("float"), # rank 4 tensor for channels_last
xfail("half"), # rank 4 tensor for channels_last
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function
# It looks like you're either (1) calling .item() on a Tensor or
# (2) attempting to use a Tensor in some data-dependent control flow or
# (3) encountering this error in PyTorch internals.
Expand Down Expand Up @@ -730,7 +731,7 @@ def fn(inp, *args, **kwargs):
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
}))
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride('TestOperators', 'test_vmapvjpvjp', (
tol1('linalg.svd',
Expand Down Expand Up @@ -807,6 +808,7 @@ def vjp_of_vjp(*args_and_cotangents):
xfail('svd_lowrank', ''), # randomness
xfail('to_sparse', ''), # non-dense output
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function
# ----------------------------------------------------------------------

# ---------------------------- BUGS ------------------------------------
Expand Down Expand Up @@ -846,7 +848,7 @@ def vjp_of_vjp(*args_and_cotangents):
})

@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride('TestOperators', 'test_vmapvjp', (
tol1('linalg.svd',
Expand Down Expand Up @@ -1044,7 +1046,7 @@ def test():
pass
check_vmap_fallback(self, test, op, dry_run=False)

@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
Expand Down Expand Up @@ -1163,7 +1165,7 @@ def test():

check_vmap_fallback(self, test, op, dry_run=False)

@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@skipOps('TestOperators', 'test_vjpvmap', vjp_fail.union({
skip('bernoulli', ''), # vjpvmap testing can't handle randomness
skip('normal', ''), # vjpvmap testing can't handle randomness
Expand All @@ -1176,6 +1178,7 @@ def test():
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
skip('to_sparse', ''), # non-dense output
skip('ormqr', ''), # takes too long
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function
# fallback path doesn't work
# All of the following are bugs and need to be fixed
Expand Down Expand Up @@ -1681,7 +1684,7 @@ def fn(input, weight, bias):
self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))

@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float32, torch.double))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float32, torch.double))
@skipOps('TestOperators', 'test_vmap_autograd_grad', {
xfail('linalg.eig'), # all close?
# The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
Expand Down
8 changes: 6 additions & 2 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
from functorch.experimental import chunk_vmap
from torch._C._functorch import reshape_dim_into, reshape_dim_outof
from functorch._src.make_functional import functional_init_with_buffers
from torch.testing._internal.autograd_function_db import autograd_function_db

torch._C._set_autograd_function_extension_enabled(True)

FALLBACK_REGEX = 'There is a performance drop'

Expand Down Expand Up @@ -3239,6 +3242,7 @@ def test():
xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops
xfail('sparse.sampled_addmm'), # sparse
xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function
skip('_softmax_backward_data'),
skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
Expand Down Expand Up @@ -3283,7 +3287,7 @@ def test():
}

@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', (
tol1('linalg.det',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
Expand All @@ -3310,7 +3314,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
self.opinfo_vmap_test(device, dtype, op, check_has_batch_rule=False,
skip_inplace=inplace_failure_list)

@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', (
tol1('linalg.det',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/_functorch.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch import Tensor
from enum import Enum
from typing import Optional, Tuple

# Defined in torch/csrc/functorch/init.cpp

Expand All @@ -14,6 +15,7 @@ def maybe_get_level(tensor: Tensor) -> int: ...
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...

def set_autograd_function_allowed(allowed: bool) -> None: ...
def get_autograd_function_allowed() -> bool: ...
Expand Down
Loading

0 comments on commit e52ef1d

Please sign in to comment.