Skip to content

Commit

Permalink
Setup_context does not contain default values of forward() (#108561)
Browse files Browse the repository at this point in the history
Fixes #108529

As the title shown.
Pull Request resolved: #108561
Approved by: https://github.com/soulitzer
  • Loading branch information
FFFrog authored and pytorchmergebot committed Sep 19, 2023
1 parent 1427b81 commit 70f2ada
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,7 @@ def vmap(info, in_dims, input):
def test_in_dims_multiple_inputs(self, device):
class Id(torch.autograd.Function):
@staticmethod
def forward(input):
def forward(x, y):
pass

@staticmethod
Expand Down
50 changes: 50 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8460,6 +8460,56 @@ def _test_op(fn, inp, args):
_test_op(torch.view_as_complex, torch.rand(2, 2), ())
_test_op(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat), ())

def test_setup_context_when_forward_has_default_args(self):
class PowFunction(Function):
@staticmethod
def forward(x, y=3):
return torch.pow(x, y)

@staticmethod
def setup_context(ctx, inputs, output):
x, y = inputs
ctx.save_for_backward(x)
ctx.y = y

@staticmethod
def backward(ctx, gO):
x, = ctx.saved_tensors
y = ctx.y
return gO * y * torch.pow(x, y - 1), None

class PowFunctionWithClassmethod(Function):
@classmethod
def forward(cls, x, y=3):
return torch.pow(x, y)

@classmethod
def setup_context(cls, ctx, inputs, output):
x, y = inputs
ctx.save_for_backward(x)
ctx.y = y

@classmethod
def backward(cls, ctx, gO):
x, = ctx.saved_tensors
y = ctx.y
return gO * y * torch.pow(x, y - 1), None

x = torch.tensor(2.0, requires_grad=True)

y = torch.tensor(8.0)
y_expected = torch.tensor(12.0)

y1 = PowFunction.apply(x)
y1_expected, = torch.autograd.grad(y1, x)

y2 = PowFunctionWithClassmethod.apply(x)
y2_expected, = torch.autograd.grad(y2, x)

self.assertEqual(y, y1)
self.assertEqual(y_expected, y1_expected)
self.assertEqual(y, y2)
self.assertEqual(y_expected, y2_expected)

def index_perm_variable(shape, max_indices):
if not isinstance(shape, tuple):
Expand Down
14 changes: 13 additions & 1 deletion torch/autograd/function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import inspect
import warnings
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
Expand Down Expand Up @@ -533,12 +534,23 @@ def vmap(info, in_dims, *args):

@classmethod
def apply(cls, *args, **kwargs):
def bind_default_args(func, *args, **kwargs):
signature = inspect.signature(func)
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()

return bound_args.args

is_setup_ctx_defined = cls.setup_context != _SingleLevelFunction.setup_context
if is_setup_ctx_defined:
args = bind_default_args(cls.forward, *args, **kwargs)

if not torch._C._are_functorch_transforms_active():
# See NOTE: [functorch vjp and autograd interaction]
args = _functorch.utils.unwrap_dead_wrappers(args)
return super().apply(*args, **kwargs) # type: ignore[misc]

if cls.setup_context == _SingleLevelFunction.setup_context:
if not is_setup_ctx_defined:
raise RuntimeError(
"In order to use an autograd.Function with functorch transforms "
"(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
Expand Down
43 changes: 43 additions & 0 deletions torch/testing/_internal/autograd_function_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,40 @@ def jvp(ctx, gx, gy):
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
)


def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5))


class ForwardHasDefaultArgs(torch.autograd.Function):
@staticmethod
def forward(x, idx=(2,)):
return x[idx]

@staticmethod
def setup_context(ctx, inputs, output):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx

@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None

@staticmethod
def vmap(info, in_dims, x, idx):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 1)
return ForwardHasDefaultArgs.apply(x, idx), 0

@staticmethod
def jvp(ctx, x_tangent, _):
return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx)


autograd_function_db = [
OpInfo(
'NumpyCubeAutogradFunction',
Expand Down Expand Up @@ -584,4 +618,13 @@ def jvp(ctx, gx, gy):
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ForwardHasDefaultArgsAutogradFunction',
op=ForwardHasDefaultArgs.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_forward_default_args,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
]

0 comments on commit 70f2ada

Please sign in to comment.